IlyaGusev/rulm

View on GitHub
self_instruct/src/infer_chatgpt.py

Summary

Maintainability
A
0 mins
Test Coverage
import json

import fire
from jinja2 import Template
from tqdm import tqdm

from src.util.io import read_jsonl
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments



def infer_batch(batch, model_name, output_file):
    prompts = [r["prompt"] for r in batch]
    results = openai_batch_completion(
        batch=prompts,
        model_name=model_name,
        decoding_args=OpenAIDecodingArguments(
            max_tokens=4096
        )
    )
    for r, prompt, result in zip(batch, prompts, results):
        print(prompt)
        print(result.message["content"])
        r["answer"] = result.message["content"]
        output_file.write(json.dumps(r, ensure_ascii=False).strip() + "\n")
    return results


def main(
    input_path,
    output_path,
    model_name,
    request_batch_size=5
):
    records = read_jsonl(input_path)

    batch = []
    with open(output_path, "w") as w:
        for record in tqdm(records):
            batch.append(record)
            if len(batch) != request_batch_size:
                continue
            infer_batch(
                batch=batch,
                model_name=model_name,
                output_file=w
            )
            batch = []

        if batch:
            infer_batch(
                batch=batch,
                model_name=model_name,
                output_file=w
            )


if __name__ == "__main__":
    fire.Fire(main)