IlyaGusev/rulm

View on GitHub
self_instruct/src/dataset.py

Summary

Maintainability
A
3 hrs
Test Coverage
import json
import random
from typing import List, Dict

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from tqdm import tqdm

from src.util.chat import Conversation


class ChatDataset(Dataset):
    def __init__(
        self,
        original_records: List[Dict],
        tokenizer: AutoTokenizer,
        max_tokens_count: int,
        templates_path: str,
        sample_rate: float = 1.0,
        only_target_loss: bool = True,
        add_global_bos: bool = True,
        add_global_eos: bool = True,
        labels_pad_token_id: int = -100,
        truncation_side: str = "left"
    ):
        self.templates_path = templates_path
        with open(templates_path) as r:
            self.template = json.load(r)
        self.original_records = original_records
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer
        self.max_tokens_count = max_tokens_count
        self.only_target_loss = only_target_loss
        self.labels_pad_token_id = labels_pad_token_id
        self.add_global_bos = add_global_bos
        self.add_global_eos = add_global_eos
        self.truncation_side = truncation_side
        self.is_printed = False

        self.records = []
        for record in tqdm(original_records):
            if random.random() > self.sample_rate:
                continue

            mapping = {
                "bot": self.template["bot_role"],
                "assistant": self.template["bot_role"],
                "user": self.template["user_role"],
                "human": self.template["user_role"],
            }
            for m in record["messages"]:
                m["role"] = mapping.get(m["role"], m["role"])

            tensors = self.convert_record(record)
            if tensors is None:
                continue
            self.records.append(tensors)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, index):
        return self.records[index]

    def get_tokens(self, text):
        return self.tokenizer(
            text,
            add_special_tokens=False,
            padding=False,
            truncation=False
        )["input_ids"]

    def convert_record(self, record):
        conversation = Conversation.from_template(self.templates_path)
        conversation.expand(record["messages"])

        input_ids, labels = [], []
        for message, role in conversation.iter_messages():
            message_input_ids = self.get_tokens(message)
            message_labels = message_input_ids
            if len(input_ids) + len(message_input_ids) > self.max_tokens_count:
                break

            labels_mask = [self.labels_pad_token_id for _ in range(len(message_input_ids))]
            if role != conversation.bot_role and self.only_target_loss:
                message_labels = labels_mask

            input_ids.extend(message_input_ids)
            labels.extend(message_labels)

        if not input_ids:
            return None

        if self.add_global_bos and input_ids[0] != self.tokenizer.bos_token_id:
            input_ids.insert(0, self.tokenizer.bos_token_id)
            labels.insert(0, self.labels_pad_token_id)

        if input_ids[-2] == self.tokenizer.eos_token_id:
            input_ids = input_ids[:-1]
            labels = labels[:-1]

        if self.add_global_eos and input_ids[-1] != self.tokenizer.eos_token_id:
            input_ids.append(self.tokenizer.eos_token_id)
            labels.append(self.tokenizer.eos_token_id)

        if not self.is_printed:
            print(input_ids)
            print(labels)
            print("Full prompt:", self.tokenizer.decode(input_ids, skip_special_tokens=False))
            self.is_printed = True

        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.new_ones(input_ids.size())
        assert input_ids.size(0) == labels.size(0) == attention_mask.size(0) <= self.max_tokens_count
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask
        }