fepegar/torchio

View on GitHub
src/torchio/data/loader.py

Summary

Maintainability
A
35 mins
Test Coverage
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import TypeVar

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from .subject import Subject

T = TypeVar('T')


class SubjectsLoader(DataLoader):
    def __init__(
        self,
        dataset: Dataset,
        collate_fn: Optional[Callable[[List[T]], Any]] = None,
        **kwargs,
    ):
        if collate_fn is None:
            collate_fn = self._collate  # type: ignore[assignment]
        super().__init__(
            dataset=dataset,
            collate_fn=collate_fn,
            **kwargs,
        )

    @staticmethod
    def _collate(subjects: List[Subject]) -> Dict[str, Any]:
        first_subject = subjects[0]
        batch_dict = {}
        for key in first_subject.keys():
            collated_value = _stack([subject[key] for subject in subjects])
            batch_dict[key] = collated_value
        return batch_dict


def _stack(x):
    """Determine the type of the input and stack it accordingly.

    Args:
        x: List of elements to stack.
    Returns:
        Stacked elements, as either a torch.Tensor, np.ndarray, dict or list.
    """
    first_element = x[0]
    if isinstance(first_element, torch.Tensor):
        return torch.stack(x, dim=0)
    elif isinstance(first_element, np.ndarray):
        return np.stack(x, axis=0)
    elif isinstance(first_element, dict):
        # Assume that all elements have the same keys
        collated_dict = {}
        for key in first_element.keys():
            collated_dict[key] = _stack([element[key] for element in x])
        return collated_dict
    else:
        return x