fepegar/torchio

View on GitHub
src/torchio/datasets/ixi.py

Summary

Maintainability
A
2 hrs
Test Coverage
"""The `Information eXtraction from Images (IXI) <https://brain-development.org/ixi-dataset/>`_
dataset contains "nearly 600 MR images from normal, healthy subjects",
including "T1, T2 and PD-weighted images, MRA images and Diffusion-weighted
images (15 directions)".

.. note ::
    This data is made available under the
    Creative Commons CC BY-SA 3.0 license.
    If you use it please acknowledge the source of the IXI data, e.g.
    `the IXI website <https://brain-development.org/ixi-dataset/>`_.
"""

# Adapted from
# https://pytorch.org/docs/stable/_modules/torchvision/datasets/mnist.html#MNIST
import shutil
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Sequence

from .. import LabelMap
from .. import ScalarImage
from .. import Subject
from .. import SubjectsDataset
from ..download import download_and_extract_archive
from ..transforms import Transform
from ..typing import TypePath


class IXI(SubjectsDataset):
    """Full IXI dataset.

    Args:
        root: Root directory to which the dataset will be downloaded.
        transform: An instance of
            :class:`~torchio.transforms.transform.Transform`.
        download: If set to ``True``, will download the data into :attr:`root`.
        modalities: List of modalities to be downloaded. They must be in
            ``('T1', 'T2', 'PD', 'MRA', 'DTI')``.

    .. warning:: The size of this dataset is multiple GB.
        If you set :attr:`download` to ``True``, it will take some time
        to be downloaded if it is not already present.

    Example:

        >>> import torchio as tio
        >>> transforms = [
        ...     tio.ToCanonical(),  # to RAS
        ...     tio.Resample((1, 1, 1)),  # to 1 mm iso
        ... ]
        >>> ixi_dataset = tio.datasets.IXI(
        ...     'path/to/ixi_root/',
        ...     modalities=('T1', 'T2'),
        ...     transform=tio.Compose(transforms),
        ...     download=True,
        ... )
        >>> print('Number of subjects in dataset:', len(ixi_dataset))  # 577
        >>> sample_subject = ixi_dataset[0]
        >>> print('Keys in subject:', tuple(sample_subject.keys()))  # ('T1', 'T2')
        >>> print('Shape of T1 data:', sample_subject['T1'].shape)  # [1, 180, 268, 268]
        >>> print('Shape of T2 data:', sample_subject['T2'].shape)  # [1, 241, 257, 188]
    """  # noqa: B950

    base_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-{modality}.tar'  # noqa: FS003,B950
    md5_dict = {
        'T1': '34901a0593b41dd19c1a1f746eac2d58',
        'T2': 'e3140d78730ecdd32ba92da48c0a9aaa',
        'PD': '88ecd9d1fa33cb4a2278183b42ffd749',
        'MRA': '29be7d2fee3998f978a55a9bdaf3407e',
        'DTI': '636573825b1c8b9e8c78f1877df3ee66',
    }

    def __init__(
        self,
        root: TypePath,
        transform: Optional[Transform] = None,
        download: bool = False,
        modalities: Sequence[str] = ('T1', 'T2'),
        **kwargs,
    ):
        root = Path(root)
        for modality in modalities:
            if modality not in self.md5_dict:
                message = (
                    f'Modality "{modality}" must be'
                    f' one of {tuple(self.md5_dict.keys())}'
                )
                raise ValueError(message)
        if download:
            self._download(root, modalities)
        if not self._check_exists(root, modalities):
            message = 'Dataset not found. You can use download=True to download it'
            raise RuntimeError(message)
        subjects_list = self._get_subjects_list(root, modalities)
        super().__init__(subjects_list, transform=transform, **kwargs)

    @staticmethod
    def _check_exists(root, modalities):
        for modality in modalities:
            modality_dir = root / modality
            if not modality_dir.is_dir():
                exists = False
                break
        else:
            exists = True
        return exists

    @staticmethod
    def _get_subjects_list(root, modalities):
        # The number of files for each modality is not the same
        # E.g. 581 for T1, 578 for T2
        # Let's just use the first modality as reference for now
        # I.e. only subjects with all modalities will be included
        one_modality = modalities[0]
        paths = sglob(root / one_modality, '*.nii.gz')
        subjects = []
        for filepath in paths:
            subject_id = get_subject_id(filepath)
            images_dict = {'subject_id': subject_id}
            images_dict[one_modality] = ScalarImage(filepath)
            for modality in modalities[1:]:
                globbed = sglob(
                    root / modality,
                    f'{subject_id}-{modality}.nii.gz',
                )
                if globbed:
                    assert len(globbed) == 1
                    images_dict[modality] = ScalarImage(globbed[0])
                else:
                    skip_subject = True
                    break
            else:
                skip_subject = False
            if skip_subject:
                continue
            subjects.append(Subject(**images_dict))
        return subjects

    def _download(self, root, modalities):
        """Download the IXI data if it does not exist already."""
        for modality in modalities:
            modality_dir = root / modality
            if modality_dir.is_dir():
                continue
            modality_dir.mkdir(exist_ok=True, parents=True)

            # download files
            url = self.base_url.format(modality=modality)
            md5 = self.md5_dict[modality]

            with NamedTemporaryFile(suffix='.tar', delete=False) as f:
                download_and_extract_archive(
                    url,
                    download_root=modality_dir,
                    filename=f.name,
                    md5=md5,
                )


