self_instruct/src/data_processing/generate_chars.py
import time
import json
import os
import random
import re
import shutil
from functools import partial
from multiprocessing import Pool
from jinja2 import Template
import fire
import tqdm
from rouge_score import rouge_scorer
from src.util.io import read_jsonl, write_jsonl
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments
NON_ALPHANUM_RE = re.compile(r"[^a-zа-яё0-9]+")
def tokenize(text):
text = text.lower()
text = NON_ALPHANUM_RE.sub(" ", text)
return text.split()
def encode_prompt(example_chars, template_path):
with open(template_path) as f:
template = Template(f.read())
for char in example_chars:
char.pop("most_similar_chars", None)
char.pop("avg_similarity_score", None)
return template.render(
example_chars=json.dumps(example_chars, ensure_ascii=False)
).strip() + "\n"
def post_process(response):
if not response:
return []
if response["finish_reason"] == "length":
return []
raw_content = response["message"]["content"]
try:
chars = json.loads(raw_content)
if isinstance(chars, list):
return chars
elif isinstance(chars, dict):
return chars["characters"]
except Exception:
return []
def generate_chars(
output_path: str,
seed_chars_path: str,
template_path: str,
num_chars_to_generate: int = 200,
model_name: str = "gpt-4",
request_batch_size: int = 5,
temperature: float = 1.0,
top_p: float = 0.95,
num_cpus: int = 8,
rouge_cutoff: float = 0.24
):
random.seed(43)
seed_chars = [json.loads(line) for line in open(seed_chars_path, "r")]
print(f"Loaded {len(seed_chars)} character examples")
machine_chars = []
if os.path.exists(output_path):
machine_chars = read_jsonl(output_path)
print(f"Loaded {len(machine_chars)} machine-generated characters")
all_descriptions = [d["context"] for d in seed_chars + machine_chars]
all_description_tokens = [tokenize(d) for d in all_descriptions]
request_idx = 0
progress_bar = tqdm.tqdm(total=num_chars_to_generate)
if machine_chars:
progress_bar.update(len(machine_chars))
is_prompt_printed = False
is_output_printed = False
while len(machine_chars) < num_chars_to_generate:
request_idx += 1
batch = []
for _ in range(request_batch_size):
if machine_chars:
prompt_chars = random.sample(machine_chars, 1)
prompt_chars += random.sample(seed_chars, 1)
else:
prompt_chars = random.sample(seed_chars, 2)
random.shuffle(prompt_chars)
prompt = encode_prompt(prompt_chars, template_path)
messages = [{"role": "user", "content": prompt}]
batch.append(messages)
if not is_prompt_printed:
is_prompt_printed = True
print("Prompt example:")
for message in batch[0]:
print("Role: {}, content: {}".format(message["role"], message["content"]))
request_start = time.time()
results = openai_batch_completion(
batch=batch,
model_name=model_name,
decoding_args=OpenAIDecodingArguments(
temperature=temperature,
top_p=top_p
)
)
if not is_output_printed:
is_output_printed = True
print("Output example:")
print(results[0].message["content"])
request_duration = time.time() - request_start
process_start = time.time()
new_chars = []
for result in results:
new_chars.extend(post_process(result))
total = len(new_chars)
keep = 0
for new_char in new_chars:
new_description_tokens = tokenize(new_char["context"])
with Pool(num_cpus) as p:
rouge_scores = p.map(
partial(rouge_scorer._score_lcs, new_description_tokens),
all_description_tokens,
)
rouge_scores = [score.fmeasure for score in rouge_scores]
if max(rouge_scores) > rouge_cutoff:
continue
keep += 1
machine_chars.append(new_char)
all_descriptions.append(new_char["context"])
all_description_tokens.append(new_description_tokens)
progress_bar.update(1)
process_duration = time.time() - process_start
print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s")
print(f"Generated {total} chars, kept {keep} chars")
print("===================================")
write_jsonl(machine_chars, output_path + "_tmp")
shutil.move(output_path + "_tmp", output_path)
if __name__ == "__main__":
fire.Fire(generate_chars)