fepegar/torchio

View on GitHub
src/torchio/transforms/augmentation/intensity/random_motion.py

Summary

Maintainability
B
5 hrs
Test Coverage
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np
import SimpleITK as sitk
import torch

from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.io import nib_to_sitk
from ....data.subject import Subject
from ....typing import TypeTripletFloat


class RandomMotion(RandomTransform, IntensityTransform, FourierTransform):
    r"""Add random MRI motion artifact.

    Magnetic resonance images suffer from motion artifacts when the subject
    moves during image acquisition. This transform follows
    `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
    simulate motion artifacts for data augmentation.

    Args:
        degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of
            the simulated movements. The rotation angles around each axis are
            :math:`(\theta_1, \theta_2, \theta_3)`,
            where :math:`\theta_i \sim \mathcal{U}(a, b)`.
            If only one value :math:`d` is provided,
            :math:`\theta_i \sim \mathcal{U}(-d, d)`.
            Larger values generate more distorted images.
        translation: Tuple :math:`(a, b)` defining the translation in mm of
            the simulated movements. The translations along each axis are
            :math:`(t_1, t_2, t_3)`,
            where :math:`t_i \sim \mathcal{U}(a, b)`.
            If only one value :math:`t` is provided,
            :math:`t_i \sim \mathcal{U}(-t, t)`.
            Larger values generate more distorted images.
        num_transforms: Number of simulated movements.
            Larger values generate more distorted images.
        image_interpolation: See :ref:`Interpolation`.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. warning:: Large numbers of movements lead to longer execution times for
        3D images.
    """

    def __init__(
        self,
        degrees: Union[float, Tuple[float, float]] = 10,
        translation: Union[float, Tuple[float, float]] = 10,  # in mm
        num_transforms: int = 2,
        image_interpolation: str = 'linear',
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.degrees_range = self.parse_degrees(degrees)
        self.translation_range = self.parse_translation(translation)
        if num_transforms < 1 or not isinstance(num_transforms, int):
            message = (
                'Number of transforms must be a strictly positive natural'
                f'number, not {num_transforms}'
            )
            raise ValueError(message)
        self.num_transforms = num_transforms
        self.image_interpolation = self.parse_interpolation(
            image_interpolation,
        )

    def apply_transform(self, subject: Subject) -> Subject:
        arguments: Dict[str, dict] = defaultdict(dict)
        for name, image in self.get_images_dict(subject).items():
            params = self.get_params(
                self.degrees_range,
                self.translation_range,
                self.num_transforms,
                is_2d=image.is_2d(),
            )
            times_params, degrees_params, translation_params = params
            arguments['times'][name] = times_params
            arguments['degrees'][name] = degrees_params
            arguments['translation'][name] = translation_params
            arguments['image_interpolation'][name] = self.image_interpolation
        transform = Motion(**self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed

    def get_params(
        self,
        degrees_range: Tuple[float, float],
        translation_range: Tuple[float, float],
        num_transforms: int,
        perturbation: float = 0.3,
        is_2d: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        # If perturbation is 0, time intervals between movements are constant
        degrees_params = self.get_params_array(
            degrees_range,
            num_transforms,
        )
        translation_params = self.get_params_array(
            translation_range,
            num_transforms,
        )
        if is_2d:  # imagine sagittal (1, A, S)
            degrees_params[:, :-1] = 0  # rotate around Z axis only
            translation_params[:, 2] = 0  # translate in XY plane only
        step = 1 / (num_transforms + 1)
        times = torch.arange(0, 1, step)[1:]
        noise = torch.FloatTensor(num_transforms)
        noise.uniform_(-step * perturbation, step * perturbation)
        times += noise
        times_params = times.numpy()
        return times_params, degrees_params, translation_params

    @staticmethod
    def get_params_array(nums_range: Tuple[float, float], num_transforms: int):
        tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range)
        return tensor.numpy()


class Motion(IntensityTransform, FourierTransform):
    r"""Add MRI motion artifact.

    Magnetic resonance images suffer from motion artifacts when the subject
    moves during image acquisition. This transform follows
    `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
    simulate motion artifacts for data augmentation.

    Args:
        degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`.
        translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm.
        times: Sequence of times from 0 to 1 at which the motions happen.
        image_interpolation: See :ref:`Interpolation`.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.
    """

    def __init__(
        self,
        degrees: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
        translation: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
        times: Union[Sequence[float], Dict[str, Sequence[float]]],
        image_interpolation: Union[
            Sequence[str], Dict[str, Sequence[str]]
        ],  # noqa: B950
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.degrees = degrees
        self.translation = translation
        self.times = times
        self.image_interpolation = image_interpolation
        self.args_names = [
            'degrees',
            'translation',
            'times',
            'image_interpolation',
        ]

    def apply_transform(self, subject: Subject) -> Subject:
        degrees = self.degrees
        translation = self.translation
        times = self.times
        image_interpolation = self.image_interpolation
        for image_name, image in self.get_images_dict(subject).items():
            if self.arguments_are_dict():
                assert isinstance(self.degrees, dict)
                assert isinstance(self.translation, dict)
                assert isinstance(self.times, dict)
                assert isinstance(self.image_interpolation, dict)
                degrees = self.degrees[image_name]
                translation = self.translation[image_name]
                times = self.times[image_name]
                image_interpolation = self.image_interpolation[image_name]
            result_arrays = []
            for channel in image.data:
                sitk_image = nib_to_sitk(
                    channel[np.newaxis],
                    image.affine,
                    force_3d=True,
                )
                transforms = self.get_rigid_transforms(
                    np.asarray(degrees),
                    np.asarray(translation),
                    sitk_image,
                )
                assert isinstance(image_interpolation, str)
                transformed_channel = self.add_artifact(
                    sitk_image,
                    transforms,
                    np.asarray(times),
                    image_interpolation,
                )
                result_arrays.append(transformed_channel)
            result = np.stack(result_arrays)
            image.set_data(torch.as_tensor(result))
        return subject

    def get_rigid_transforms(
        self,
        degrees_params: np.ndarray,
        translation_params: np.ndarray,
        image: sitk.Image,
    ) -> List[sitk.Euler3DTransform]:
        center_ijk = np.array(image.GetSize()) / 2
        center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
        identity = np.eye(4)
        matrices = [identity]
        for degrees, translation in zip(degrees_params, translation_params):
            radians = np.radians(degrees).tolist()
            motion = sitk.Euler3DTransform()
            motion.SetCenter(center_lps)
            motion.SetRotation(*radians)
            motion.SetTranslation(translation.tolist())
            motion_matrix = self.transform_to_matrix(motion)
            matrices.append(motion_matrix)
        transforms = [self.matrix_to_transform(m) for m in matrices]
        return transforms

    @staticmethod
    def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray:
        matrix = np.eye(4)
        rotation = np.array(transform.GetMatrix()).reshape(3, 3)
        matrix[:3, :3] = rotation
        matrix[:3, 3] = transform.GetTranslation()
        return matrix

    @staticmethod
    def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform:
        transform = sitk.Euler3DTransform()
        rotation = matrix[:3, :3].flatten().tolist()
        transform.SetMatrix(rotation)
        transform.SetTranslation(matrix[:3, 3])
        return transform

    def resample_images(
        self,
        image: sitk.Image,
        transforms: Sequence[sitk.Euler3DTransform],
        interpolation: str,
    ) -> List[sitk.Image]:
        floating = reference = image
        default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
        transforms = transforms[1:]  # first is identity
        images = [image]  # first is identity
        for transform in transforms:
            interpolator = self.get_sitk_interpolator(interpolation)
            resampler = sitk.ResampleImageFilter()
            resampler.SetInterpolator(interpolator)
            resampler.SetReferenceImage(reference)
            resampler.SetOutputPixelType(sitk.sitkFloat32)
            resampler.SetDefaultPixelValue(default_value)
            resampler.SetTransform(transform)
            resampled = resampler.Execute(floating)
            images.append(resampled)
        return images

    @staticmethod
    def sort_spectra(spectra: List[torch.Tensor], times: np.ndarray):
        """Use original spectrum to fill the center of k-space."""
        num_spectra = len(spectra)
        if np.any(times > 0.5):
            index = np.where(times > 0.5)[0].min()
        else:
            index = num_spectra - 1
        spectra[0], spectra[index] = spectra[index], spectra[0]

    def add_artifact(
        self,
        image: sitk.Image,
        transforms: Sequence[sitk.Euler3DTransform],
        times: np.ndarray,
        interpolation: str,
    ):
        images = self.resample_images(image, transforms, interpolation)
        spectra = []
        for image in images:
            array = sitk.GetArrayFromImage(image).transpose()  # sitk to np
            spectrum = self.fourier_transform(torch.from_numpy(array))
            spectra.append(spectrum)
        self.sort_spectra(spectra, times)
        result_spectrum = torch.empty_like(spectra[0])
        last_index = result_spectrum.shape[2]
        indices = (last_index * times).astype(int).tolist()
        indices.append(last_index)
        ini = 0
        for spectrum, fin in zip(spectra, indices):
            result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
            ini = fin
        result_image = self.inv_fourier_transform(result_spectrum).real.float()
        return result_image