fepegar/torchio

View on GitHub
src/torchio/transforms/augmentation/intensity/random_labels_to_image.py

Summary

Maintainability
D
1 day
Test Coverage
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

import torch

from .. import RandomTransform
from ... import IntensityTransform
from ....data.image import LabelMap
from ....data.image import ScalarImage
from ....data.subject import Subject
from ....typing import TypeData
from ....typing import TypeRangeFloat
from ....utils import check_sequence


class RandomLabelsToImage(RandomTransform, IntensityTransform):
    r"""Randomly generate an image from a segmentation.

    Based on the work by Billot et al.: `A Learning Strategy for Contrast-agnostic MRI Segmentation`_
    and `Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast`_.

    .. _A Learning Strategy for Contrast-agnostic MRI Segmentation: http://proceedings.mlr.press/v121/billot20a.html

    .. _Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast: https://link.springer.com/chapter/10.1007/978-3-030-59728-3_18

    .. plot::

        import torch
        import torchio as tio
        torch.manual_seed(42)
        colin = tio.datasets.Colin27(2008)
        label_map = colin.cls
        colin.remove_image('t1')
        colin.remove_image('t2')
        colin.remove_image('pd')
        downsample = tio.Resample(1)
        blurring_transform = tio.RandomBlur(std=0.6)
        create_synthetic_image = tio.RandomLabelsToImage(
            image_key='synthetic',
            ignore_background=True,
        )
        transform = tio.Compose((
            downsample,
            create_synthetic_image,
            blurring_transform,
        ))
        colin_synth = transform(colin)
        colin_synth.plot()

    Args:
        label_key: String designating the label map in the subject
            that will be used to generate the new image.
        used_labels: Sequence of integers designating the labels used
            to generate the new image. If categorical encoding is used,
            :attr:`label_channels` refers to the values of the
            categorical encoding. If one hot encoding or partial-volume
            label maps are used, :attr:`label_channels` refers to the
            channels of the label maps.
            Default uses all labels. Missing voxels will be filled with zero
            or with voxels from an already existing volume,
            see :attr:`image_key`.
        image_key: String designating the key to which the new volume will be
            saved. If this key corresponds to an already existing volume,
            missing voxels will be filled with the corresponding values
            in the original volume.
        mean: Sequence of means for each label.
            For each value :math:`v`, if a tuple :math:`(a, b)` is
            provided then :math:`v \sim \mathcal{U}(a, b)`.
            If ``None``, :attr:`default_mean` range will be used for every
            label.
            If not ``None`` and :attr:`label_channels` is not ``None``,
            :attr:`mean` and :attr:`label_channels` must have the
            same length.
        std: Sequence of standard deviations for each label.
            For each value :math:`v`, if a tuple :math:`(a, b)` is
            provided then :math:`v \sim \mathcal{U}(a, b)`.
            If ``None``, :attr:`default_std` range will be used for every
            label.
            If not ``None`` and :attr:`label_channels` is not ``None``,
            :attr:`std` and :attr:`label_channels` must have the
            same length.
        default_mean: Default mean range.
        default_std: Default standard deviation range.
        discretize: If ``True``, partial-volume label maps will be discretized.
            Does not have any effects if not using partial-volume label maps.
            Discretization is done taking the class of the highest value per
            voxel in the different partial-volume label maps using
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
        ignore_background: If ``True``, input voxels labeled as ``0`` will not
            be modified.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. tip:: It is recommended to blur the new images in order to simulate
        partial volume effects at the borders of the synthetic structures. See
        :class:`~torchio.transforms.augmentation.intensity.random_blur.RandomBlur`.

    Example:
        >>> import torchio as tio
        >>> subject = tio.datasets.ICBM2009CNonlinearSymmetric()
        >>> # Using the default parameters
        >>> transform = tio.RandomLabelsToImage(label_key='tissues')
        >>> # Using custom mean and std
        >>> transform = tio.RandomLabelsToImage(
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0]
        ... )
        >>> # Discretizing the partial volume maps and blurring the result
        >>> simulation_transform = tio.RandomLabelsToImage(
        ...     label_key='tissues', mean=[0.33, 0.66, 1.], std=[0, 0, 0], discretize=True
        ... )
        >>> blurring_transform = tio.RandomBlur(std=0.3)
        >>> transform = tio.Compose([simulation_transform, blurring_transform])
        >>> transformed = transform(subject)  # subject has a new key 'image_from_labels' with the simulated image
        >>> # Filling holes of the simulated image with the original T1 image
        >>> rescale_transform = tio.RescaleIntensity(
        ...     out_min_max=(0, 1), percentiles=(1, 99))   # Rescale intensity before filling holes
        >>> simulation_transform = tio.RandomLabelsToImage(
        ...     label_key='tissues',
        ...     image_key='t1',
        ...     used_labels=[0, 1]
        ... )
        >>> transform = tio.Compose([rescale_transform, simulation_transform])
        >>> transformed = transform(subject)  # subject's key 't1' has been replaced with the simulated image

    .. seealso:: :class:`~torchio.transforms.preprocessing.label.remap_labels.RemapLabels`.
    """  # noqa: B950

    def __init__(
        self,
        label_key: Optional[str] = None,
        used_labels: Optional[Sequence[int]] = None,
        image_key: str = 'image_from_labels',
        mean: Optional[Sequence[TypeRangeFloat]] = None,
        std: Optional[Sequence[TypeRangeFloat]] = None,
        default_mean: TypeRangeFloat = (0.1, 0.9),
        default_std: TypeRangeFloat = (0.01, 0.1),
        discretize: bool = False,
        ignore_background: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.label_key = _parse_label_key(label_key)
        self.used_labels = _parse_used_labels(used_labels)  # type: ignore[arg-type]  # noqa: B950
        self.mean, self.std = self.parse_mean_and_std(mean, std)  # type: ignore[arg-type,assignment]  # noqa: B950
        self.default_mean = self.parse_gaussian_parameter(
            default_mean,
            'default_mean',
        )
        self.default_std = self.parse_gaussian_parameter(
            default_std,
            'default_std',
        )
        self.image_key = image_key
        self.discretize = discretize
        self.ignore_background = ignore_background

    def parse_mean_and_std(
        self,
        mean: Sequence[TypeRangeFloat],
        std: Sequence[TypeRangeFloat],
    ) -> Tuple[List[TypeRangeFloat], List[TypeRangeFloat]]:
        if mean is not None:
            mean = self.parse_gaussian_parameters(mean, 'mean')
        if std is not None:
            std = self.parse_gaussian_parameters(std, 'std')
        if mean is not None and std is not None:
            message = (
                'If both "mean" and "std" are defined they must have the samelength'
            )
            assert len(mean) == len(std), message
        return mean, std

    def parse_gaussian_parameters(
        self,
        params: Sequence[TypeRangeFloat],
        name: str,
    ) -> List[TypeRangeFloat]:
        check_sequence(params, name)
        params = [
            self.parse_gaussian_parameter(p, f'{name}[{i}]')
            for i, p in enumerate(params)
        ]
        if self.used_labels is not None:
            message = (
                f'If both "{name}" and "used_labels" are defined, '
                'they must have the same length'
            )
            assert len(params) == len(self.used_labels), message
        return params

    @staticmethod
    def parse_gaussian_parameter(
        nums_range: TypeRangeFloat,
        name: str,
    ) -> Tuple[float, float]:
        if isinstance(nums_range, (int, float)):
            return nums_range, nums_range

        if len(nums_range) != 2:
            raise ValueError(
                f'If {name} is a sequence, it must be of len 2, not {nums_range}',
            )
        min_value, max_value = nums_range
        if min_value > max_value:
            raise ValueError(
                f'If {name} is a sequence, the second value must be'
                f' equal or greater than the first, not {nums_range}',
            )
        return min_value, max_value

    def apply_transform(self, subject: Subject) -> Subject:
        if self.label_key is None:
            iterable = subject.get_images_dict(intensity_only=False).items()
            for name, image in iterable:
                if isinstance(image, LabelMap):
                    self.label_key = name
                    break
            else:
                message = f'No label maps found in subject: {subject}'
                raise RuntimeError(message)

        arguments = {
            'label_key': self.label_key,
            'mean': [],
            'std': [],
            'image_key': self.image_key,
            'used_labels': self.used_labels,
            'discretize': self.discretize,
            'ignore_background': self.ignore_background,
        }

        label_map = subject[self.label_key].data

        # Find out if we face a partial-volume image or a label map.
        # One-hot-encoded label map is considered as a partial-volume image
        all_discrete = label_map.eq(label_map.float().round()).all()
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
        is_discretized = all_discrete and same_num_dims

        if not is_discretized and self.discretize:
            # Take label with highest value in voxel
            max_label, label_map = label_map.max(dim=0, keepdim=True)
            # Remove values where all labels are 0 (i.e. missing labels)
            label_map[max_label == 0] = -1
            is_discretized = True

        if is_discretized:
            labels = label_map.unique().long().tolist()
            if -1 in labels:
                labels.remove(-1)
        else:
            labels = range(label_map.shape[0])

        # Raise error if mean and std are not defined for every label
        _check_mean_and_std_length(labels, self.mean, self.std)  # type: ignore[arg-type]  # noqa: B950

        for label in labels:
            mean, std = self.get_params(label)
            means = arguments['mean']
            stds = arguments['std']
            assert isinstance(means, list)
            assert isinstance(stds, list)
            means.append(mean)
            stds.append(std)

        transform = LabelsToImage(**self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed

    def get_params(self, label: int) -> Tuple[float, float]:
        if self.mean is None:
            mean_range = self.default_mean
        else:
            assert isinstance(self.mean, Sequence)
            mean_range = self.mean[label]
        if self.std is None:
            std_range = self.default_std
        else:
            std_range = self.std[label]
        mean = self.sample_uniform(*mean_range)  # type: ignore[misc]
        std = self.sample_uniform(*std_range)  # type: ignore[misc]
        return mean, std


class LabelsToImage(IntensityTransform):
    r"""Generate an image from a segmentation.

    Args:
        label_key: String designating the label map in the subject
            that will be used to generate the new image.
        used_labels: Sequence of integers designating the labels used
            to generate the new image. If categorical encoding is used,
            :attr:`label_channels` refers to the values of the
            categorical encoding. If one hot encoding or partial-volume
            label maps are used, :attr:`label_channels` refers to the
            channels of the label maps.
            Default uses all labels. Missing voxels will be filled with zero
            or with voxels from an already existing volume,
            see :attr:`image_key`.
        image_key: String designating the key to which the new volume will be
            saved. If this key corresponds to an already existing volume,
            missing voxels will be filled with the corresponding values
            in the original volume.
        mean: Sequence of means for each label.
            If not ``None`` and :attr:`label_channels` is not ``None``,
            :attr:`mean` and :attr:`label_channels` must have the
            same length.
        std: Sequence of standard deviations for each label.
            If not ``None`` and :attr:`label_channels` is not ``None``,
            :attr:`std` and :attr:`label_channels` must have the
            same length.
        discretize: If ``True``, partial-volume label maps will be discretized.
            Does not have any effects if not using partial-volume label maps.
            Discretization is done taking the class of the highest value per
            voxel in the different partial-volume label maps using
            :func:`torch.argmax()` on the channel dimension (i.e. 0).
        ignore_background: If ``True``, input voxels labeled as ``0`` will not
            be modified.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. note:: It is recommended to blur the new images to make the result more
        realistic. See
        :class:`~torchio.transforms.augmentation.RandomBlur`.
    """

    def __init__(
        self,
        label_key: str,
        mean: Optional[Sequence[float]],
        std: Optional[Sequence[float]],
        image_key: str = 'image_from_labels',
        used_labels: Optional[Sequence[int]] = None,
        ignore_background: bool = False,
        discretize: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.label_key = _parse_label_key(label_key)
        self.used_labels = _parse_used_labels(used_labels)
        self.mean, self.std = mean, std  # type: ignore[assignment]
        self.image_key = image_key
        self.ignore_background = ignore_background
        self.discretize = discretize
        self.args_names = [
            'label_key',
            'mean',
            'std',
            'image_key',
            'used_labels',
            'ignore_background',
            'discretize',
        ]

    def apply_transform(self, subject: Subject) -> Subject:
        original_image = subject.get(self.image_key)

        label_map_image = subject[self.label_key]
        label_map = label_map_image.data
        affine = label_map_image.affine

        # Find out if we face a partial-volume image or a label map.
        # One-hot-encoded label map is considered as a partial-volume image
        all_discrete = label_map.eq(label_map.float().round()).all()
        same_num_dims = label_map.squeeze().dim() < label_map.dim()
        is_discretized = all_discrete and same_num_dims

        if not is_discretized and self.discretize:
            # Take label with highest value in voxel
            max_label, label_map = label_map.max(dim=0, keepdim=True)
            # Remove values where all labels are 0 (i.e. missing labels)
            label_map[max_label == 0] = -1
            is_discretized = True

        tissues = torch.zeros(1, *label_map_image.spatial_shape).float()
        if is_discretized:
            labels_in_image = label_map.unique().long().tolist()
            if -1 in labels_in_image:
                labels_in_image.remove(-1)
        else:
            labels_in_image = range(label_map.shape[0])

        # Raise error if mean and std are not defined for every label
        _check_mean_and_std_length(
            labels_in_image,
            self.mean,  # type: ignore[arg-type]
            self.std,
        )

        for i, label in enumerate(labels_in_image):
            if label == 0 and self.ignore_background:
                continue
            if self.used_labels is None or label in self.used_labels:
                assert isinstance(self.mean, Sequence)
                assert isinstance(self.std, Sequence)
                mean = self.mean[i]
                std = self.std[i]
                if is_discretized:
                    mask = label_map == label
                else:
                    mask = label_map[label]
                tissues += self.generate_tissue(mask, mean, std)

            else:
                # Modify label map to easily compute background mask
                if is_discretized:
                    label_map[label_map == label] = -1
                else:
                    label_map[label] = 0

        final_image = ScalarImage(affine=affine, tensor=tissues)

        if original_image is not None:
            if is_discretized:
                bg_mask = label_map == -1
            else:
                bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
            final_image.data[bg_mask] = original_image.data[bg_mask].float()

        subject.add_image(final_image, self.image_key)
        return subject

    @staticmethod
    def generate_tissue(
        data: TypeData,
        mean: float,
        std: float,
    ) -> TypeData:
        # Create the simulated tissue using a gaussian random variable
        gaussian = torch.randn(data.shape) * std + mean
        return gaussian * data


def _parse_label_key(label_key: Optional[str]) -> Optional[str]:
    if label_key is not None and not isinstance(label_key, str):
        message = f'"label_key" must be a string or None, not {type(label_key)}'
        raise TypeError(message)
    return label_key


def _parse_used_labels(
    used_labels: Optional[Sequence[int]],
) -> Optional[Sequence[int]]:
    if used_labels is None:
        return None
    check_sequence(used_labels, 'used_labels')
    for e in used_labels:
        if not isinstance(e, int):
            message = (
                'Items in "used_labels" must be integers,'
                f' but some are not: {used_labels}'
            )
            raise ValueError(message)
    return used_labels


def _check_mean_and_std_length(
    labels: Sequence[int],
    means: Optional[Sequence[TypeRangeFloat]],
    stds: Optional[Sequence[TypeRangeFloat]],
) -> None:
    num_labels = len(labels)
    if means is not None:
        num_means = len(means)
        message = (
            '"mean" must define a value for each label but length of "mean"'
            f' is {num_means} while {num_labels} labels were found'
        )
        if num_means != num_labels:
            raise RuntimeError(message)
    if stds is not None:
        num_stds = len(stds)
        message = (
            '"std" must define a value for each label but length of "std"'
            f' is {num_stds} while {num_labels} labels were found'
        )
        if num_stds != num_labels:
            raise RuntimeError(message)