IlyaGusev/rulm

View on GitHub
self_instruct/src/benchmarks/eval_lora_rsg.py

Summary

Maintainability
A
1 hr
Test Coverage
from pathlib import Path
from typing import Tuple
import fire

from src.benchmarks.eval_zs_rsg import (
    predict_danetqa,
    predict_rcb,
    predict_terra,
    predict_lidirus,
    predict_parus,
    predict_muserc,
    predict_rucos,
    predict_rwsd,
    predict_russe,
    clean_rwsd_response,
    ALL_TASKS,
    predict_saiga_zero_shot
)
from src.data_processing.convert_rsg import (
    DANETQA_SOURCE_TEMPLATE,
    RCB_SOURCE_TEMPLATE,
    TERRA_SOURCE_TEMPLATE,
    LIDIRUS_SOURCE_TEMPLATE,
    PARUS_CAUSE_SOURCE_TEMPLATE,
    PARUS_EFFECT_SOURCE_TEMPLATE,
    MUSERC_SOURCE_TEMPLATE,
    RUCOS_SOURCE_TEMPLATE,
    RWSD_SOURCE_TEMPLATE,
    RUSSE_SOURCE_TEMPLATE
)
from src.util.load import load_saiga

HF_DATASET = "RussianNLP/russian_super_glue"


def clean_danetqa(response):
    return "да" in response.lower()


def clean_rcb(response):
    if "да" in response.lower():
        return "entailment"
    if "нет" in response.lower():
        return "contradiction"
    return "neutral"


def clean_terra(response):
    if "да" in response.lower():
        return "entailment"
    return "not_entailment"


def clean_muserc(response):
    if "да" in response.lower():
        return True
    return False


def clean_rucos(response, entities):
    _ = entities
    return response


def clean_russe(response):
    if "да" in response.lower():
        return 1
    return 0


def main(
    model_name,
    nrows: int = None,
    template_path: str = "internal_prompts/saiga_v2.json",
    split: str = "test",
    predictions_dir: str = "submission_peft",
    debug: bool = False,
    tasks: Tuple[str] = ALL_TASKS
):
    predictions_dir = Path(predictions_dir)

    model, tokenizer, generation_config = load_saiga(model_name)
    generation_config.no_repeat_ngram_size = 64
    generation_config.temperature = 0.01

    def predict_saiga_zero_shot_bound(batch):
        generation_config.max_new_tokens = 256
        return predict_saiga_zero_shot(
            model=model,
            tokenizer=tokenizer,
            generation_config=generation_config,
            template_path=template_path,
            prompts=batch,
            debug=debug
        )

    if "danetqa" in tasks:
        predict_danetqa(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "DaNetQA.jsonl",
            nrows=nrows,
            clean_func=clean_danetqa,
            template="Задание: danetqa\n" + DANETQA_SOURCE_TEMPLATE
        )
    if "rcb" in tasks:
        predict_rcb(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "RCB.jsonl",
            nrows=nrows,
            clean_func=clean_rcb,
            template="Задание: rcb\n" + RCB_SOURCE_TEMPLATE
        )
    if "terra" in tasks:
        predict_terra(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "TERRa.jsonl",
            nrows=nrows,
            clean_func=clean_terra,
            template="Задание: terra\n" + TERRA_SOURCE_TEMPLATE
        )
    if "lidirus" in tasks:
        predict_lidirus(
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "LiDiRus.jsonl",
            nrows=nrows,
            clean_func=clean_terra,
            template="Задание: terra\n" + LIDIRUS_SOURCE_TEMPLATE
        )
    if "parus" in tasks:
        predict_parus(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "PARus.jsonl",
            nrows=nrows,
            template_cause="Задание: parus\n" + PARUS_CAUSE_SOURCE_TEMPLATE,
            template_effect="Задание: parus\n" + PARUS_EFFECT_SOURCE_TEMPLATE
        )
    if "muserc" in tasks:
        predict_muserc(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "MuSeRC.jsonl",
            nrows=nrows,
            template="Задание: muserc\n" + MUSERC_SOURCE_TEMPLATE,
            clean_func=clean_muserc
        )
    if "rucos" in tasks:
        predict_rucos(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "RuCoS.jsonl",
            nrows=nrows,
            template="Задание: rucos\n" + RUCOS_SOURCE_TEMPLATE,
            clean_func=clean_rucos
        )

    if "rwsd" in tasks:
        predict_rwsd(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "RWSD.jsonl",
            nrows=nrows,
            template="Задание: rwsd\n" + RWSD_SOURCE_TEMPLATE,
            clean_func=clean_rwsd_response
        )
    if "russe" in tasks:
        predict_russe(
            split=split,
            predict_func=predict_saiga_zero_shot_bound,
            output_path=predictions_dir / "RUSSE.jsonl",
            nrows=nrows,
            template="Задание: russe\n" + RUSSE_SOURCE_TEMPLATE,
            clean_func=clean_russe
        )


if __name__ == "__main__":
    fire.Fire(main)