fepegar/torchio

View on GitHub
src/torchio/transforms/augmentation/intensity/random_spike.py

Summary

Maintainability
A
25 mins
Test Coverage
from collections import defaultdict
from numbers import Number
from typing import Dict
from typing import Tuple
from typing import Union

import numpy as np
import torch

from .. import RandomTransform
from ... import FourierTransform
from ... import IntensityTransform
from ....data.subject import Subject


class RandomSpike(RandomTransform, IntensityTransform, FourierTransform):
    r"""Add random MRI spike artifacts.

    Also known as `Herringbone artifact
    <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
    crisscross artifact or corduroy artifact, it creates stripes in different
    directions in image space due to spikes in k-space.

    Args:
        num_spikes: Number of spikes :math:`n` present in k-space.
            If a tuple :math:`(a, b)` is provided, then
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
            If only one value :math:`d` is provided,
            :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
            Larger values generate more distorted images.
        intensity: Ratio :math:`r` between the spike intensity and the maximum
            of the spectrum.
            If a tuple :math:`(a, b)` is provided, then
            :math:`r \sim \mathcal{U}(a, b)`.
            If only one value :math:`d` is provided,
            :math:`r \sim \mathcal{U}(-d, d)`.
            Larger values generate more distorted images.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. note:: The execution time of this transform does not depend on the
        number of spikes.
    """

    def __init__(
        self,
        num_spikes: Union[int, Tuple[int, int]] = 1,
        intensity: Union[float, Tuple[float, float]] = (1, 3),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.intensity_range = self._parse_range(
            intensity,
            'intensity_range',
        )
        self.num_spikes_range: Tuple[int, int] = self._parse_range(  # type: ignore[assignment]  # noqa: B950
            num_spikes,
            'num_spikes',
            min_constraint=0,
            type_constraint=int,
        )

    def apply_transform(self, subject: Subject) -> Subject:
        arguments: Dict[str, dict] = defaultdict(dict)
        for image_name in self.get_images_dict(subject):
            spikes_positions_param, intensity_param = self.get_params(
                self.num_spikes_range,
                self.intensity_range,
            )
            arguments['spikes_positions'][image_name] = spikes_positions_param
            arguments['intensity'][image_name] = intensity_param
        transform = Spike(**self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed

    def get_params(
        self,
        num_spikes_range: Tuple[int, int],
        intensity_range: Tuple[float, float],
    ) -> Tuple[np.ndarray, float]:
        ns_min, ns_max = num_spikes_range
        num_spikes_param = int(torch.randint(ns_min, ns_max + 1, (1,)).item())
        intensity_param = self.sample_uniform(*intensity_range)
        spikes_positions = torch.rand(num_spikes_param, 3).numpy()
        return spikes_positions, intensity_param


class Spike(IntensityTransform, FourierTransform):
    r"""Add MRI spike artifacts.

    Also known as `Herringbone artifact
    <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
    crisscross artifact or corduroy artifact, it creates stripes in different
    directions in image space due to spikes in k-space.

    Args:
        spikes_positions:
        intensity: Ratio :math:`r` between the spike intensity and the maximum
            of the spectrum.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. note:: The execution time of this transform does not depend on the
        number of spikes.
    """

    def __init__(
        self,
        spikes_positions: Union[np.ndarray, Dict[str, np.ndarray]],
        intensity: Union[float, Dict[str, float]],
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.spikes_positions = spikes_positions
        self.intensity = intensity
        self.args_names = ['spikes_positions', 'intensity']
        self.invert_transform = False

    def apply_transform(self, subject: Subject) -> Subject:
        spikes_positions = self.spikes_positions
        intensity = self.intensity
        for image_name, image in self.get_images_dict(subject).items():
            if self.arguments_are_dict():
                spikes_positions = self.spikes_positions[image_name]
                assert isinstance(self.intensity, dict)
                intensity = self.intensity[image_name]
            transformed_tensors = []
            for channel in image.data:
                assert isinstance(intensity, Number)
                transformed_tensor = self.add_artifact(
                    channel,
                    np.asarray(spikes_positions),
                    intensity,
                )
                transformed_tensors.append(transformed_tensor)
            image.set_data(torch.stack(transformed_tensors))
        return subject

    def add_artifact(
        self,
        tensor: torch.Tensor,
        spikes_positions: np.ndarray,
        intensity_factor: float,
    ):
        if intensity_factor == 0 or len(spikes_positions) == 0:
            return tensor
        spectrum = self.fourier_transform(tensor)
        shape = np.array(spectrum.shape)
        mid_shape = shape // 2
        indices = np.floor(spikes_positions * shape).astype(int)
        for index in indices:
            diff = index - mid_shape
            i, j, k = mid_shape + diff
            # As of torch 1.7, "max is not yet implemented for complex tensors"
            artifact = spectrum.cpu().numpy().max() * intensity_factor
            if self.invert_transform:
                spectrum[i, j, k] -= artifact
            else:
                spectrum[i, j, k] += artifact
            # If we wanted to add a pure cosine, we should add spikes to both
            # sides of k-space. However, having only one is a better
            # representation og the actual cause of the artifact in real
            # scans. Therefore the next two lines have been removed.
            # #i, j, k = mid_shape - diff
            # #spectrum[i, j, k] = spectrum.max() * intensity_factor
        result = self.inv_fourier_transform(spectrum).real.float()
        return result