fepegar/torchio

View on GitHub
src/torchio/transforms/augmentation/spatial/random_flip.py

Summary

Maintainability
A
0 mins
Test Coverage
from typing import List
from typing import Tuple
from typing import Union

import numpy as np
import torch

from .. import RandomTransform
from ... import SpatialTransform
from ....data.subject import Subject
from ....utils import to_tuple


class RandomFlip(RandomTransform, SpatialTransform):
    """Reverse the order of elements in an image along the given axes.

    Args:
        axes: Index or tuple of indices of the spatial dimensions along which
            the image might be flipped. If they are integers, they must be in
            ``(0, 1, 2)``. Anatomical labels may also be used, such as
            ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
            ``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
            ``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
            ``'i'`` (inferior). Only the first letter of the string will be
            used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
            used.
        flip_probability: Probability that the image will be flipped. This is
            computed on a per-axis basis.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    Example:
        >>> import torchio as tio
        >>> fpg = tio.datasets.FPG()
        >>> flip = tio.RandomFlip(axes=('LR',))  # flip along lateral axis only

    .. tip:: It is handy to specify the axes as anatomical labels when the
        image orientation is not known.
    """

    def __init__(
        self,
        axes: Union[int, Tuple[int, ...]] = 0,
        flip_probability: float = 0.5,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.axes = _parse_axes(axes)
        self.flip_probability = self.parse_probability(flip_probability)

    def apply_transform(self, subject: Subject) -> Subject:
        potential_axes = _ensure_axes_indices(subject, self.axes)
        axes_to_flip_hot = self.get_params(self.flip_probability)
        for i in range(3):
            if i not in potential_axes:
                axes_to_flip_hot[i] = False
        (axes,) = np.where(axes_to_flip_hot)
        axes = axes.tolist()
        if not axes:
            return subject

        arguments = {'axes': axes}
        transform = Flip(**self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed

    @staticmethod
    def get_params(probability: float) -> List[bool]:
        return (probability > torch.rand(3)).tolist()


class Flip(SpatialTransform):
    """Reverse the order of elements in an image along the given axes.

    Args:
        axes: Index or tuple of indices of the spatial dimensions along which
            the image will be flipped. See
            :class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
            for more information.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.

    .. tip:: It is handy to specify the axes as anatomical labels when the
        image orientation is not known.
    """

    def __init__(self, axes, **kwargs):
        super().__init__(**kwargs)
        self.axes = _parse_axes(axes)
        self.args_names = ('axes',)

    def apply_transform(self, subject: Subject) -> Subject:
        axes = _ensure_axes_indices(subject, self.axes)
        for image in self.get_images(subject):
            _flip_image(image, axes)
        return subject

    @staticmethod
    def is_invertible():
        return True

    def inverse(self):
        return self


def _parse_axes(axes: Union[int, Tuple[int, ...]]):
    axes_tuple = to_tuple(axes)
    for axis in axes_tuple:
        is_int = isinstance(axis, int)
        is_string = isinstance(axis, str)
        valid_number = is_int and axis in (0, 1, 2)
        if not is_string and not valid_number:
            message = (
                f'All axes must be 0, 1 or 2, but found "{axis}" with type {type(axis)}'
            )
            raise ValueError(message)
    return axes_tuple


def _ensure_axes_indices(subject, axes):
    if any(isinstance(n, str) for n in axes):
        subject.check_consistent_orientation()
        image = subject.get_first_image()
        axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
    return axes


def _flip_image(image, axes):
    spatial_axes = np.array(axes, int) + 1
    data = image.numpy()
    data = np.flip(data, axis=spatial_axes)
    data = data.copy()  # remove negative strides
    data = torch.as_tensor(data)
    image.set_data(data)