DeepRegNet/DeepReg

View on GitHub
deepreg/dataset/loader/grouped_loader.py

Summary

Maintainability
C
1 day
Test Coverage
"""
Load grouped data.
Supported formats: h5 and Nifti.
Image data can be labeled or unlabeled.
Read https://deepreg.readthedocs.io/en/latest/api/loader.html#module-deepreg.dataset.loader.grouped_loader for more details.
"""
import random
from copy import deepcopy
from typing import List, Optional, Tuple, Union

from deepreg.dataset.loader.interface import (
    AbstractUnpairedDataLoader,
    GeneratorDataLoader,
)
from deepreg.dataset.util import check_difference_between_two_lists
from deepreg.registry import REGISTRY


@REGISTRY.register_data_loader(name="grouped")
class GroupedDataLoader(AbstractUnpairedDataLoader, GeneratorDataLoader):
    """
    Load grouped data.

    Yield indexes of images to load using
    sample_index_generator from GeneratorDataLoader.
    AbstractUnpairedLoader handles different file formats
    """

    def __init__(
        self,
        file_loader,
        data_dir_paths: List[str],
        labeled: bool,
        sample_label: Optional[str],
        intra_group_prob: float,
        intra_group_option: str,
        sample_image_in_group: bool,
        seed: Optional[int],
        image_shape: Union[Tuple[int, ...], List[int]],
    ):
        """
        :param file_loader: a subclass of FileLoader
        :param data_dir_paths: paths of the directory storing data,
          the data has to be saved under two different sub-directories:

          - images
          - labels

        :param labeled: bool, true if the data is labeled, false if unlabeled
        :param sample_label: "sample" or "all", read `get_label_indices`
            in deepreg/dataset/util.py for more details.
        :param intra_group_prob: float between 0 and 1,

          - 0 means generating only inter-group samples,
          - 1 means generating only intra-group samples

        :param intra_group_option: str, "forward", "backward, or "unconstrained"
        :param sample_image_in_group: bool,

          - if true, only one image pair will be yielded for each group,
            so one epoch has num_groups pairs of data,
          - if false, iterate through this loader will generate all possible pairs

        :param seed: controls the randomness in sampling,
            if seed=None, then the randomness is not fixed
        :param image_shape: list or tuple of length 3,
            corresponding to (dim1, dim2, dim3) of the 3D image
        """
        super().__init__(
            image_shape=image_shape,
            labeled=labeled,
            sample_label=sample_label,
            seed=seed,
        )
        assert isinstance(
            data_dir_paths, list
        ), f"data_dir_paths must be list of strings, got {data_dir_paths}"
        # init
        # the indices for identifying an image pair is (group1, sample1, group2, sample2, label)
        self.num_indices = 5
        self.intra_group_option = intra_group_option
        self.intra_group_prob = intra_group_prob
        self.sample_image_in_group = sample_image_in_group
        # set file loaders
        # grouped data are not paired data, so moving/fixed share the same file loader for images/labels
        loader_image = file_loader(
            dir_paths=data_dir_paths, name="images", grouped=True
        )
        self.loader_moving_image = loader_image
        self.loader_fixed_image = loader_image
        if self.labeled is True:
            loader_label = file_loader(
                dir_paths=data_dir_paths, name="labels", grouped=True
            )
            self.loader_moving_label = loader_label
            self.loader_fixed_label = loader_label
        self.validate_data_files()
        # get group related stats
        self.num_groups = self.loader_moving_image.get_num_groups()
        self.num_images_per_group = self.loader_moving_image.get_num_images_per_group()
        if self.intra_group_prob < 1:
            if self.num_groups < 2:
                raise ValueError(
                    f"There are {self.num_groups} groups, "
                    f"we need at least two groups for inter group sampling"
                )
        # calculate number of samples and save pre-calculated sample indices
        if self.sample_image_in_group is True:
            # one image pair in each group (pair) will be yielded
            self.sample_indices = None
            self._num_samples = self.num_groups
        else:
            # all possible pair in each group (pair) will be yielded
            if intra_group_prob not in [0, 1]:
                raise ValueError(
                    "Mixing intra and inter groups is not supported"
                    " when not sampling pairs."
                )
            if intra_group_prob == 0:  # inter group
                self.sample_indices = self.get_inter_sample_indices()
            else:  # intra group
                self.sample_indices = self.get_intra_sample_indices()

            self._num_samples = len(self.sample_indices)  # type: ignore

    def validate_data_files(self):
        """If the data are labeled, verify image loader and label loader have the same files."""
        if self.labeled is True:
            image_ids = self.loader_moving_image.get_data_ids()
            label_ids = self.loader_moving_label.get_data_ids()
            check_difference_between_two_lists(
                list1=image_ids,
                list2=label_ids,
                name="images and labels in grouped loader",
            )

    def get_intra_sample_indices(self) -> list:
        """
        Calculate the sample indices for intra-group sampling
        The index to identify a sample is (group1, image1, group2, image2), means
        - image1 of group1 is moving image
        - image2 of group2 is fixed image

        Assuming group i has ni images,
        then in total the number of samples are
        - sum( ni * (ni-1) / 2 ) for forward/backward
        - sum( ni * (ni-1) ) for unconstrained

        :return: a list of sample indices
        """
        intra_sample_indices = []
        for group_index in range(self.num_groups):
            num_images_in_group = self.num_images_per_group[group_index]
            if self.intra_group_option == "forward":
                for i in range(num_images_in_group):
                    for j in range(i):
                        # j < i
                        intra_sample_indices.append((group_index, j, group_index, i))
            elif self.intra_group_option == "backward":
                for i in range(num_images_in_group):
                    for j in range(i):
                        # i > j
                        intra_sample_indices.append((group_index, i, group_index, j))
            elif self.intra_group_option == "unconstrained":
                for i in range(num_images_in_group):
                    for j in range(i):
                        # j < i, i > j
                        intra_sample_indices.append((group_index, j, group_index, i))
                        intra_sample_indices.append((group_index, i, group_index, j))
            else:
                raise ValueError(
                    "Unknown intra_group_option, must be forward/backward/unconstrained"
                )
        return intra_sample_indices

    def get_inter_sample_indices(self) -> list:
        """
        Calculate the sample indices for inter-group sampling
        The index to identify a sample is (group1, image1, group2, image2), means

          - image1 of group1 is moving image
          - image2 of group2 is fixed image

        All pairs of images in the dataset are registered.
        Assuming group i has ni images and that N=[n1, n2, ..., nI],
        then in total the number of samples are:
        sum(N) * (sum(N)-1) - sum( N * (N-1) )

        :return: a list of sample indices
        """
        inter_sample_indices = []
        for group_index1 in range(self.num_groups):
            for group_index2 in range(self.num_groups):
                if group_index1 == group_index2:  # do not sample from the same group
                    continue
                num_images_in_group1 = self.num_images_per_group[group_index1]
                num_images_in_group2 = self.num_images_per_group[group_index2]
                for image_index1 in range(num_images_in_group1):
                    for image_index2 in range(num_images_in_group2):
                        inter_sample_indices.append(
                            (group_index1, image_index1, group_index2, image_index2)
                        )
        return inter_sample_indices

    def sample_index_generator(self):
        """
        Yield (moving_index, fixed_index, image_indices) sequentially, where

          - moving_index = (group1, image1)
          - fixed_index = (group2, image2)
          - image_indices = [group1, image1, group2, image2]
        """
        rnd = random.Random(self.seed)  # set random seed
        if self.sample_image_in_group is True:
            # for each group sample one image pair only
            group_indices = [i for i in range(self.num_groups)]
            rnd.shuffle(group_indices)
            for group_index in group_indices:
                if rnd.random() <= self.intra_group_prob:
                    # intra-group sampling
                    # inside the group_index-th group, we sample two images as moving/fixed
                    group_index1 = group_index
                    group_index2 = group_index
                    num_images_in_group = self.num_images_per_group[group_index]
                    if num_images_in_group < 2:
                        # skip groups having <2 images
                        # currently have not encountered
                        continue  # pragma: no cover

                    image_index1, image_index2 = rnd.sample(
                        [i for i in range(num_images_in_group)], 2
                    )  # sample two unique indices
                    if self.intra_group_option == "forward":
                        # image_index1 < image_index2
                        image_index1, image_index2 = (
                            min(image_index1, image_index2),
                            max(image_index1, image_index2),
                        )
                    elif self.intra_group_option == "backward":
                        # image_index1 > image_index2
                        image_index1, image_index2 = (
                            max(image_index1, image_index2),
                            min(image_index1, image_index2),
                        )
                    elif self.intra_group_option == "unconstrained":
                        pass
                    else:
                        raise ValueError(
                            f"Unknown intra_group_option, "
                            f"must be forward/backward/unconstrained, "
                            f"got {self.intra_group_option}"
                        )
                else:
                    # inter-group sampling
                    # we sample another group, then in each group we sample one image
                    group_index1 = group_index
                    group_index2 = rnd.choice(
                        [i for i in range(self.num_groups) if i != group_index]
                    )
                    num_images_in_group1 = self.num_images_per_group[group_index1]
                    num_images_in_group2 = self.num_images_per_group[group_index2]
                    image_index1 = rnd.choice([i for i in range(num_images_in_group1)])
                    image_index2 = rnd.choice([i for i in range(num_images_in_group2)])

                moving_index = (group_index1, image_index1)
                fixed_index = (group_index2, image_index2)
                image_indices = [group_index1, image_index1, group_index2, image_index2]
                yield moving_index, fixed_index, image_indices
        else:
            # sample indices are pre-calculated
            assert self.sample_indices is not None
            sample_indices = deepcopy(self.sample_indices)
            rnd.shuffle(sample_indices)  # shuffle in place
            for sample_index in sample_indices:
                group_index1, image_index1, group_index2, image_index2 = sample_index
                moving_index = (group_index1, image_index1)
                fixed_index = (group_index2, image_index2)
                image_indices = [group_index1, image_index1, group_index2, image_index2]
                yield moving_index, fixed_index, image_indices

    def close(self):
        """Close file loaders"""
        self.loader_moving_image.close()
        if self.labeled is True:
            self.loader_moving_label.close()