fepegar/torchio

View on GitHub
src/torchio/transforms/augmentation/intensity/random_bias_field.py

Summary

Maintainability
A
1 hr
Test Coverage
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union

import numpy as np
import torch

from .. import RandomTransform
from ... import IntensityTransform
from ....data.subject import Subject
from ....typing import TypeData


class RandomBiasField(RandomTransform, IntensityTransform):
    r"""Add random MRI bias field artifact.

    MRI magnetic field inhomogeneity creates intensity
    variations of very low frequency across the whole image.

    The bias field is modeled as a linear combination of
    polynomial basis functions, as in K. Van Leemput et al., 1999,
    *Automated model-based tissue classification of MR images of the brain*.

    It was implemented in NiftyNet by Carole Sudre and used in
    `Sudre et al., 2017, Longitudinal segmentation of age-related
    white matter hyperintensities
    <https://www.sciencedirect.com/science/article/pii/S1361841517300257?via%3Dihub>`_.

    Args:
        coefficients: Maximum magnitude :math:`n` of polynomial coefficients.
            If a tuple :math:`(a, b)` is specified, then
            :math:`n \sim \mathcal{U}(a, b)`.
        order: Order of the basis polynomial functions.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.
    """

    def __init__(
        self,
        coefficients: Union[float, Tuple[float, float]] = 0.5,
        order: int = 3,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.coefficients_range = self._parse_range(
            coefficients,
            'coefficients_range',
        )
        self.order = _parse_order(order)

    def apply_transform(self, subject: Subject) -> Subject:
        arguments: Dict[str, dict] = defaultdict(dict)
        for image_name in self.get_images_dict(subject):
            coefficients = self.get_params(
                self.order,
                self.coefficients_range,
            )
            arguments['coefficients'][image_name] = coefficients
            arguments['order'][image_name] = self.order
        transform = BiasField(**self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed

    def get_params(
        self,
        order: int,
        coefficients_range: Tuple[float, float],
    ) -> List[float]:
        # Sampling of the appropriate number of coefficients for the creation
        # of the bias field map
        random_coefficients = []
        for x_order in range(0, order + 1):
            for y_order in range(0, order + 1 - x_order):
                for _ in range(0, order + 1 - (x_order + y_order)):
                    sample = self.sample_uniform(*coefficients_range)
                    random_coefficients.append(sample)
        return random_coefficients


class BiasField(IntensityTransform):
    r"""Add MRI bias field artifact.

    Args:
        coefficients: Magnitudes of the polinomial coefficients.
        order: Order of the basis polynomial functions.
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
            keyword arguments.
    """

    def __init__(
        self,
        coefficients: Union[List[float], Dict[str, List[float]]],
        order: Union[int, Dict[str, int]],
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.coefficients = coefficients
        self.order = order
        self.invert_transform = False
        self.args_names = ['coefficients', 'order']

    def arguments_are_dict(self):
        coefficients_dict = isinstance(self.coefficients, dict)
        order_dict = isinstance(self.order, dict)
        if coefficients_dict != order_dict:
            message = 'If one of the arguments is a dict, all must be'
            raise ValueError(message)
        return coefficients_dict and order_dict

    def apply_transform(self, subject: Subject) -> Subject:
        coefficients, order = self.coefficients, self.order
        for name, image in self.get_images_dict(subject).items():
            if self.arguments_are_dict():
                assert isinstance(self.coefficients, dict)
                assert isinstance(self.order, dict)
                coefficients, order = self.coefficients[name], self.order[name]
            assert isinstance(order, int)
            bias_field = self.generate_bias_field(
                image.data,
                order,
                coefficients,  # type: ignore[arg-type]
            )
            if self.invert_transform:
                np.divide(1, bias_field, out=bias_field)
            image.set_data(image.data * torch.as_tensor(bias_field))
        return subject

    @staticmethod
    def generate_bias_field(
        data: TypeData,
        order: int,
        coefficients: TypeData,
    ) -> np.ndarray:
        # Create the bias field map using a linear combination of polynomial
        # functions and the coefficients previously sampled
        shape = np.array(data.shape[1:])  # first axis is channels
        half_shape = shape / 2

        ranges = [np.arange(-n, n) + 0.5 for n in half_shape]

        bias_field = np.zeros(shape)
        meshes = np.asarray(np.meshgrid(*ranges))

        for mesh in meshes:
            mesh_max = mesh.max()
            if mesh_max > 0:
                mesh /= mesh_max
        x_mesh, y_mesh, z_mesh = meshes

        i = 0
        for x_order in range(order + 1):
            for y_order in range(order + 1 - x_order):
                for z_order in range(order + 1 - (x_order + y_order)):
                    coefficient = coefficients[i]
                    new_map = (
                        coefficient
                        * x_mesh**x_order
                        * y_mesh**y_order
                        * z_mesh**z_order
                    )
                    bias_field += np.transpose(new_map, (1, 0, 2))  # why?
                    i += 1
        bias_field = np.exp(bias_field).astype(np.float32)
        return bias_field


def _parse_order(order):
    if not isinstance(order, int):
        raise TypeError(f'Order must be an int, not {type(order)}')
    if order < 0:
        raise ValueError(f'Order must be a positive int, not {order}')
    return order