self_instruct/src/data_processing/improve_instructions.py
import os
import random
import shutil
from jinja2 import Template
import fire
from tqdm import tqdm
from src.util.io import read_jsonl, write_jsonl
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments
def to_messages(prompt):
return [{"role": "user", "content": prompt}]
def depth_encode_prompt(task, template_path, methods_path):
with open(methods_path) as r:
methods = [line.strip() for line in r]
methods = [m for m in methods if m.strip()]
with open(template_path) as f:
template = Template(f.read())
return [to_messages(template.render(
task=task,
method=method
).strip()) for method in methods]
def task_only_encode_prompt(task, template_path):
with open(template_path) as f:
template = Template(f.read())
return to_messages(template.render(
task=task,
).strip())
def elimination_encode_prompt(task, template_path):
with open(template_path) as f:
template = Template(f.read())
first_task = task
second_task = task["previous_tasks"][-1]
return to_messages(template.render(
first_task=first_task,
second_task=second_task
).strip())
def extend_post_process(response, original_task, method):
if not response:
return None
if response["finish_reason"] == "length":
return None
content = response["message"]["content"].strip()
previous_tasks = []
if "previous_tasks" in original_task:
previous_tasks = original_task["previous_tasks"]
previous_tasks.append({
"instruction": original_task["instruction"],
"input": original_task["input"] if "input" in original_task else "",
"method": method
})
return {
"instruction": content,
"input": "",
"previous_tasks": previous_tasks
}
def check_new_task(task):
if "оригинальное задание" in task["instruction"].lower():
return False
if "усложнённое задание" in task["instruction"].lower():
return False
if "####" in task["instruction"]:
return False
return True
def get_key(task):
if "previous_tasks" in task:
return (task["previous_tasks"][0]["instruction"], )
return (task["instruction"], )
def evolve_batch(
original_tasks,
model_name,
decoding_args,
depth_template_path: str = None,
depth_methods_path: str = None,
breadth_template_path: str = None,
xml_template_path: str = None,
json_template_path: str = None,
few_shot_template_path: str = None
):
batch = [depth_encode_prompt(task, depth_template_path, depth_methods_path) for task in original_tasks]
batch = [[(prompt, "depth") for prompt in batch_part] for batch_part in batch]
for batch_part, task in zip(batch, original_tasks):
batch_part.append((task_only_encode_prompt(task, xml_template_path), "xml"))
batch_part.append((task_only_encode_prompt(task, json_template_path), "json"))
batch_part.append((task_only_encode_prompt(task, breadth_template_path), "breadth"))
if "input" in task and task["input"]:
batch_part.append((task_only_encode_prompt(task, few_shot_template_path), "few_shot"))
batch = [random.choice(batch_part) for batch_part in batch]
if not batch:
return []
prompts = [prompt for prompt, _ in batch]
results = openai_batch_completion(
batch=prompts,
model_name=model_name,
decoding_args=decoding_args
)
gen_tasks = []
for result, original_task, (_, method) in zip(results, original_tasks, batch):
gen_tasks.append(extend_post_process(result, original_task, method=method))
_, keep = len(gen_tasks), 0
new_tasks = []
for task in gen_tasks:
if not task:
continue
if not check_new_task(task):
continue
keep += 1
new_tasks.append(task)
return new_tasks
def eliminate_batch(tasks, model_name, decoding_args, template_path):
batch = [elimination_encode_prompt(task, template_path) for task in tasks]
results = openai_batch_completion(
batch=batch,
model_name=model_name,
decoding_args=decoding_args
)
filtered_tasks = []
eliminated_count = 0
total_count = len(tasks)
for result, task in zip(results, tasks):
content = result["message"]["content"]
if content.lower().startswith("да"):
eliminated_count += 1
continue
filtered_tasks.append(task)
print(f"Eliminated {eliminated_count} of {total_count} new tasks")
return filtered_tasks
def process_batch(
batch,
model_name,
decoding_args,
depth_template_path,
depth_methods_path,
breadth_template_path,
elimination_template_path,
xml_template_path,
json_template_path,
few_shot_template_path
):
new_tasks = evolve_batch(
batch,
model_name=model_name,
decoding_args=decoding_args,
depth_template_path=depth_template_path,
depth_methods_path=depth_methods_path,
json_template_path=json_template_path,
xml_template_path=xml_template_path,
breadth_template_path=breadth_template_path,
few_shot_template_path=few_shot_template_path
)
if not new_tasks:
return []
new_tasks = eliminate_batch(
new_tasks,
model_name=model_name,
decoding_args=decoding_args,
template_path=elimination_template_path
)
return new_tasks
def improve_instructions(
original_tasks_path,
output_path,
depth_template_path,
depth_methods_path,
breadth_template_path,
elimination_template_path,
xml_template_path,
json_template_path,
few_shot_template_path,
model_name="gpt-3.5-turbo",
request_batch_size=5,
temperature=1.0,
top_p=0.95,
num_cpus=8,
skip_count=0
):
original_tasks = read_jsonl(original_tasks_path)
print(f"Loaded {len(original_tasks)} original tasks")
new_tasks = []
existing_keys = set()
if os.path.exists(output_path):
new_tasks = read_jsonl(output_path)
existing_keys = {get_key(task) for task in new_tasks}
print(f"Loaded {len(new_tasks)} new tasks")
decoding_args = OpenAIDecodingArguments(
temperature=temperature,
top_p=top_p
)
is_output_printed = False
batch = []
for original_task in tqdm(original_tasks[skip_count:]):
if get_key(original_task) in existing_keys:
print("Skip:", get_key(original_task))
continue
batch.append(original_task)
if len(batch) == request_batch_size:
new_batch_tasks = process_batch(
batch,
model_name=model_name,
decoding_args=decoding_args,
depth_template_path=depth_template_path,
depth_methods_path=depth_methods_path,
breadth_template_path=breadth_template_path,
elimination_template_path=elimination_template_path,
xml_template_path=xml_template_path,
json_template_path=json_template_path,
few_shot_template_path=few_shot_template_path
)
new_tasks.extend(new_batch_tasks)
batch = []
if not is_output_printed and new_batch_tasks:
is_output_printed = True
print(new_batch_tasks[0])
write_jsonl(new_tasks, output_path + "_tmp")
shutil.move(output_path + "_tmp", output_path)
if batch:
new_batch_tasks = process_batch(
batch,
model_name=model_name,
decoding_args=decoding_args,
depth_template_path=depth_template_path,
depth_methods_path=depth_methods_path,
breadth_template_path=breadth_template_path,
elimination_template_path=elimination_template_path,
xml_template_path=xml_template_path,
json_template_path=json_template_path,
few_shot_template_path=few_shot_template_path
)
new_tasks.extend(new_batch_tasks)
write_jsonl(new_tasks, output_path + "_tmp")
shutil.move(output_path + "_tmp", output_path)
if __name__ == "__main__":
fire.Fire(improve_instructions)