class IXITiny(SubjectsDataset):
    r"""This is the dataset used in the main `notebook`_. It is a tiny version
    of IXI, containing 566 :math:`T_1`-weighted brain MR images and their
    corresponding brain segmentations, all with size :math:`83 \times 44 \times
    55`.

    It can be used as a medical image MNIST.

    Args:
        root: Root directory to which the dataset will be downloaded.
        transform: An instance of
            :class:`~torchio.transforms.transform.Transform`.
        download: If set to ``True``, will download the data into :attr:`root`.

    .. _notebook: https://github.com/fepegar/torchio/blob/main/tutorials/README.md
    """  # noqa: B950

    url = 'https://www.dropbox.com/s/ogxjwjxdv5mieah/ixi_tiny.zip?dl=1'
    md5 = 'bfb60f4074283d78622760230bfa1f98'

    def __init__(
        self,
        root: TypePath,
        transform: Optional[Transform] = None,
        download: bool = False,
        **kwargs,
    ):
        root = Path(root)
        if download:
            self._download(root)
        if not root.is_dir():
            message = 'Dataset not found. You can use download=True to download it'
            raise RuntimeError(message)
        subjects_list = self._get_subjects_list(root)
        super().__init__(subjects_list, transform=transform, **kwargs)

    @staticmethod
    def _get_subjects_list(root):
        image_paths = sglob(root / 'image', '*.nii.gz')
        label_paths = sglob(root / 'label', '*.nii.gz')
        if not (image_paths and label_paths):
            message = (
                f'Images not found. Remove the root directory ({root}) and try again'
            )
            raise FileNotFoundError(message)

        subjects = []
        for image_path, label_path in zip(image_paths, label_paths):
            subject_id = get_subject_id(image_path)
            subject_dict = {}
            subject_dict['image'] = ScalarImage(image_path)
            subject_dict['label'] = LabelMap(label_path)
            subject_dict['subject_id'] = subject_id
            subjects.append(Subject(**subject_dict))
        return subjects

    def _download(self, root):
        """Download the tiny IXI data if it doesn't exist already."""
        if root.is_dir():  # assume it's been downloaded
            print('Root directory for IXITiny found:', root)  # noqa: T201
            return
        print('Root directory for IXITiny not found:', root)  # noqa: T201
        print('Downloading...')  # noqa: T201
        with NamedTemporaryFile(suffix='.zip', delete=False) as f:
            download_and_extract_archive(
                self.url,
                download_root=root,
                filename=f.name,
                md5=self.md5,
            )
        ixi_tiny_dir = root / 'ixi_tiny'
        (ixi_tiny_dir / 'image').rename(root / 'image')
        (ixi_tiny_dir / 'label').rename(root / 'label')
        shutil.rmtree(ixi_tiny_dir)


def sglob(directory, pattern):
    return sorted(Path(directory).glob(pattern))


def get_subject_id(path):
    return '-'.join(path.name.split('-')[:-1])