fepegar/torchio

View on GitHub
src/torchio/datasets/rsna_miccai.py

Summary

Maintainability
A
2 hrs
Test Coverage
import csv
import warnings
from pathlib import Path
from typing import Dict
from typing import List
from typing import Sequence
from typing import Union

from .. import ScalarImage
from .. import Subject
from .. import SubjectsDataset
from ..typing import TypePath


class RSNAMICCAI(SubjectsDataset):
    """RSNA-MICCAI Brain Tumor Radiogenomic Classification challenge dataset.

    This is a helper class for the dataset used in the
    `RSNA-MICCAI Brain Tumor Radiogenomic Classification challenge`_ hosted on
    `kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before
    instantiating this class (as opposed to, e.g., :class:`torchio.datasets.IXI`).

    This `kaggle kernel <https://www.kaggle.com/fepegar/preprocessing-mri-with-torchio/>`_
    includes a usage example including preprocessing of all the scans.

    If you reference or use the dataset in any form, include the following
    citation:

    U.Baid, et al., "The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor
    Segmentation and Radiogenomic Classification", arXiv:2107.02314, 2021.

    Args:
        root_dir: Directory containing the dataset (``train`` directory,
            ``test`` directory, etc.).
        train: If ``True``, the ``train`` set will be used. Otherwise the
            ``test`` set will be used.
        ignore_empty: If ``True``, the three subjects flagged as "presenting
            issues" (empty images) by the challenge organizers will be ignored.
            The subject IDs are ``00109``, ``00123`` and ``00709``.

    Example:
        >>> import torchio as tio
        >>> from subprocess import call
        >>> call('kaggle competitions download -c rsna-miccai-brain-tumor-radiogenomic-classification'.split())
        >>> root_dir = 'rsna-miccai-brain-tumor-radiogenomic-classification'
        >>> train_set = tio.datasets.RSNAMICCAI(root_dir, train=True)
        >>> test_set = tio.datasets.RSNAMICCAI(root_dir, train=False)
        >>> len(train_set), len(test_set)
        (582, 87)


    .. _RSNA-MICCAI Brain Tumor Radiogenomic Classification challenge: https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification
    """  # noqa: B950

    id_key = 'BraTS21ID'
    label_key = 'MGMT_value'
    bad_subjects = '00109', '00123', '00709'

    def __init__(
        self,
        root_dir: TypePath,
        train: bool = True,
        ignore_empty: bool = True,
        modalities: Sequence[str] = ('T1w', 'T1wCE', 'T2w', 'FLAIR'),
        **kwargs,
    ):
        self.root_dir = Path(root_dir).expanduser().resolve()
        if isinstance(modalities, str):
            modalities = [modalities]
        self.modalities = modalities
        subjects = self._get_subjects(self.root_dir, train, ignore_empty)
        super().__init__(subjects, **kwargs)
        self.train = train

    def _get_subjects(
        self,
        root_dir: Path,
        train: bool,
        ignore_empty: bool,
    ) -> List[Subject]:
        subjects = []
        if train:
            csv_path = root_dir / 'train_labels.csv'
            try:
                with open(csv_path) as csvfile:
                    reader = csv.DictReader(csvfile)
                    labels_dict = {
                        row[self.id_key]: int(row[self.label_key]) for row in reader
                    }
            except FileNotFoundError:
                warnings.warn(
                    'Labels CSV not found. Ignoring MGMT labels',
                    stacklevel=2,
                )
                labels_dict = {}
            subjects_dir = root_dir / 'train'
        else:
            subjects_dir = root_dir / 'test'

        for subject_dir in sorted(subjects_dir.iterdir()):
            subject_id = subject_dir.name
            if ignore_empty and subject_id in self.bad_subjects:
                continue
            try:
                int(subject_id)
            except ValueError:
                continue
            images_dict: Dict[str, Union[str, int, ScalarImage]]
            images_dict = {self.id_key: subject_dir.name}
            if train and labels_dict:
                images_dict[self.label_key] = labels_dict[subject_id]
            for modality in self.modalities:
                image_dir = subject_dir / modality
                filepaths = list(image_dir.iterdir())
                num_files = len(filepaths)
                path = filepaths[0] if num_files == 1 else image_dir
                images_dict[modality] = ScalarImage(path)
            subject = Subject(images_dict)
            subjects.append(subject)
        return subjects