DeepRegNet/DeepReg

View on GitHub
deepreg/train.py

Summary

Maintainability
A
2 hrs
Test Coverage
# coding=utf-8

"""
Module to train a network using init files and a CLI.
"""

import argparse
import os
from typing import Dict, List, Tuple, Union

import tensorflow as tf

import deepreg.config.parser as config_parser
import deepreg.model.optimizer as opt
from deepreg import log
from deepreg.callback import build_checkpoint_callback
from deepreg.registry import REGISTRY
from deepreg.util import build_dataset, build_log_dir

logger = log.get(__name__)


def build_config(
    config_path: Union[str, List[str]],
    log_dir: str,
    exp_name: str,
    ckpt_path: str,
    max_epochs: int = -1,
) -> Tuple[Dict, str, str]:
    """
    Function to initialise log directories,
    assert that checkpointed model is the right
    type and to parse the configuration for training.

    :param config_path: list of str, path to config file
    :param log_dir: path of the log directory
    :param exp_name: name of the experiment
    :param ckpt_path: path where model is stored.
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
    :return: - config: a dictionary saving configuration
             - exp_name: the path of directory to save logs
    """

    # init log directory
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)

    # load config
    config = config_parser.load_configs(config_path)

    # replace the ~ with user home path
    ckpt_path = os.path.expanduser(ckpt_path)

    # overwrite epochs and save_period if necessary
    if max_epochs > 0:
        config["train"]["epochs"] = max_epochs
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])

    # backup config
    config_parser.save(config=config, out_dir=log_dir)

    return config, log_dir, ckpt_path


def train(
    gpu: str,
    config_path: Union[str, List[str]],
    ckpt_path: str,
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    exp_name: str = "",
    log_dir: str = "logs",
    max_epochs: int = -1,
):
    """
    Function to train a model.

    :param gpu: which local gpu to use to train.
    :param config_path: path to configuration set up.
    :param ckpt_path: where to store training checkpoints.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
    """
    # set env variables
    if gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = (
            "true" if gpu_allow_growth else "false"
        )
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the training. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(
        config_path=config_path,
        log_dir=log_dir,
        exp_name=exp_name,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="valid",
        training=False,
        repeat=True,
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    batch_size = config["train"]["preprocess"]["batch_size"]
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices."
            )
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(
            config=dict(
                name=config["train"]["method"],
                moving_image_size=data_loader_train.moving_image_shape,
                fixed_image_size=data_loader_train.fixed_image_shape,
                index_size=data_loader_train.num_indices,
                labeled=config["dataset"]["train"]["labeled"],
                batch_size=batch_size,
                config=config["train"],
            )
        )
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # build callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=config["train"]["save_period"],
        update_freq=config["train"].get("update_freq", "epoch"),
    )
    ckpt_callback, initial_epoch = build_checkpoint_callback(
        model=model,
        dataset=dataset_train,
        log_dir=log_dir,
        save_period=config["train"]["save_period"],
        ckpt_path=ckpt_path,
    )
    callbacks = [tensorboard_callback, ckpt_callback]

    # train
    # it's necessary to define the steps_per_epoch
    # and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        initial_epoch=initial_epoch,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()


def main(args=None):
    """
    Entry point for train script.

    :param args: arguments
    """

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--gpu",
        "-g",
        help="GPU index for training."
        "-g for using GPU remotely"
        '-g "" for using CPU'
        '-g "0" for using GPU 0'
        '-g "0,1" for using GPU 0 and 1.',
        type=str,
        required=False,
    )

    parser.add_argument(
        "--gpu_allow_growth",
        "-gr",
        help="Prevent TensorFlow from reserving all available GPU memory",
        default=False,
    )

    parser.add_argument(
        "--num_workers",
        help="Number of CPUs to be used, <= 0 means unlimited.",
        type=int,
        default=1,
    )

    parser.add_argument(
        "--ckpt_path",
        "-k",
        help="Path of the saved model checkpoint to load."
        "No need to provide if start training from scratch.",
        default="",
        type=str,
        required=False,
    )

    parser.add_argument(
        "--log_dir", help="Path of log directory.", default="logs", type=str
    )

    parser.add_argument(
        "--exp_name",
        "-l",
        help="Name of log directory."
        "The directory is under log root, e.g. logs/ by default."
        "If not provided, a timestamp based folder will be created.",
        default="",
        type=str,
    )

    parser.add_argument(
        "--config_path",
        "-c",
        help="Path of config, must end with .yaml. Can pass multiple paths.",
        type=str,
        nargs="+",
        required=True,
    )

    parser.add_argument(
        "--max_epochs",
        help="The maximum number of epochs, -1 means following configuration.",
        type=int,
        default=-1,
    )

    args = parser.parse_args(args)
    train(
        gpu=args.gpu,
        config_path=args.config_path,
        num_workers=args.num_workers,
        gpu_allow_growth=args.gpu_allow_growth,
        ckpt_path=args.ckpt_path,
        log_dir=args.log_dir,
        exp_name=args.exp_name,
        max_epochs=args.max_epochs,
    )


if __name__ == "__main__":
    main()  # pragma: no cover