deepreg/dataset/preprocess.py
"""
Module containing data augmentation techniques.
- 3D Affine/DDF Transforms for moving and fixed images.
"""
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from deepreg.model.layer import Resize3d
from deepreg.model.layer_util import get_reference_grid, resample, warp_grid
from deepreg.registry import REGISTRY
class RandomTransformation3D(tf.keras.layers.Layer):
"""
An interface for different types of transformation.
"""
def __init__(
self,
moving_image_size: Tuple[int, ...],
fixed_image_size: Tuple[int, ...],
batch_size: int,
name: str = "RandomTransformation3D",
trainable: bool = False,
):
"""
Abstract class for image transformation.
:param moving_image_size: (m_dim1, m_dim2, m_dim3)
:param fixed_image_size: (f_dim1, f_dim2, f_dim3)
:param batch_size: total number of samples consumed per step, over all devices.
:param name: name of layer
:param trainable: if this layer is trainable
"""
super().__init__(trainable=trainable, name=name)
self.moving_image_size = moving_image_size
self.fixed_image_size = fixed_image_size
self.batch_size = batch_size
self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
@abstractmethod
def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Generates transformation parameters for moving and fixed image.
:return: two tensors
"""
@staticmethod
@abstractmethod
def transform(
image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
) -> tf.Tensor:
"""
Transforms the reference grid and then resample the image.
:param image: shape = (batch, dim1, dim2, dim3)
:param grid_ref: shape = (dim1, dim2, dim3, 3)
:param params: parameters for transformation
:return: shape = (batch, dim1, dim2, dim3)
"""
def call(self, inputs: Dict[str, tf.Tensor], **kwargs) -> Dict[str, tf.Tensor]:
"""
Creates random params for the input images and their labels,
and params them based on the resampled reference grids.
:param inputs: a dict having multiple tensors
if labeled:
moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
moving_label, shape = (batch, m_dim1, m_dim2, m_dim3)
fixed_label, shape = (batch, f_dim1, f_dim2, f_dim3)
indices, shape = (batch, num_indices)
else, unlabeled:
moving_image, shape = (batch, m_dim1, m_dim2, m_dim3)
fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
indices, shape = (batch, num_indices)
:param kwargs: other arguments
:return: dictionary with the same structure as inputs
"""
moving_image = inputs["moving_image"]
fixed_image = inputs["fixed_image"]
indices = inputs["indices"]
moving_params, fixed_params = self.gen_transform_params()
moving_image = self.transform(moving_image, self.moving_grid_ref, moving_params)
fixed_image = self.transform(fixed_image, self.fixed_grid_ref, fixed_params)
if "moving_label" not in inputs: # unlabeled
return dict(
moving_image=moving_image, fixed_image=fixed_image, indices=indices
)
moving_label = inputs["moving_label"]
fixed_label = inputs["fixed_label"]
moving_label = self.transform(moving_label, self.moving_grid_ref, moving_params)
fixed_label = self.transform(fixed_label, self.fixed_grid_ref, fixed_params)
return dict(
moving_image=moving_image,
fixed_image=fixed_image,
moving_label=moving_label,
fixed_label=fixed_label,
indices=indices,
)
def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config["moving_image_size"] = self.moving_image_size
config["fixed_image_size"] = self.fixed_image_size
config["batch_size"] = self.batch_size
return config
@REGISTRY.register_data_augmentation(name="affine")
class RandomAffineTransform3D(RandomTransformation3D):
"""Apply random affine transformation to moving/fixed images separately."""
def __init__(
self,
moving_image_size: Tuple[int, ...],
fixed_image_size: Tuple[int, ...],
batch_size: int,
scale: float = 0.1,
name: str = "RandomAffineTransform3D",
**kwargs,
):
"""
Init.
:param moving_image_size: (m_dim1, m_dim2, m_dim3)
:param fixed_image_size: (f_dim1, f_dim2, f_dim3)
:param batch_size: total number of samples consumed per step, over all devices.
:param scale: a positive float controlling the scale of transformation
:param name: name of the layer
:param kwargs: additional arguments
"""
super().__init__(
moving_image_size=moving_image_size,
fixed_image_size=fixed_image_size,
batch_size=batch_size,
name=name,
**kwargs,
)
self.scale = scale
def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config["scale"] = self.scale
return config
def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Function that generates the random 3D transformation parameters
for a batch of data for moving and fixed image.
:return: a tuple of tensors, each has shape = (batch, 4, 3)
"""
theta = gen_rand_affine_transform(
batch_size=self.batch_size * 2, scale=self.scale
)
return theta[: self.batch_size], theta[self.batch_size :]
@staticmethod
def transform(
image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
) -> tf.Tensor:
"""
Transforms the reference grid and then resample the image.
:param image: shape = (batch, dim1, dim2, dim3)
:param grid_ref: shape = (dim1, dim2, dim3, 3)
:param params: shape = (batch, 4, 3)
:return: shape = (batch, dim1, dim2, dim3)
"""
return resample(vol=image, loc=warp_grid(grid_ref, params))
@REGISTRY.register_data_augmentation(name="ddf")
class RandomDDFTransform3D(RandomTransformation3D):
"""Apply random DDF transformation to moving/fixed images separately."""
def __init__(
self,
moving_image_size: Tuple[int, ...],
fixed_image_size: Tuple[int, ...],
batch_size: int,
field_strength: int = 1,
low_res_size: tuple = (1, 1, 1),
name: str = "RandomDDFTransform3D",
**kwargs,
):
"""
Creates a DDF transformation for data augmentation.
To simulate smooth deformation fields, we interpolate from a low resolution
field of size low_res_size using linear interpolation. The variance of the
deformation field is drawn from a uniform variable
between [0, field_strength].
:param moving_image_size: tuple
:param fixed_image_size: tuple
:param batch_size: total number of samples consumed per step, over all devices.
:param field_strength: int = 1. It is used as the upper bound for the
deformation field variance
:param low_res_size: tuple = (1, 1, 1).
:param name: name of layer
:param kwargs: additional arguments
"""
super().__init__(
moving_image_size=moving_image_size,
fixed_image_size=fixed_image_size,
batch_size=batch_size,
name=name,
**kwargs,
)
assert tuple(low_res_size) <= tuple(moving_image_size)
assert tuple(low_res_size) <= tuple(fixed_image_size)
self.field_strength = field_strength
self.low_res_size = low_res_size
def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config["field_strength"] = self.field_strength
config["low_res_size"] = self.low_res_size
return config
def gen_transform_params(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Generates two random ddf fields for moving and fixed images.
:return: tuple, one has shape = (batch, m_dim1, m_dim2, m_dim3, 3)
another one has shape = (batch, f_dim1, f_dim2, f_dim3, 3)
"""
moving = gen_rand_ddf(
image_size=self.moving_image_size,
batch_size=self.batch_size,
field_strength=self.field_strength,
low_res_size=self.low_res_size,
)
fixed = gen_rand_ddf(
image_size=self.fixed_image_size,
batch_size=self.batch_size,
field_strength=self.field_strength,
low_res_size=self.low_res_size,
)
return moving, fixed
@staticmethod
def transform(
image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor
) -> tf.Tensor:
"""
Transforms the reference grid and then resample the image.
:param image: shape = (batch, dim1, dim2, dim3)
:param grid_ref: shape = (dim1, dim2, dim3, 3)
:param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
:return: shape = (batch, dim1, dim2, dim3)
"""
return resample(vol=image, loc=grid_ref[None, ...] + params)
def resize_inputs(
inputs: Dict[str, tf.Tensor],
moving_image_size: Tuple[int, ...],
fixed_image_size: tuple,
) -> Dict[str, tf.Tensor]:
"""
Resize inputs
:param inputs:
if labeled:
moving_image, shape = (None, None, None)
fixed_image, shape = (None, None, None)
moving_label, shape = (None, None, None)
fixed_label, shape = (None, None, None)
indices, shape = (num_indices, )
else, unlabeled:
moving_image, shape = (None, None, None)
fixed_image, shape = (None, None, None)
indices, shape = (num_indices, )
:param moving_image_size: Tuple[int, ...], (m_dim1, m_dim2, m_dim3)
:param fixed_image_size: Tuple[int, ...], (f_dim1, f_dim2, f_dim3)
:return:
if labeled:
moving_image, shape = (m_dim1, m_dim2, m_dim3)
fixed_image, shape = (f_dim1, f_dim2, f_dim3)
moving_label, shape = (m_dim1, m_dim2, m_dim3)
fixed_label, shape = (f_dim1, f_dim2, f_dim3)
indices, shape = (num_indices, )
else, unlabeled:
moving_image, shape = (m_dim1, m_dim2, m_dim3)
fixed_image, shape = (f_dim1, f_dim2, f_dim3)
indices, shape = (num_indices, )
"""
moving_image = inputs["moving_image"]
fixed_image = inputs["fixed_image"]
indices = inputs["indices"]
moving_resize_layer = Resize3d(shape=moving_image_size)
fixed_resize_layer = Resize3d(shape=fixed_image_size)
moving_image = moving_resize_layer(moving_image)
fixed_image = fixed_resize_layer(fixed_image)
if "moving_label" not in inputs: # unlabeled
return dict(moving_image=moving_image, fixed_image=fixed_image, indices=indices)
moving_label = inputs["moving_label"]
fixed_label = inputs["fixed_label"]
moving_label = moving_resize_layer(moving_label)
fixed_label = fixed_resize_layer(fixed_label)
return dict(
moving_image=moving_image,
fixed_image=fixed_image,
moving_label=moving_label,
fixed_label=fixed_label,
indices=indices,
)
def gen_rand_affine_transform(
batch_size: int, scale: float, seed: Optional[int] = None
) -> tf.Tensor:
"""
Function that generates a random 3D transformation parameters for a batch of data.
for 3D coordinates, affine transformation is
.. code-block:: text
[[x' y' z' 1]] = [[x y z 1]] * [[* * * 0]
[* * * 0]
[* * * 0]
[* * * 1]]
where each * represents a degree of freedom,
so there are in total 12 degrees of freedom
the equation can be denoted as
new = old * T
where
- new is the transformed coordinates, of shape (1, 4)
- old is the original coordinates, of shape (1, 4)
- T is the transformation matrix, of shape (4, 4)
the equation can be simplified to
.. code-block:: text
[[x' y' z']] = [[x y z 1]] * [[* * *]
[* * *]
[* * *]
[* * *]]
so that
new = old * T
where
- new is the transformed coordinates, of shape (1, 3)
- old is the original coordinates, of shape (1, 4)
- T is the transformation matrix, of shape (4, 3)
Given original and transformed coordinates,
we can calculate the transformation matrix using
x = np.linalg.lstsq(a, b)
such that
a x = b
In our case,
- a = old
- b = new
- x = T
To generate random transformation,
we choose to add random perturbation to corner coordinates as follows:
for corner of coordinates (x, y, z), the noise is
-(x, y, z) .* (r1, r2, r3)
where ri is a random number between (0, scale).
So
(x', y', z') = (x, y, z) .* (1-r1, 1-r2, 1-r3)
Thus, we can directly sample between 1-scale and 1 instead
We choose to calculate the transformation based on
four corners in a cube centered at (0, 0, 0).
A cube is shown as below, where
- C = (-1, -1, -1)
- G = (-1, -1, 1)
- D = (-1, 1, -1)
- A = (1, -1, -1)
.. code-block:: text
G — — — — — — — — H
/ | / |
/ | / |
/ | / |
/ | / |
/ | / |
E — — — — — — — — F |
| | | |
| | | |
| C — — | — — — — — D
| / | /
| / | /
| / | /
| / | /
| / | /
A — — — — — — — — B
:param batch_size: total number of samples consumed per step, over all devices.
:param scale: a float number between 0 and 1
:param seed: control the randomness
:return: shape = (batch, 4, 3)
"""
assert 0 <= scale <= 1
np.random.seed(seed)
noise = np.random.uniform(1 - scale, 1, [batch_size, 4, 3]) # shape = (batch, 4, 3)
# old represents four corners of a cube
# corresponding to the corner C G D A as shown above
old = np.tile(
[[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, 1, -1, 1], [1, -1, -1, 1]]],
[batch_size, 1, 1],
) # shape = (batch, 4, 4)
new = old[:, :, :3] * noise # shape = (batch, 4, 3)
theta = np.array(
[np.linalg.lstsq(old[k], new[k], rcond=-1)[0] for k in range(batch_size)]
) # shape = (batch, 4, 3)
return tf.cast(theta, dtype=tf.float32)
def gen_rand_ddf(
batch_size: int,
image_size: Tuple[int, ...],
field_strength: Union[Tuple, List, int, float],
low_res_size: Union[Tuple, List],
seed: Optional[int] = None,
) -> tf.Tensor:
"""
Function that generates a random 3D DDF for a batch of data.
:param batch_size: total number of samples consumed per step, over all devices.
:param image_size:
:param field_strength: maximum field strength, computed as a U[0,field_strength]
:param low_res_size: low_resolution deformation field that will be upsampled to
the original size in order to get smooth and more realistic fields.
:param seed: control the randomness
:return:
"""
np.random.seed(seed)
low_res_strength = np.random.uniform(0, field_strength, (batch_size, 1, 1, 1, 3))
low_res_field = low_res_strength * np.random.randn(
batch_size, low_res_size[0], low_res_size[1], low_res_size[2], 3
)
high_res_field = Resize3d(shape=image_size)(low_res_field)
return high_res_field