self_instruct/src/infer_saiga_vllm.py
import json
import fire
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from src.util.io import read_jsonl
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
def infer_saiga_vllm(
model_name: str,
input_path: str,
output_path: str,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 30,
max_tokens: int = 2048,
repetition_penalty: float = 1.1,
disable_system_prompt: bool = False
):
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm = LLM(model=model_name, max_context_len_to_capture=8192)
records = read_jsonl(input_path)
prompts = []
role_mapping = {
"bot": "assistant",
"gpt": "assistant",
"human": "user",
}
actual_records = []
for r in records:
if "instruction" in r:
query = r["instruction"]
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query}
]
elif "messages" in r or "prompt" in r:
messages = r.get("messages", r.get("prompt"))
assert messages
if messages[0]["role"] != "system" and not disable_system_prompt:
messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
for m in messages:
m["role"] = role_mapping.get(m["role"], m["role"])
if messages[-1]["role"] == "assistant":
messages = messages[:-1]
else:
assert False, "Wrong input format!"
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt = prompt.replace(tokenizer.bos_token, "")
actual_records.append(r)
prompts.append(prompt)
outputs = llm.generate(prompts, sampling_params)
with open(output_path, "w") as w:
for record, output in zip(actual_records, outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(prompt)
print(generated_text)
print()
print()
record["answer"] = generated_text.encode("utf-8").decode("utf-8", "ignore")
w.write(json.dumps(record, ensure_ascii=False).strip() + "\n")
if __name__ == "__main__":
fire.Fire(infer_saiga_vllm)