IlyaGusev/rulm

View on GitHub
self_instruct/src/util/generate.py

Summary

Maintainability
A
45 mins
Test Coverage
from typing import List
from transformers import AutoModel, AutoTokenizer, GenerationConfig


def generate(
    model: AutoModel,
    tokenizer: AutoTokenizer,
    prompts: List[str],
    generation_config: GenerationConfig,
    source_max_length: int = 2048,
    eos_token_id: int = None
):
    if eos_token_id is not None:
        generation_config.eos_token_id = eos_token_id

    data = tokenizer(
        prompts,
        return_tensors="pt",
        truncation=True,
        max_length=source_max_length,
        padding=True
    )
    data = {k: v.to(model.device) for k, v in data.items()}
    output_ids = model.generate(
        **data,
        generation_config=generation_config
    )
    outputs = []
    for sample_output_ids, sample_input_ids in zip(output_ids, data["input_ids"]):
        sample_output_ids = sample_output_ids[len(sample_input_ids):]
        sample_output = tokenizer.decode(sample_output_ids, skip_special_tokens=True)
        sample_output = sample_output.replace("</s>", "").strip()
        outputs.append(sample_output)
    return outputs