tensorflow/models

View on GitHub
official/vision/ops/augment.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Augmentation policies for enhanced image/video preprocessing.

AutoAugment Reference:
  - AutoAugment Reference: https://arxiv.org/abs/1805.09501
  - AutoAugment for Object Detection Reference: https://arxiv.org/abs/1906.11172
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
  - Mixup: https://arxiv.org/abs/1710.09412
  - Cutmix: https://arxiv.org/abs/1905.04899

RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models

SSDRandCrop Reference:
  - Liu et al., SSD: Single shot multibox detector:
    https://arxiv.org/abs/1512.02325
  - Implementation from TF Object Detection API:
    https://github.com/tensorflow/models/
"""
from collections.abc import Sequence
import inspect
import math
from typing import Any, Iterable, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf, tf_keras

from official.vision.configs import common as configs
from official.vision.ops import box_ops

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.


def to_4d(image: tf.Tensor) -> tf.Tensor:
  """Converts an input Tensor to 4 dimensions.

  4D image => [N, H, W, C] or [N, C, H, W]
  3D image => [1, H, W, C] or [1, C, H, W]
  2D image => [1, H, W, 1]

  Args:
    image: The 2/3/4D input tensor.

  Returns:
    A 4D image tensor.

  Raises:
    `TypeError` if `image` is not a 2/3/4D tensor.

  """
  shape = tf.shape(image)
  original_rank = tf.rank(image)
  left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
  right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
  new_shape = tf.concat(
      [
          tf.ones(shape=left_pad, dtype=tf.int32),
          shape,
          tf.ones(shape=right_pad, dtype=tf.int32),
      ],
      axis=0,
  )
  return tf.reshape(image, new_shape)


def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
  """Converts a 4D image back to `ndims` rank."""
  shape = tf.shape(image)
  begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
  end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32)
  new_shape = shape[begin:end]
  return tf.reshape(image, new_shape)


def _pad(
    image: tf.Tensor,
    filter_shape: Union[List[int], Tuple[int, ...]],
    mode: str = 'CONSTANT',
    constant_values: Union[int, tf.Tensor] = 0,
) -> tf.Tensor:
  """Explicitly pads a 4-D image.

  Equivalent to the implicit padding method offered in `tf.nn.conv2d` and
  `tf.nn.depthwise_conv2d`, but supports non-zero, reflect and symmetric
  padding mode. For the even-sized filter, it pads one more value to the
  right or the bottom side.

  Args:
    image: A 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
    filter_shape: A `tuple`/`list` of 2 integers, specifying the height and
      width of the 2-D filter.
    mode: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". The type of
      padding algorithm to use, which is compatible with `mode` argument in
      `tf.pad`. For more details, please refer to
      https://www.tensorflow.org/api_docs/python/tf/pad.
    constant_values: A `scalar`, the pad value to use in "CONSTANT" padding
      mode.

  Returns:
    A padded image.
  """
  if mode.upper() not in {'REFLECT', 'CONSTANT', 'SYMMETRIC'}:
    raise ValueError(
        'padding should be one of "REFLECT", "CONSTANT", or "SYMMETRIC".'
    )
  constant_values = tf.convert_to_tensor(constant_values, image.dtype)
  filter_height, filter_width = filter_shape
  pad_top = (filter_height - 1) // 2
  pad_bottom = filter_height - 1 - pad_top
  pad_left = (filter_width - 1) // 2
  pad_right = filter_width - 1 - pad_left
  paddings = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  return tf.pad(image, paddings, mode=mode, constant_values=constant_values)


def _get_gaussian_kernel(sigma, filter_shape):
  """Computes 1D Gaussian kernel."""
  sigma = tf.convert_to_tensor(sigma)
  x = tf.range(-filter_shape // 2 + 1, filter_shape // 2 + 1)
  x = tf.cast(x**2, sigma.dtype)
  x = tf.nn.softmax(-x / (2.0 * (sigma**2)))
  return x


def _get_gaussian_kernel_2d(gaussian_filter_x, gaussian_filter_y):
  """Computes 2D Gaussian kernel given 1D kernels."""
  gaussian_kernel = tf.matmul(gaussian_filter_x, gaussian_filter_y)
  return gaussian_kernel


def _normalize_tuple(value, n, name):
  """Transforms an integer or iterable of integers into an integer tuple.

  Args:
    value: The value to validate and convert. Could an int, or any iterable of
      ints.
    n: The size of the tuple to be returned.
    name: The name of the argument being validated, e.g. "strides" or
      "kernel_size". This is only used to format error messages.

  Returns:
    A tuple of n integers.

  Raises:
    ValueError: If something else than an int/long or iterable thereof was
      passed.
  """
  if isinstance(value, int):
    return (value,) * n
  else:
    try:
      value_tuple = tuple(value)
    except TypeError as exc:
      raise TypeError(
          f'The {name} argument must be a tuple of {n} integers. '
          f'Received: {value}'
      ) from exc
    if len(value_tuple) != n:
      raise ValueError(
          f'The {name} argument must be a tuple of {n} integers. '
          f'Received: {value}'
      )
    for single_value in value_tuple:
      try:
        int(single_value)
      except (ValueError, TypeError) as exc:
        raise ValueError(
            f'The {name} argument must be a tuple of {n} integers. Received:'
            f' {value} including element {single_value} of type'
            f' {type(single_value)}.'
        ) from exc
    return value_tuple


def gaussian_filter2d(
    image: tf.Tensor,
    filter_shape: Union[List[int], Tuple[int, ...], int],
    sigma: Union[List[float], Tuple[float, float], float] = 1.0,
    padding: str = 'REFLECT',
    constant_values: Union[int, tf.Tensor] = 0,
    name: Optional[str] = None,
) -> tf.Tensor:
  """Performs Gaussian blur on image(s).

  Args:
    image: Either a 2-D `Tensor` of shape `[height, width]`, a 3-D `Tensor` of
      shape `[height, width, channels]`, or a 4-D `Tensor` of shape
      `[batch_size, height, width, channels]`.
    filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying the
      height and width of the 2-D gaussian filter. Can be a single integer to
      specify the same value for all spatial dimensions.
    sigma: A `float` or `tuple`/`list` of 2 floats, specifying the standard
      deviation in x and y direction the 2-D gaussian filter. Can be a single
      float to specify the same value for all spatial dimensions.
    padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". The type
      of padding algorithm to use, which is compatible with `mode` argument in
      `tf.pad`. For more details, please refer to
      https://www.tensorflow.org/api_docs/python/tf/pad.
    constant_values: A `scalar`, the pad value to use in "CONSTANT" padding
      mode.
    name: A name for this operation (optional).

  Returns:
    2-D, 3-D or 4-D `Tensor` of the same dtype as input.

  Raises:
    ValueError: If `image` is not 2, 3 or 4-dimensional,
      if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
      if `filter_shape` is invalid,
      or if `sigma` is invalid.
  """
  with tf.name_scope(name or 'gaussian_filter2d'):
    if isinstance(sigma, (list, tuple)):
      if len(sigma) != 2:
        raise ValueError('sigma should be a float or a tuple/list of 2 floats')
    else:
      sigma = (sigma,) * 2

    if any(s < 0 for s in sigma):
      raise ValueError('sigma should be greater than or equal to 0.')

    image = tf.convert_to_tensor(image, name='image')
    sigma = tf.convert_to_tensor(sigma, name='sigma')

    original_ndims = tf.rank(image)
    image = to_4d(image)

    # Keep the precision if it's float;
    # otherwise, convert to float32 for computing.
    orig_dtype = image.dtype
    if not image.dtype.is_floating:
      image = tf.cast(image, tf.float32)

    channels = tf.shape(image)[3]
    filter_shape = _normalize_tuple(filter_shape, 2, 'filter_shape')

    sigma = tf.cast(sigma, image.dtype)
    gaussian_kernel_x = _get_gaussian_kernel(sigma[1], filter_shape[1])
    gaussian_kernel_x = gaussian_kernel_x[tf.newaxis, :]

    gaussian_kernel_y = _get_gaussian_kernel(sigma[0], filter_shape[0])
    gaussian_kernel_y = gaussian_kernel_y[:, tf.newaxis]

    gaussian_kernel_2d = _get_gaussian_kernel_2d(
        gaussian_kernel_y, gaussian_kernel_x
    )
    gaussian_kernel_2d = gaussian_kernel_2d[:, :, tf.newaxis, tf.newaxis]
    gaussian_kernel_2d = tf.tile(gaussian_kernel_2d, [1, 1, channels, 1])

    image = _pad(
        image, filter_shape, mode=padding, constant_values=constant_values
    )

    output = tf.nn.depthwise_conv2d(
        input=image,
        filter=gaussian_kernel_2d,
        strides=(1, 1, 1, 1),
        padding='VALID',
    )
    output = from_4d(output, original_ndims)
    return tf.cast(output, orig_dtype)


def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
  """Converts translations to a projective transform.

  The translation matrix looks like this:
    [[1 0 -dx]
     [0 1 -dy]
     [0 0 1]]

  Args:
    translations: The 2-element list representing [dx, dy], or a matrix of
      2-element lists representing [dx dy] to translate for each image. The
      shape must be static.

  Returns:
    The transformation matrix of shape (num_images, 8).

  Raises:
    `TypeError` if
      - the shape of `translations` is not known or
      - the shape of `translations` is not rank 1 or 2.

  """
  translations = tf.convert_to_tensor(translations, dtype=tf.float32)
  if translations.get_shape().ndims is None:
    raise TypeError('translations rank must be statically known')
  elif len(translations.get_shape()) == 1:
    translations = translations[None]
  elif len(translations.get_shape()) != 2:
    raise TypeError('translations should have rank 1 or 2.')
  num_translations = tf.shape(translations)[0]

  return tf.concat(
      values=[
          tf.ones((num_translations, 1), tf.dtypes.float32),
          tf.zeros((num_translations, 1), tf.dtypes.float32),
          -translations[:, 0, None],
          tf.zeros((num_translations, 1), tf.dtypes.float32),
          tf.ones((num_translations, 1), tf.dtypes.float32),
          -translations[:, 1, None],
          tf.zeros((num_translations, 2), tf.dtypes.float32),
      ],
      axis=1,
  )


def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
                                 image_height: tf.Tensor) -> tf.Tensor:
  """Converts an angle or angles to a projective transform.

  Args:
    angles: A scalar to rotate all images, or a vector to rotate a batch of
      images. This must be a scalar.
    image_width: The width of the image(s) to be transformed.
    image_height: The height of the image(s) to be transformed.

  Returns:
    A tensor of shape (num_images, 8).

  Raises:
    `TypeError` if `angles` is not rank 0 or 1.

  """
  angles = tf.convert_to_tensor(angles, dtype=tf.float32)
  if len(angles.get_shape()) == 0:  # pylint:disable=g-explicit-length-test
    angles = angles[None]
  elif len(angles.get_shape()) != 1:
    raise TypeError('Angles should have a rank 0 or 1.')
  x_offset = ((image_width - 1) -
              (tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) *
               (image_height - 1))) / 2.0
  y_offset = ((image_height - 1) -
              (tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) *
               (image_height - 1))) / 2.0
  num_angles = tf.shape(angles)[0]
  return tf.concat(
      values=[
          tf.math.cos(angles)[:, None],
          -tf.math.sin(angles)[:, None],
          x_offset[:, None],
          tf.math.sin(angles)[:, None],
          tf.math.cos(angles)[:, None],
          y_offset[:, None],
          tf.zeros((num_angles, 2), tf.dtypes.float32),
      ],
      axis=1,
  )


def _apply_transform_to_images(
    images,
    transforms,
    fill_mode='reflect',
    fill_value=0.0,
    interpolation='bilinear',
    output_shape=None,
    name=None,
):
  """Applies the given transform(s) to the image(s).

  Args:
    images: A tensor of shape `(num_images, num_rows, num_columns,
      num_channels)` (NHWC). The rank must be statically known (the shape is
      not `TensorShape(None)`).
    transforms: Projective transform matrix/matrices. A vector of length 8 or
      tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1,
      b2, c0, c1], then it maps the *output* point `(x, y)` to a transformed
      *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) /
      k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared
      to the transform mapping input points to output points. Note that
      gradients are not backpropagated into transformation parameters.
    fill_mode: Points outside the boundaries of the input are filled according
      to the given mode (one of `{"constant", "reflect", "wrap", "nearest"}`).
    fill_value: a float represents the value to be filled outside the
      boundaries when `fill_mode="constant"`.
    interpolation: Interpolation mode. Supported values: `"nearest"`,
      `"bilinear"`.
    output_shape: Output dimension after the transform, `[height, width]`. If
      `None`, output is the same size as input image.
    name: The name of the op.  Fill mode behavior for each valid value is as
      follows
      - `"reflect"`: `(d c b a | a b c d | d c b a)` The input is extended by
      reflecting about the edge of the last pixel.
      - `"constant"`: `(k k k k | a b c d | k k k k)` The input is extended by
      filling all values beyond the edge with the same constant value k = 0.
      - `"wrap"`: `(a b c d | a b c d | a b c d)` The input is extended by
      wrapping around to the opposite edge.
      - `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by
      the nearest pixel.  Input shape: 4D tensor with shape:
      `(samples, height, width, channels)`, in `"channels_last"` format.
      Output shape: 4D tensor with shape: `(samples, height, width, channels)`,
      in `"channels_last"` format.

  Returns:
    Image(s) with the same type and shape as `images`, with the given
    transform(s) applied. Transformed coordinates outside of the input image
    will be filled with zeros.
  """
  with tf.name_scope(name or 'transform'):
    if output_shape is None:
      output_shape = tf.shape(images)[1:3]
      if not tf.executing_eagerly():
        output_shape_value = tf.get_static_value(output_shape)
        if output_shape_value is not None:
          output_shape = output_shape_value

    output_shape = tf.convert_to_tensor(
        output_shape, tf.int32, name='output_shape'
    )

    if not output_shape.get_shape().is_compatible_with([2]):
      raise ValueError(
          'output_shape must be a 1-D Tensor of 2 elements: '
          'new_height, new_width, instead got '
          f'output_shape={output_shape}'
      )

    fill_value = tf.convert_to_tensor(fill_value, tf.float32, name='fill_value')

    return tf.raw_ops.ImageProjectiveTransformV3(
        images=images,
        output_shape=output_shape,
        fill_value=fill_value,
        transforms=transforms,
        fill_mode=fill_mode.upper(),
        interpolation=interpolation.upper(),
    )


def transform(
    image: tf.Tensor,
    transforms: Any,
    interpolation: str = 'nearest',
    output_shape=None,
    fill_mode: str = 'reflect',
    fill_value: float = 0.0,
) -> tf.Tensor:
  """Transforms an image."""
  original_ndims = tf.rank(image)
  transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
  if transforms.shape.rank == 1:
    transforms = transforms[None]
  image = to_4d(image)
  image = _apply_transform_to_images(
      images=image,
      transforms=transforms,
      interpolation=interpolation,
      fill_mode=fill_mode,
      fill_value=fill_value,
      output_shape=output_shape,
  )
  return from_4d(image, original_ndims)


def translate(
    image: tf.Tensor,
    translations,
    fill_value: float = 0.0,
    fill_mode: str = 'reflect',
    interpolation: str = 'nearest',
) -> tf.Tensor:
  """Translates image(s) by provided vectors.

  Args:
    image: An image Tensor of type uint8.
    translations: A vector or matrix representing [dx dy].
    fill_value: a float represents the value to be filled outside the boundaries
      when `fill_mode="constant"`.
    fill_mode: Points outside the boundaries of the input are filled according
      to the given mode (one of `{"constant", "reflect", "wrap", "nearest"}`).
    interpolation: Interpolation mode. Supported values: `"nearest"`,
      `"bilinear"`.

  Returns:
    The translated version of the image.
  """
  transforms = _convert_translation_to_transform(translations)  # pytype: disable=wrong-arg-types  # always-use-return-annotations
  return transform(
      image,
      transforms=transforms,
      interpolation=interpolation,
      fill_value=fill_value,
      fill_mode=fill_mode,
  )


def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
  """Rotates the image by degrees either clockwise or counterclockwise.

  Args:
    image: An image Tensor of type uint8.
    degrees: Float, a scalar angle in degrees to rotate all images by. If
      degrees is positive the image will be rotated clockwise otherwise it will
      be rotated counterclockwise.

  Returns:
    The rotated version of image.

  """
  # Convert from degrees to radians.
  degrees_to_radians = math.pi / 180.0
  radians = tf.cast(degrees * degrees_to_radians, tf.float32)

  original_ndims = tf.rank(image)
  image = to_4d(image)

  image_height = tf.cast(tf.shape(image)[1], tf.float32)
  image_width = tf.cast(tf.shape(image)[2], tf.float32)
  transforms = _convert_angles_to_transform(
      angles=radians, image_width=image_width, image_height=image_height)
  # In practice, we should randomize the rotation degrees by flipping
  # it negatively half the time, but that's done on 'degrees' outside
  # of the function.
  image = transform(image, transforms=transforms)
  return from_4d(image, original_ndims)


def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor:
  """Blend image1 and image2 using 'factor'.

  Factor can be above 0.0.  A value of 0.0 means only image1 is used.
  A value of 1.0 means only image2 is used.  A value between 0.0 and
  1.0 means we linearly interpolate the pixel values between the two
  images.  A value greater than 1.0 "extrapolates" the difference
  between the two pixel values, and we clip the results to values
  between 0 and 255.

  Args:
    image1: An image Tensor of type uint8.
    image2: An image Tensor of type uint8.
    factor: A floating point value above 0.0.

  Returns:
    A blended image Tensor of type uint8.
  """
  if factor == 0.0:
    return tf.convert_to_tensor(image1)
  if factor == 1.0:
    return tf.convert_to_tensor(image2)

  image1 = tf.cast(image1, tf.float32)
  image2 = tf.cast(image2, tf.float32)

  difference = image2 - image1
  scaled = factor * difference

  # Do addition in float.
  temp = tf.cast(image1, tf.float32) + scaled

  # Interpolate
  if factor > 0.0 and factor < 1.0:
    # Interpolation means we always stay within 0 and 255.
    return tf.cast(temp, tf.uint8)

  # Extrapolate:
  #
  # We need to clip and then cast.
  return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)


def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
  """Apply cutout (https://arxiv.org/abs/1708.04552) to image.

  This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
  a random location within `image`. The pixel values filled in will be of the
  value `replace`. The location where the mask will be applied is randomly
  chosen uniformly over the whole image.

  Args:
    image: An image Tensor of type uint8.
    pad_size: Specifies how big the zero mask that will be generated is that is
      applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
    replace: What pixel value to fill in the image in the area that has the
      cutout mask applied to it.

  Returns:
    An image Tensor that is of type uint8.
  """
  if image.shape.rank not in [3, 4]:
    raise ValueError('Bad image rank: {}'.format(image.shape.rank))

  if image.shape.rank == 4:
    return cutout_video(image, replace=replace)

  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]

  # Sample the center location in the image where the zero mask will be applied.
  cutout_center_height = tf.random.uniform(
      shape=[], minval=0, maxval=image_height, dtype=tf.int32)

  cutout_center_width = tf.random.uniform(
      shape=[], minval=0, maxval=image_width, dtype=tf.int32)

  image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
                          pad_size, pad_size, replace)

  return image


def _fill_rectangle(image,
                    center_width,
                    center_height,
                    half_width,
                    half_height,
                    replace=None):
  """Fills blank area."""
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]

  lower_pad = tf.maximum(0, center_height - half_height)
  upper_pad = tf.maximum(0, image_height - center_height - half_height)
  left_pad = tf.maximum(0, center_width - half_width)
  right_pad = tf.maximum(0, image_width - center_width - half_width)

  cutout_shape = [
      image_height - (lower_pad + upper_pad),
      image_width - (left_pad + right_pad)
  ]
  padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
  mask = tf.pad(
      tf.zeros(cutout_shape, dtype=image.dtype),
      padding_dims,
      constant_values=1)
  mask = tf.expand_dims(mask, -1)
  mask = tf.tile(mask, [1, 1, 3])

  if replace is None:
    fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
  elif isinstance(replace, tf.Tensor):
    fill = replace
  else:
    fill = tf.ones_like(image, dtype=image.dtype) * replace
  image = tf.where(tf.equal(mask, 0), fill, image)

  return image


def _fill_rectangle_video(image,
                          center_width,
                          center_height,
                          half_width,
                          half_height,
                          replace=None):
  """Fills blank area for video."""
  image_time = tf.shape(image)[0]
  image_height = tf.shape(image)[1]
  image_width = tf.shape(image)[2]
  image_channels = tf.shape(image)[3]

  lower_pad = tf.maximum(0, center_height - half_height)
  upper_pad = tf.maximum(0, image_height - center_height - half_height)
  left_pad = tf.maximum(0, center_width - half_width)
  right_pad = tf.maximum(0, image_width - center_width - half_width)

  cutout_shape = [
      image_time, image_height - (lower_pad + upper_pad),
      image_width - (left_pad + right_pad)
  ]
  padding_dims = [[0, 0], [lower_pad, upper_pad], [left_pad, right_pad]]
  mask = tf.pad(
      tf.zeros(cutout_shape, dtype=image.dtype),
      padding_dims,
      constant_values=1)
  mask = tf.expand_dims(mask, -1)
  mask = tf.tile(mask, [1, 1, 1, image_channels])

  if replace is None:
    fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
  elif isinstance(replace, tf.Tensor):
    fill = replace
  else:
    fill = tf.ones_like(image, dtype=image.dtype) * replace
  image = tf.where(tf.equal(mask, 0), fill, image)

  return image


def cutout_video(
    video: tf.Tensor,
    mask_shape: Optional[tf.Tensor] = None,
    replace: int = 0,
) -> tf.Tensor:
  """Apply cutout (https://arxiv.org/abs/1708.04552) to a video.

  This operation applies a random size 3D mask of zeros to a random location
  within `video`. The mask is padded The pixel values filled in will be of the
  value `replace`. The location where the mask will be applied is randomly
  chosen uniformly over the whole video. If the size of the mask is not set,
  then, it is randomly sampled uniformly from [0.25*height, 0.5*height],
  [0.25*width, 0.5*width], and [1, 0.25*depth], which represent the height,
  width, and number of frames of the input video tensor respectively.

  Args:
    video: A video Tensor of shape [T, H, W, C].
    mask_shape: An optional integer tensor that specifies the depth, height and
      width of the mask to cut. If it is not set, the shape is randomly sampled
      as described above. The shape dimensions should be divisible by 2
      otherwise they will rounded down.
    replace: What pixel value to fill in the image in the area that has the
      cutout mask applied to it.

  Returns:
    A video Tensor with cutout applied.
  """
  tf.debugging.assert_shapes([
      (video, ('T', 'H', 'W', 'C')),
  ])

  video_depth = tf.shape(video)[0]
  video_height = tf.shape(video)[1]
  video_width = tf.shape(video)[2]

  # Sample the center location in the image where the zero mask will be applied.
  cutout_center_height = tf.random.uniform(
      shape=[], minval=0, maxval=video_height, dtype=tf.int32
  )

  cutout_center_width = tf.random.uniform(
      shape=[], minval=0, maxval=video_width, dtype=tf.int32
  )

  cutout_center_depth = tf.random.uniform(
      shape=[], minval=0, maxval=video_depth, dtype=tf.int32
  )

  if mask_shape is not None:
    pad_shape = tf.maximum(1, mask_shape // 2)
    pad_size_depth, pad_size_height, pad_size_width = (
        pad_shape[0],
        pad_shape[1],
        pad_shape[2],
    )
  else:
    pad_size_height = tf.random.uniform(
        shape=[],
        minval=tf.maximum(1, tf.cast(video_height / 4, tf.int32)),
        maxval=tf.maximum(2, tf.cast(video_height / 2, tf.int32)),
        dtype=tf.int32,
    )
    pad_size_width = tf.random.uniform(
        shape=[],
        minval=tf.maximum(1, tf.cast(video_width / 4, tf.int32)),
        maxval=tf.maximum(2, tf.cast(video_width / 2, tf.int32)),
        dtype=tf.int32,
    )
    pad_size_depth = tf.random.uniform(
        shape=[],
        minval=1,
        maxval=tf.maximum(2, tf.cast(video_depth / 4, tf.int32)),
        dtype=tf.int32,
    )

  lower_pad = tf.maximum(0, cutout_center_height - pad_size_height)
  upper_pad = tf.maximum(
      0, video_height - cutout_center_height - pad_size_height
  )
  left_pad = tf.maximum(0, cutout_center_width - pad_size_width)
  right_pad = tf.maximum(0, video_width - cutout_center_width - pad_size_width)
  back_pad = tf.maximum(0, cutout_center_depth - pad_size_depth)
  forward_pad = tf.maximum(
      0, video_depth - cutout_center_depth - pad_size_depth
  )

  cutout_shape = [
      video_depth - (back_pad + forward_pad),
      video_height - (lower_pad + upper_pad),
      video_width - (left_pad + right_pad),
  ]
  padding_dims = [[back_pad, forward_pad],
                  [lower_pad, upper_pad],
                  [left_pad, right_pad]]
  mask = tf.pad(
      tf.zeros(cutout_shape, dtype=video.dtype), padding_dims, constant_values=1
  )
  mask = tf.expand_dims(mask, -1)
  num_channels = tf.shape(video)[-1]
  mask = tf.tile(mask, [1, 1, 1, num_channels])
  video = tf.where(
      tf.equal(mask, 0), tf.ones_like(video, dtype=video.dtype) * replace, video
  )
  return video


def gaussian_noise(
    image: tf.Tensor, low: float = 0.1, high: float = 2.0) -> tf.Tensor:
  """Add Gaussian noise to image(s)."""
  augmented_image = gaussian_filter2d(  # pylint: disable=g-long-lambda
      image, filter_shape=[3, 3], sigma=np.random.uniform(low=low, high=high)
  )
  return augmented_image


def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
  """Solarize the input image(s)."""
  # For each pixel in the image, select the pixel
  # if the value is less than the threshold.
  # Otherwise, subtract 255 from the pixel.
  return tf.where(image < threshold, image, 255 - image)


def solarize_add(image: tf.Tensor,
                 addition: int = 0,
                 threshold: int = 128) -> tf.Tensor:
  """Additive solarize the input image(s)."""
  # For each pixel in the image less than threshold
  # we add 'addition' amount to it and then clip the
  # pixel value to be between 0 and 255. The value
  # of 'addition' is between -128 and 128.
  added_image = tf.cast(image, tf.int64) + addition
  added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
  return tf.where(image < threshold, added_image, image)


def grayscale(image: tf.Tensor) -> tf.Tensor:
  """Convert image to grayscale."""
  return tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))


def color(image: tf.Tensor, factor: float) -> tf.Tensor:
  """Equivalent of PIL Color."""
  degenerate = grayscale(image)
  return blend(degenerate, image, factor)


def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
  """Equivalent of PIL Contrast."""
  degenerate = tf.image.rgb_to_grayscale(image)
  # Cast before calling tf.histogram.
  degenerate = tf.cast(degenerate, tf.int32)

  # Compute the grayscale histogram, then compute the mean pixel value,
  # and create a constant image size of that value.  Use that as the
  # blending degenerate target of the original image.
  hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
  mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
  degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
  degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
  degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
  return blend(degenerate, image, factor)


def brightness(image: tf.Tensor, factor: float) -> tf.Tensor:
  """Equivalent of PIL Brightness."""
  degenerate = tf.zeros_like(image)
  return blend(degenerate, image, factor)


def posterize(image: tf.Tensor, bits: int) -> tf.Tensor:
  """Equivalent of PIL Posterize."""
  shift = 8 - bits
  return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)


def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor:
  """Applies rotation with wrap/unwrap."""
  image = rotate(wrap(image), degrees=degrees)
  return unwrap(image, replace)


def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
  """Equivalent of PIL Translate in X dimension."""
  image = translate(wrap(image), [-pixels, 0])
  return unwrap(image, replace)


def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
  """Equivalent of PIL Translate in Y dimension."""
  image = translate(wrap(image), [0, -pixels])
  return unwrap(image, replace)


def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
  """Equivalent of PIL Shearing in X dimension."""
  # Shear parallel to x axis is a projective transform
  # with a matrix form of:
  # [1  level
  #  0  1].
  image = transform(
      image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
  return unwrap(image, replace)


def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
  """Equivalent of PIL Shearing in Y dimension."""
  # Shear parallel to y axis is a projective transform
  # with a matrix form of:
  # [1  0
  #  level  1].
  image = transform(
      image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
  return unwrap(image, replace)


def autocontrast(image: tf.Tensor) -> tf.Tensor:
  """Implements Autocontrast function from PIL using TF ops.

  Args:
    image: A 3D uint8 tensor.

  Returns:
    The image after it has had autocontrast applied to it and will be of type
    uint8.
  """

  def scale_channel(image: tf.Tensor) -> tf.Tensor:
    """Scale the 2D image using the autocontrast rule."""
    # A possibly cheaper version can be done using cumsum/unique_with_counts
    # over the histogram values, rather than iterating over the entire image.
    # to compute mins and maxes.
    lo = tf.cast(tf.reduce_min(image), tf.float32)
    hi = tf.cast(tf.reduce_max(image), tf.float32)

    # Scale the image, making the lowest value 0 and the highest value 255.
    def scale_values(im):
      scale = 255.0 / (hi - lo)
      offset = -lo * scale
      im = tf.cast(im, tf.float32) * scale + offset
      im = tf.clip_by_value(im, 0.0, 255.0)
      return tf.cast(im, tf.uint8)

    result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
    return result

  # Assumes RGB for now.  Scales each channel independently
  # and then stacks the result.
  s1 = scale_channel(image[..., 0])
  s2 = scale_channel(image[..., 1])
  s3 = scale_channel(image[..., 2])
  image = tf.stack([s1, s2, s3], -1)

  return image


def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
  """Implements Sharpness function from PIL using TF ops."""
  orig_image = image
  image = tf.cast(image, tf.float32)
  # Make image 4D for conv operation.
  image = tf.expand_dims(image, 0)
  # SMOOTH PIL Kernel.
  if orig_image.shape.rank == 3:
    kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                         dtype=tf.float32,
                         shape=[3, 3, 1, 1]) / 13.
    # Tile across channel dimension.
    kernel = tf.tile(kernel, [1, 1, 3, 1])
    strides = [1, 1, 1, 1]
    degenerate = tf.nn.depthwise_conv2d(
        image, kernel, strides, padding='VALID', dilations=[1, 1])
  elif orig_image.shape.rank == 4:
    kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                         dtype=tf.float32,
                         shape=[1, 3, 3, 1, 1]) / 13.
    strides = [1, 1, 1, 1, 1]
    # Run the kernel across each channel
    channels = tf.split(image, 3, axis=-1)
    degenerates = [
        tf.nn.conv3d(channel, kernel, strides, padding='VALID',
                     dilations=[1, 1, 1, 1, 1])
        for channel in channels
    ]
    degenerate = tf.concat(degenerates, -1)
  else:
    raise ValueError('Bad image rank: {}'.format(image.shape.rank))
  degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
  degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])

  # For the borders of the resulting image, fill in the values of the
  # original image.
  mask = tf.ones_like(degenerate)
  paddings = [[0, 0]] * (orig_image.shape.rank - 3)
  padded_mask = tf.pad(mask, paddings + [[1, 1], [1, 1], [0, 0]])
  padded_degenerate = tf.pad(degenerate, paddings + [[1, 1], [1, 1], [0, 0]])
  result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)

  # Blend the final result.
  return blend(result, orig_image, factor)


def equalize(image: tf.Tensor) -> tf.Tensor:
  """Implements Equalize function from PIL using TF ops."""

  def scale_channel(im, c):
    """Scale the data in the channel to implement equalize."""
    im = tf.cast(im[..., c], tf.int32)
    # Compute the histogram of the image channel.
    histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)

    # For the purposes of computing the step, filter out the nonzeros.
    nonzero = tf.where(tf.not_equal(histo, 0))
    nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
    step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255

    def build_lut(histo, step):
      # Compute the cumulative sum, shifting by step // 2
      # and then normalization by step.
      lut = (tf.cumsum(histo) + (step // 2)) // step
      # Shift lut, prepending with 0.
      lut = tf.concat([[0], lut[:-1]], 0)
      # Clip the counts to be in range.  This is done
      # in the C code for image.point.
      return tf.clip_by_value(lut, 0, 255)

    # If step is zero, return the original image.  Otherwise, build
    # lut from the full histogram and step and then index from it.
    result = tf.cond(
        tf.equal(step, 0), lambda: im,
        lambda: tf.gather(build_lut(histo, step), im))

    return tf.cast(result, tf.uint8)

  # Assumes RGB for now.  Scales each channel independently
  # and then stacks the result.
  s1 = scale_channel(image, 0)
  s2 = scale_channel(image, 1)
  s3 = scale_channel(image, 2)
  image = tf.stack([s1, s2, s3], -1)
  return image


def invert(image: tf.Tensor) -> tf.Tensor:
  """Inverts the image pixels."""
  image = tf.convert_to_tensor(image)
  return 255 - image


def wrap(image: tf.Tensor) -> tf.Tensor:
  """Returns 'image' with an extra channel set to all 1s."""
  shape = tf.shape(image)
  extended_channel = tf.expand_dims(tf.ones(shape[:-1], image.dtype), -1)
  extended = tf.concat([image, extended_channel], axis=-1)
  return extended


def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
  """Unwraps an image produced by wrap.

  Where there is a 0 in the last channel for every spatial position,
  the rest of the three channels in that spatial dimension are grayed
  (set to 128).  Operations like translate and shear on a wrapped
  Tensor will leave 0s in empty locations.  Some transformations look
  at the intensity of values to do preprocessing, and we want these
  empty pixels to assume the 'average' value, rather than pure black.


  Args:
    image: A 3D Image Tensor with 4 channels.
    replace: A one or three value 1D tensor to fill empty pixels.

  Returns:
    image: A 3D image Tensor with 3 channels.
  """
  image_shape = tf.shape(image)
  # Flatten the spatial dimensions.
  flattened_image = tf.reshape(image, [-1, image_shape[-1]])

  # Find all pixels where the last channel is zero.
  alpha_channel = tf.expand_dims(flattened_image[..., 3], axis=-1)

  replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)

  # Where they are zero, fill them in with 'replace'.
  flattened_image = tf.where(
      tf.equal(alpha_channel, 0),
      tf.ones_like(flattened_image, dtype=image.dtype) * replace,
      flattened_image)

  image = tf.reshape(flattened_image, image_shape)
  image = tf.slice(
      image,
      [0] * image.shape.rank,
      tf.concat([image_shape[:-1], [3]], -1))
  return image


def _scale_bbox_only_op_probability(prob):
  """Reduce the probability of the bbox-only operation.

  Probability is reduced so that we do not distort the content of too many
  bounding boxes that are close to each other. The value of 3.0 was a chosen
  hyper parameter when designing the autoaugment algorithm that we found
  empirically to work well.

  Args:
    prob: Float that is the probability of applying the bbox-only operation.

  Returns:
    Reduced probability.
  """
  return prob / 3.0


def _apply_bbox_augmentation(image, bbox, augmentation_func, *args):
  """Applies augmentation_func to the subsection of image indicated by bbox.

  Args:
    image: 3D uint8 Tensor.
    bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
      of type float that represents the normalized coordinates between 0 and 1.
    augmentation_func: Augmentation function that will be applied to the
      subsection of image.
    *args: Additional parameters that will be passed into augmentation_func
      when it is called.

  Returns:
    A modified version of image, where the bbox location in the image will
    have `ugmentation_func applied to it.
  """
  image_height = tf.cast(tf.shape(image)[0], tf.float32)
  image_width = tf.cast(tf.shape(image)[1], tf.float32)
  min_y = tf.cast(image_height * bbox[0], tf.int32)
  min_x = tf.cast(image_width * bbox[1], tf.int32)
  max_y = tf.cast(image_height * bbox[2], tf.int32)
  max_x = tf.cast(image_width * bbox[3], tf.int32)
  image_height = tf.cast(image_height, tf.int32)
  image_width = tf.cast(image_width, tf.int32)

  # Clip to be sure the max values do not fall out of range.
  max_y = tf.minimum(max_y, image_height - 1)
  max_x = tf.minimum(max_x, image_width - 1)

  # Get the sub-tensor that is the image within the bounding box region.
  bbox_content = image[min_y:max_y + 1, min_x:max_x + 1, :]

  # Apply the augmentation function to the bbox portion of the image.
  augmented_bbox_content = augmentation_func(bbox_content, *args)

  # Pad the augmented_bbox_content and the mask to match the shape of original
  # image.
  augmented_bbox_content = tf.pad(augmented_bbox_content,
                                  [[min_y, (image_height - 1) - max_y],
                                   [min_x, (image_width - 1) - max_x],
                                   [0, 0]])

  # Create a mask that will be used to zero out a part of the original image.
  mask_tensor = tf.zeros_like(bbox_content)

  mask_tensor = tf.pad(mask_tensor,
                       [[min_y, (image_height - 1) - max_y],
                        [min_x, (image_width - 1) - max_x],
                        [0, 0]],
                       constant_values=1)
  # Replace the old bbox content with the new augmented content.
  image = image * mask_tensor + augmented_bbox_content
  return image


def _concat_bbox(bbox, bboxes):
  """Helper function that concates bbox to bboxes along the first dimension."""

  # Note if all elements in bboxes are -1 (_INVALID_BOX), then this means
  # we discard bboxes and start the bboxes Tensor with the current bbox.
  bboxes_sum_check = tf.reduce_sum(bboxes)
  bbox = tf.expand_dims(bbox, 0)
  # This check will be true when it is an _INVALID_BOX
  bboxes = tf.cond(tf.equal(bboxes_sum_check, -4.0),
                   lambda: bbox,
                   lambda: tf.concat([bboxes, bbox], 0))
  return bboxes


def _apply_bbox_augmentation_wrapper(image, bbox, new_bboxes, prob,
                                     augmentation_func, func_changes_bbox,
                                     *args):
  """Applies _apply_bbox_augmentation with probability prob.

  Args:
    image: 3D uint8 Tensor.
    bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
      of type float that represents the normalized coordinates between 0 and 1.
    new_bboxes: 2D Tensor that is a list of the bboxes in the image after they
      have been altered by aug_func. These will only be changed when
      func_changes_bbox is set to true. Each bbox has 4 elements
      (min_y, min_x, max_y, max_x) of type float that are the normalized
      bbox coordinates between 0 and 1.
    prob: Float that is the probability of applying _apply_bbox_augmentation.
    augmentation_func: Augmentation function that will be applied to the
      subsection of image.
    func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
      to image.
    *args: Additional parameters that will be passed into augmentation_func
      when it is called.

  Returns:
    A tuple. Fist element is a modified version of image, where the bbox
    location in the image will have augmentation_func applied to it if it is
    chosen to be called with probability `prob`. The second element is a
    Tensor of Tensors of length 4 that will contain the altered bbox after
    applying augmentation_func.
  """
  should_apply_op = tf.cast(
      tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
  if func_changes_bbox:
    augmented_image, bbox = tf.cond(
        should_apply_op,
        lambda: augmentation_func(image, bbox, *args),
        lambda: (image, bbox))
  else:
    augmented_image = tf.cond(
        should_apply_op,
        lambda: _apply_bbox_augmentation(image, bbox, augmentation_func, *args),
        lambda: image)
  new_bboxes = _concat_bbox(bbox, new_bboxes)
  return augmented_image, new_bboxes


def _apply_multi_bbox_augmentation_wrapper(image, bboxes, prob, aug_func,
                                           func_changes_bbox, *args):
  """Checks to be sure num bboxes > 0 before calling inner function."""
  num_bboxes = tf.shape(bboxes)[0]
  image, bboxes = tf.cond(
      tf.equal(num_bboxes, 0),
      lambda: (image, bboxes),
      # pylint:disable=g-long-lambda
      lambda: _apply_multi_bbox_augmentation(
          image, bboxes, prob, aug_func, func_changes_bbox, *args))
  # pylint:enable=g-long-lambda
  return image, bboxes


# Represents an invalid bounding box that is used for checking for padding
# lists of bounding box coordinates for a few augmentation operations
_INVALID_BOX = [[-1.0, -1.0, -1.0, -1.0]]


def _apply_multi_bbox_augmentation(image, bboxes, prob, aug_func,
                                   func_changes_bbox, *args):
  """Applies aug_func to the image for each bbox in bboxes.

  Args:
    image: 3D uint8 Tensor.
    bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
      has 4 elements (min_y, min_x, max_y, max_x) of type float.
    prob: Float that is the probability of applying aug_func to a specific
      bounding box within the image.
    aug_func: Augmentation function that will be applied to the
      subsections of image indicated by the bbox values in bboxes.
    func_changes_bbox: Boolean. Does augmentation_func return bbox in addition
      to image.
    *args: Additional parameters that will be passed into augmentation_func
      when it is called.

  Returns:
    A modified version of image, where each bbox location in the image will
    have augmentation_func applied to it if it is chosen to be called with
    probability prob independently across all bboxes. Also the final
    bboxes are returned that will be unchanged if func_changes_bbox is set to
    false and if true, the new altered ones will be returned.

  Raises:
    ValueError if applied to video.
  """
  if image.shape.rank == 4:
    raise ValueError('Image rank 4 is not supported')

  # Will keep track of the new altered bboxes after aug_func is repeatedly
  # applied. The -1 values are a dummy value and this first Tensor will be
  # removed upon appending the first real bbox.
  new_bboxes = tf.constant(_INVALID_BOX)

  # If the bboxes are empty, then just give it _INVALID_BOX. The result
  # will be thrown away.
  bboxes = tf.cond(tf.equal(tf.size(bboxes), 0),
                   lambda: tf.constant(_INVALID_BOX),
                   lambda: bboxes)

  bboxes = tf.ensure_shape(bboxes, (None, 4))

  # pylint:disable=g-long-lambda
  wrapped_aug_func = (
      lambda _image, bbox, _new_bboxes: _apply_bbox_augmentation_wrapper(
          _image, bbox, _new_bboxes, prob, aug_func, func_changes_bbox, *args))
  # pylint:enable=g-long-lambda

  # Setup the while_loop.
  num_bboxes = tf.shape(bboxes)[0]  # We loop until we go over all bboxes.
  idx = tf.constant(0)  # Counter for the while loop.

  # Conditional function when to end the loop once we go over all bboxes
  # images_and_bboxes contain (_image, _new_bboxes)
  cond = lambda _idx, _images_and_bboxes: tf.less(_idx, num_bboxes)

  # Shuffle the bboxes so that the augmentation order is not deterministic if
  # we are not changing the bboxes with aug_func.
  if not func_changes_bbox:
    loop_bboxes = tf.random.shuffle(bboxes)
  else:
    loop_bboxes = bboxes

  # Main function of while_loop where we repeatedly apply augmentation on the
  # bboxes in the image.
  # pylint:disable=g-long-lambda
  body = lambda _idx, _images_and_bboxes: [
      _idx + 1, wrapped_aug_func(_images_and_bboxes[0],
                                 loop_bboxes[_idx],
                                 _images_and_bboxes[1])]
  # pylint:enable=g-long-lambda

  _, (image, new_bboxes) = tf.while_loop(
      cond, body, [idx, (image, new_bboxes)],
      shape_invariants=[idx.get_shape(),
                        (image.get_shape(), tf.TensorShape([None, 4]))])

  # Either return the altered bboxes or the original ones depending on if
  # we altered them in anyway.
  if func_changes_bbox:
    final_bboxes = new_bboxes
  else:
    final_bboxes = bboxes
  return image, final_bboxes


def _clip_bbox(min_y, min_x, max_y, max_x):
  """Clip bounding box coordinates between 0 and 1.

  Args:
    min_y: Normalized bbox coordinate of type float between 0 and 1.
    min_x: Normalized bbox coordinate of type float between 0 and 1.
    max_y: Normalized bbox coordinate of type float between 0 and 1.
    max_x: Normalized bbox coordinate of type float between 0 and 1.

  Returns:
    Clipped coordinate values between 0 and 1.
  """
  min_y = tf.clip_by_value(min_y, 0.0, 1.0)
  min_x = tf.clip_by_value(min_x, 0.0, 1.0)
  max_y = tf.clip_by_value(max_y, 0.0, 1.0)
  max_x = tf.clip_by_value(max_x, 0.0, 1.0)
  return min_y, min_x, max_y, max_x


def _check_bbox_area(min_y, min_x, max_y, max_x, delta=0.05):
  """Adjusts bbox coordinates to make sure the area is > 0.

  Args:
    min_y: Normalized bbox coordinate of type float between 0 and 1.
    min_x: Normalized bbox coordinate of type float between 0 and 1.
    max_y: Normalized bbox coordinate of type float between 0 and 1.
    max_x: Normalized bbox coordinate of type float between 0 and 1.
    delta: Float, this is used to create a gap of size 2 * delta between
      bbox min/max coordinates that are the same on the boundary.
      This prevents the bbox from having an area of zero.

  Returns:
    Tuple of new bbox coordinates between 0 and 1 that will now have a
    guaranteed area > 0.
  """
  height = max_y - min_y
  width = max_x - min_x
  def _adjust_bbox_boundaries(min_coord, max_coord):
    # Make sure max is never 0 and min is never 1.
    max_coord = tf.maximum(max_coord, 0.0 + delta)
    min_coord = tf.minimum(min_coord, 1.0 - delta)
    return min_coord, max_coord
  min_y, max_y = tf.cond(tf.equal(height, 0.0),
                         lambda: _adjust_bbox_boundaries(min_y, max_y),
                         lambda: (min_y, max_y))
  min_x, max_x = tf.cond(tf.equal(width, 0.0),
                         lambda: _adjust_bbox_boundaries(min_x, max_x),
                         lambda: (min_x, max_x))
  return min_y, min_x, max_y, max_x


def _rotate_bbox(bbox, image_height, image_width, degrees):
  """Rotates the bbox coordinated by degrees.

  Args:
    bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
      of type float that represents the normalized coordinates between 0 and 1.
    image_height: Int, height of the image.
    image_width: Int, height of the image.
    degrees: Float, a scalar angle in degrees to rotate all images by. If
      degrees is positive the image will be rotated clockwise otherwise it will
      be rotated counterclockwise.

  Returns:
    A tensor of the same shape as bbox, but now with the rotated coordinates.
  """
  image_height, image_width = (
      tf.cast(image_height, tf.float32), tf.cast(image_width, tf.float32))

  # Convert from degrees to radians.
  degrees_to_radians = math.pi / 180.0
  radians = degrees * degrees_to_radians

  # Translate the bbox to the center of the image and turn the normalized 0-1
  # coordinates to absolute pixel locations.
  # Y coordinates are made negative as the y axis of images goes down with
  # increasing pixel values, so we negate to make sure x axis and y axis points
  # are in the traditionally positive direction.
  min_y = -tf.cast(image_height * (bbox[0] - 0.5), tf.int32)
  min_x = tf.cast(image_width * (bbox[1] - 0.5), tf.int32)
  max_y = -tf.cast(image_height * (bbox[2] - 0.5), tf.int32)
  max_x = tf.cast(image_width * (bbox[3] - 0.5), tf.int32)
  coordinates = tf.stack(
      [[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]])
  coordinates = tf.cast(coordinates, tf.float32)
  # Rotate the coordinates according to the rotation matrix clockwise if
  # radians is positive, else negative
  rotation_matrix = tf.stack(
      [[tf.cos(radians), tf.sin(radians)],
       [-tf.sin(radians), tf.cos(radians)]])
  new_coords = tf.cast(
      tf.matmul(rotation_matrix, tf.transpose(coordinates)), tf.int32)
  # Find min/max values and convert them back to normalized 0-1 floats.
  min_y = -(
      tf.cast(tf.reduce_max(new_coords[0, :]), tf.float32) / image_height - 0.5)
  min_x = tf.cast(tf.reduce_min(new_coords[1, :]),
                  tf.float32) / image_width + 0.5
  max_y = -(
      tf.cast(tf.reduce_min(new_coords[0, :]), tf.float32) / image_height - 0.5)
  max_x = tf.cast(tf.reduce_max(new_coords[1, :]),
                  tf.float32) / image_width + 0.5

  # Clip the bboxes to be sure the fall between [0, 1].
  min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
  min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
  return tf.stack([min_y, min_x, max_y, max_x])


def rotate_with_bboxes(image, bboxes, degrees, replace):
  """Equivalent of PIL Rotate that rotates the image and bbox.

  Args:
    image: 3D uint8 Tensor.
    bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
      has 4 elements (min_y, min_x, max_y, max_x) of type float.
    degrees: Float, a scalar angle in degrees to rotate all images by. If
      degrees is positive the image will be rotated clockwise otherwise it will
      be rotated counterclockwise.
    replace: A one or three value 1D tensor to fill empty pixels.

  Returns:
    A tuple containing a 3D uint8 Tensor that will be the result of rotating
    image by degrees. The second element of the tuple is bboxes, where now
    the coordinates will be shifted to reflect the rotated image.

  Raises:
    ValueError: If applied to video.
  """
  if image.shape.rank == 4:
    raise ValueError('Image rank 4 is not supported')

  # Rotate the image.
  image = wrapped_rotate(image, degrees, replace)

  # Convert bbox coordinates to pixel values.
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]
  # pylint:disable=g-long-lambda
  wrapped_rotate_bbox = lambda bbox: _rotate_bbox(
      bbox, image_height, image_width, degrees)
  # pylint:enable=g-long-lambda
  bboxes = tf.map_fn(wrapped_rotate_bbox, bboxes)
  return image, bboxes


def _shear_bbox(bbox, image_height, image_width, level, shear_horizontal):
  """Shifts the bbox according to how the image was sheared.

  Args:
    bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
      of type float that represents the normalized coordinates between 0 and 1.
    image_height: Int, height of the image.
    image_width: Int, height of the image.
    level: Float. How much to shear the image.
    shear_horizontal: If true then shear in X dimension else shear in
      the Y dimension.

  Returns:
    A tensor of the same shape as bbox, but now with the shifted coordinates.
  """
  image_height, image_width = (
      tf.cast(image_height, tf.float32), tf.cast(image_width, tf.float32))

  # Change bbox coordinates to be pixels.
  min_y = tf.cast(image_height * bbox[0], tf.int32)
  min_x = tf.cast(image_width * bbox[1], tf.int32)
  max_y = tf.cast(image_height * bbox[2], tf.int32)
  max_x = tf.cast(image_width * bbox[3], tf.int32)
  coordinates = tf.stack(
      [[min_y, min_x], [min_y, max_x], [max_y, min_x], [max_y, max_x]])
  coordinates = tf.cast(coordinates, tf.float32)

  # Shear the coordinates according to the translation matrix.
  if shear_horizontal:
    translation_matrix = tf.stack(
        [[1, 0], [-level, 1]])
  else:
    translation_matrix = tf.stack(
        [[1, -level], [0, 1]])
  translation_matrix = tf.cast(translation_matrix, tf.float32)
  new_coords = tf.cast(
      tf.matmul(translation_matrix, tf.transpose(coordinates)), tf.int32)

  # Find min/max values and convert them back to floats.
  min_y = tf.cast(tf.reduce_min(new_coords[0, :]), tf.float32) / image_height
  min_x = tf.cast(tf.reduce_min(new_coords[1, :]), tf.float32) / image_width
  max_y = tf.cast(tf.reduce_max(new_coords[0, :]), tf.float32) / image_height
  max_x = tf.cast(tf.reduce_max(new_coords[1, :]), tf.float32) / image_width

  # Clip the bboxes to be sure the fall between [0, 1].
  min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
  min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
  return tf.stack([min_y, min_x, max_y, max_x])


def shear_with_bboxes(image, bboxes, level, replace, shear_horizontal):
  """Applies Shear Transformation to the image and shifts the bboxes.

  Args:
    image: 3D uint8 Tensor.
    bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
      has 4 elements (min_y, min_x, max_y, max_x) of type float with values
      between [0, 1].
    level: Float. How much to shear the image. This value will be between
      -0.3 to 0.3.
    replace: A one or three value 1D tensor to fill empty pixels.
    shear_horizontal: Boolean. If true then shear in X dimension else shear in
      the Y dimension.

  Returns:
    A tuple containing a 3D uint8 Tensor that will be the result of shearing
    image by level. The second element of the tuple is bboxes, where now
    the coordinates will be shifted to reflect the sheared image.

  Raises:
    ValueError: If applied to video.
  """
  if image.shape.rank == 4:
    raise ValueError('Image rank 4 is not supported')

  if shear_horizontal:
    image = shear_x(image, level, replace)
  else:
    image = shear_y(image, level, replace)

  # Convert bbox coordinates to pixel values.
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]
  # pylint:disable=g-long-lambda
  wrapped_shear_bbox = lambda bbox: _shear_bbox(
      bbox, image_height, image_width, level, shear_horizontal)
  # pylint:enable=g-long-lambda
  bboxes = tf.map_fn(wrapped_shear_bbox, bboxes)
  return image, bboxes


def _shift_bbox(bbox, image_height, image_width, pixels, shift_horizontal):
  """Shifts the bbox coordinates by pixels.

  Args:
    bbox: 1D Tensor that has 4 elements (min_y, min_x, max_y, max_x)
      of type float that represents the normalized coordinates between 0 and 1.
    image_height: Int, height of the image.
    image_width: Int, width of the image.
    pixels: An int. How many pixels to shift the bbox.
    shift_horizontal: Boolean. If true then shift in X dimension else shift in
      Y dimension.

  Returns:
    A tensor of the same shape as bbox, but now with the shifted coordinates.
  """
  pixels = tf.cast(pixels, tf.int32)
  # Convert bbox to integer pixel locations.
  min_y = tf.cast(tf.cast(image_height, tf.float32) * bbox[0], tf.int32)
  min_x = tf.cast(tf.cast(image_width, tf.float32) * bbox[1], tf.int32)
  max_y = tf.cast(tf.cast(image_height, tf.float32) * bbox[2], tf.int32)
  max_x = tf.cast(tf.cast(image_width, tf.float32) * bbox[3], tf.int32)

  if shift_horizontal:
    min_x = tf.maximum(0, min_x - pixels)
    max_x = tf.minimum(image_width, max_x - pixels)
  else:
    min_y = tf.maximum(0, min_y - pixels)
    max_y = tf.minimum(image_height, max_y - pixels)

  # Convert bbox back to floats.
  min_y = tf.cast(min_y, tf.float32) / tf.cast(image_height, tf.float32)
  min_x = tf.cast(min_x, tf.float32) / tf.cast(image_width, tf.float32)
  max_y = tf.cast(max_y, tf.float32) / tf.cast(image_height, tf.float32)
  max_x = tf.cast(max_x, tf.float32) / tf.cast(image_width, tf.float32)

  # Clip the bboxes to be sure the fall between [0, 1].
  min_y, min_x, max_y, max_x = _clip_bbox(min_y, min_x, max_y, max_x)
  min_y, min_x, max_y, max_x = _check_bbox_area(min_y, min_x, max_y, max_x)
  return tf.stack([min_y, min_x, max_y, max_x])


def translate_bbox(image, bboxes, pixels, replace, shift_horizontal):
  """Equivalent of PIL Translate in X/Y dimension that shifts image and bbox.

  Args:
    image: 3D uint8 Tensor.
    bboxes: 2D Tensor that is a list of the bboxes in the image. Each bbox
      has 4 elements (min_y, min_x, max_y, max_x) of type float with values
      between [0, 1].
    pixels: An int. How many pixels to shift the image and bboxes
    replace: A one or three value 1D tensor to fill empty pixels.
    shift_horizontal: Boolean. If true then shift in X dimension else shift in
      Y dimension.

  Returns:
    A tuple containing a 3D uint8 Tensor that will be the result of translating
    image by pixels. The second element of the tuple is bboxes, where now
    the coordinates will be shifted to reflect the shifted image.

  Raises:
    ValueError if applied to video.
  """
  if image.shape.rank == 4:
    raise ValueError('Image rank 4 is not supported')

  if shift_horizontal:
    image = translate_x(image, pixels, replace)
  else:
    image = translate_y(image, pixels, replace)

  # Convert bbox coordinates to pixel values.
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]
  # pylint:disable=g-long-lambda
  wrapped_shift_bbox = lambda bbox: _shift_bbox(
      bbox, image_height, image_width, pixels, shift_horizontal)
  # pylint:enable=g-long-lambda
  bboxes = tf.map_fn(wrapped_shift_bbox, bboxes)
  return image, bboxes


def translate_y_only_bboxes(
    image: tf.Tensor, bboxes: tf.Tensor, prob: float, pixels: int, replace):
  """Apply translate_y to each bbox in the image with probability prob."""
  if bboxes.shape.rank == 4:
    raise ValueError('translate_y_only_bboxes does not support rank 4 boxes')

  func_changes_bbox = False
  prob = _scale_bbox_only_op_probability(prob)
  return _apply_multi_bbox_augmentation_wrapper(
      image, bboxes, prob, translate_y, func_changes_bbox, pixels, replace)


def _randomly_negate_tensor(tensor):
  """With 50% prob turn the tensor negative."""
  should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
  final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
  return final_tensor


def _rotate_level_to_arg(level: float):
  level = (level / _MAX_LEVEL) * 30.
  level = _randomly_negate_tensor(level)
  return (level,)


def _shrink_level_to_arg(level: float):
  """Converts level to ratio by which we shrink the image content."""
  if level == 0:
    return (1.0,)  # if level is zero, do not shrink the image
  # Maximum shrinking ratio is 2.9.
  level = 2. / (_MAX_LEVEL / level) + 0.9
  return (level,)


def _enhance_level_to_arg(level: float):
  return ((level / _MAX_LEVEL) * 1.8 + 0.1,)


def _shear_level_to_arg(level: float):
  level = (level / _MAX_LEVEL) * 0.3
  # Flip level to negative with 50% chance.
  level = _randomly_negate_tensor(level)
  return (level,)


def _translate_level_to_arg(level: float, translate_const: float):
  level = (level / _MAX_LEVEL) * float(translate_const)
  # Flip level to negative with 50% chance.
  level = _randomly_negate_tensor(level)
  return (level,)


def _gaussian_noise_level_to_arg(level: float, translate_const: float):
  low_std = (level / _MAX_LEVEL)
  high_std = translate_const * low_std
  return low_std, high_std


def _mult_to_arg(level: float, multiplier: float = 1.):
  return (int((level / _MAX_LEVEL) * multiplier),)


def _apply_func_with_prob(func: Any, image: tf.Tensor,
                          bboxes: Optional[tf.Tensor], args: Any, prob: float):
  """Apply `func` to image w/ `args` as input with probability `prob`."""
  assert isinstance(args, tuple)
  assert inspect.getfullargspec(func)[0][1] == 'bboxes'

  # Apply the function with probability `prob`.
  should_apply_op = tf.cast(
      tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
  augmented_image, augmented_bboxes = tf.cond(
      should_apply_op,
      lambda: func(image, bboxes, *args),
      lambda: (image, bboxes))
  return augmented_image, augmented_bboxes


def select_and_apply_random_policy(
    policies: Any, image: tf.Tensor, bboxes: Optional[tf.Tensor] = None
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
  """Select a random policy from `policies` and apply it to `image`."""
  policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
  # Note that using tf.case instead of tf.conds would result in significantly
  # larger graphs and would even break export for some larger policies.
  for (i, policy) in enumerate(policies):
    image, bboxes = tf.cond(
        tf.equal(i, policy_to_select),
        lambda selected_policy=policy: selected_policy(image, bboxes),
        lambda: (image, bboxes))
  return image, bboxes


NAME_TO_FUNC = {
    'AutoContrast': autocontrast,
    'Equalize': equalize,
    'Invert': invert,
    'Rotate': wrapped_rotate,
    'Posterize': posterize,
    'Solarize': solarize,
    'SolarizeAdd': solarize_add,
    'Color': color,
    'Contrast': contrast,
    'Brightness': brightness,
    'Sharpness': sharpness,
    'ShearX': shear_x,
    'ShearY': shear_y,
    'TranslateX': translate_x,
    'TranslateY': translate_y,
    'Cutout': cutout,
    'Rotate_BBox': rotate_with_bboxes,
    'Grayscale': grayscale,
    'Gaussian_Noise': gaussian_noise,
    # pylint:disable=g-long-lambda
    'ShearX_BBox': lambda image, bboxes, level, replace: shear_with_bboxes(
        image, bboxes, level, replace, shear_horizontal=True),
    'ShearY_BBox': lambda image, bboxes, level, replace: shear_with_bboxes(
        image, bboxes, level, replace, shear_horizontal=False),
    'TranslateX_BBox': lambda image, bboxes, pixels, replace: translate_bbox(
        image, bboxes, pixels, replace, shift_horizontal=True),
    'TranslateY_BBox': lambda image, bboxes, pixels, replace: translate_bbox(
        image, bboxes, pixels, replace, shift_horizontal=False),
    # pylint:enable=g-long-lambda
    'TranslateY_Only_BBoxes': translate_y_only_bboxes,
}

# Functions that require a `bboxes` parameter.
REQUIRE_BOXES_FUNCS = frozenset({
    'Rotate_BBox',
    'ShearX_BBox',
    'ShearY_BBox',
    'TranslateX_BBox',
    'TranslateY_BBox',
    'TranslateY_Only_BBoxes',
})

# Functions that have a 'prob' parameter
PROB_FUNCS = frozenset({
    'TranslateY_Only_BBoxes',
})

# Functions that have a 'replace' parameter
REPLACE_FUNCS = frozenset({
    'Rotate',
    'TranslateX',
    'ShearX',
    'ShearY',
    'TranslateY',
    'Cutout',
    'Rotate_BBox',
    'ShearX_BBox',
    'ShearY_BBox',
    'TranslateX_BBox',
    'TranslateY_BBox',
    'TranslateY_Only_BBoxes',
})


def level_to_arg(cutout_const: float, translate_const: float):
  """Creates a dict mapping image operation names to their arguments."""

  no_arg = lambda level: ()
  posterize_arg = lambda level: _mult_to_arg(level, 4)
  solarize_arg = lambda level: _mult_to_arg(level, 256)
  solarize_add_arg = lambda level: _mult_to_arg(level, 110)
  cutout_arg = lambda level: _mult_to_arg(level, cutout_const)
  translate_arg = lambda level: _translate_level_to_arg(level, translate_const)
  translate_bbox_arg = lambda level: _translate_level_to_arg(level, 120)

  args = {
      'AutoContrast': no_arg,
      'Equalize': no_arg,
      'Invert': no_arg,
      'Rotate': _rotate_level_to_arg,
      'Posterize': posterize_arg,
      'Solarize': solarize_arg,
      'SolarizeAdd': solarize_add_arg,
      'Color': _enhance_level_to_arg,
      'Contrast': _enhance_level_to_arg,
      'Brightness': _enhance_level_to_arg,
      'Sharpness': _enhance_level_to_arg,
      'ShearX': _shear_level_to_arg,
      'ShearY': _shear_level_to_arg,
      'Cutout': cutout_arg,
      'TranslateX': translate_arg,
      'TranslateY': translate_arg,
      'Rotate_BBox': _rotate_level_to_arg,
      'ShearX_BBox': _shear_level_to_arg,
      'ShearY_BBox': _shear_level_to_arg,
      'Grayscale': no_arg,
      # pylint:disable=g-long-lambda
      'Gaussian_Noise': lambda level: _gaussian_noise_level_to_arg(
          level, translate_const),
      # pylint:disable=g-long-lambda
      'TranslateX_BBox': lambda level: _translate_level_to_arg(
          level, translate_const),
      'TranslateY_BBox': lambda level: _translate_level_to_arg(
          level, translate_const),
      # pylint:enable=g-long-lambda
      'TranslateY_Only_BBoxes': translate_bbox_arg,
  }
  return args


def bbox_wrapper(func):
  """Adds a bboxes function argument to func and returns unchanged bboxes."""
  def wrapper(images, bboxes, *args, **kwargs):
    return (func(images, *args, **kwargs), bboxes)
  return wrapper


def _parse_policy_info(name: str,
                       prob: float,
                       level: float,
                       replace_value: List[int],
                       cutout_const: float,
                       translate_const: float,
                       level_std: float = 0.) -> Tuple[Any, float, Any]:
  """Return the function that corresponds to `name` and update `level` param."""
  func = NAME_TO_FUNC[name]

  if level_std > 0:
    level += tf.random.normal([], dtype=tf.float32)
    level = tf.clip_by_value(level, 0., _MAX_LEVEL)

  args = level_to_arg(cutout_const, translate_const)[name](level)

  if name in PROB_FUNCS:
    # Add in the prob arg if it is required for the function that is called.
    args = tuple([prob] + list(args))

  if name in REPLACE_FUNCS:
    # Add in replace arg if it is required for the function that is called.
    args = tuple(list(args) + [replace_value])

  # Add bboxes as the second positional argument for the function if it does
  # not already exist.
  if 'bboxes' not in inspect.getfullargspec(func)[0]:
    func = bbox_wrapper(func)

  return func, prob, args


class ImageAugment(object):
  """Image augmentation class for applying image distortions."""

  def distort(
      self,
      image: tf.Tensor
  ) -> tf.Tensor:
    """Given an image tensor, returns a distorted image with the same shape.

    Expect the image tensor values are in the range [0, 255].

    Args:
      image: `Tensor` of shape [height, width, 3] or
        [num_frames, height, width, 3] representing an image or image sequence.

    Returns:
      The augmented version of `image`.
    """
    raise NotImplementedError()

  def distort_with_boxes(
      self,
      image: tf.Tensor,
      bboxes: tf.Tensor
  ) -> Tuple[tf.Tensor, tf.Tensor]:
    """Distorts the image and bounding boxes.

    Expect the image tensor values are in the range [0, 255].

    Args:
      image: `Tensor` of shape [height, width, 3] or
        [num_frames, height, width, 3] representing an image or image sequence.
      bboxes: `Tensor` of shape [num_boxes, 4] or [num_frames, num_boxes, 4]
        representing bounding boxes for an image or image sequence.

    Returns:
      The augmented version of `image` and `bboxes`.
    """
    raise NotImplementedError


class AutoAugment(ImageAugment):
  """Applies the AutoAugment policy to images.

    AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
  """

  def __init__(self,
               augmentation_name: str = 'v0',
               policies: Optional[Iterable[Iterable[Tuple[str, float,
                                                          float]]]] = None,
               cutout_const: float = 100,
               translate_const: float = 250):
    """Applies the AutoAugment policy to images.

    Args:
      augmentation_name: The name of the AutoAugment policy to use. The
        available options are `v0`, `test`, `reduced_cifar10`, `svhn` and
        `reduced_imagenet`. `v0` is the policy used for all
        of the results in the paper and was found to achieve the best results on
        the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
        on the COCO dataset that have slight variation in what operations were
        used during the search procedure along with how many operations are
        applied in parallel to a single image (2 vs 3). Make sure to set
        `policies` to `None` (the default) if you want to set options using
        `augmentation_name`.
      policies: list of lists of tuples in the form `(func, prob, level)`,
        `func` is a string name of the augmentation function, `prob` is the
        probability of applying the `func` operation, `level` (or magnitude) is
        the input argument for `func`. For example:
        ```
        [[('Equalize', 0.9, 3), ('Color', 0.7, 8)],
         [('Invert', 0.6, 5), ('Rotate', 0.2, 9), ('ShearX', 0.1, 2)], ...]
        ```
        The outer-most list must be 3-d. The number of operations in a
        sub-policy can vary from one sub-policy to another.
        If you provide `policies` as input, any option set with
        `augmentation_name` will get overriden as they are mutually exclusive.
      cutout_const: multiplier for applying cutout.
      translate_const: multiplier for applying translation.

    Raises:
      ValueError if `augmentation_name` is unsupported.
    """
    super(AutoAugment, self).__init__()

    self.augmentation_name = augmentation_name
    self.cutout_const = float(cutout_const)
    self.translate_const = float(translate_const)
    self.available_policies = {
        'detection_v0': self.detection_policy_v0(),
        'v0': self.policy_v0(),
        'test': self.policy_test(),
        'simple': self.policy_simple(),
        'reduced_cifar10': self.policy_reduced_cifar10(),
        'svhn': self.policy_svhn(),
        'reduced_imagenet': self.policy_reduced_imagenet(),
        'panoptic_deeplab_policy': self.panoptic_deeplab_policy(),
        'vit': self.vit(),
        'deit3_three_augment': self.deit3_three_augment(),
    }

    if not policies:
      if augmentation_name not in self.available_policies:
        raise ValueError(
            'Invalid augmentation_name: {}'.format(augmentation_name))

      self.policies = self.available_policies[augmentation_name]

    else:
      self._check_policy_shape(policies)
      self.policies = policies

  def _check_policy_shape(self, policies):
    """Checks dimension and shape of the custom policy.

    Args:
      policies: List of list of tuples in the form `(func, prob, level)`. Must
        have shape of `(:, :, 3)`.

    Raises:
      ValueError if the shape of `policies` is unexpected.
    """
    in_shape = np.array(policies).shape
    if len(in_shape) != 3 or in_shape[-1:] != (3,):
      raise ValueError('Wrong shape detected for custom policy. Expected '
                       '(:, :, 3) but got {}.'.format(in_shape))

  def _make_tf_policies(self):
    """Prepares the TF functions for augmentations based on the policies."""
    replace_value = [128] * 3

    # func is the string name of the augmentation function, prob is the
    # probability of applying the operation and level is the parameter
    # associated with the tf op.

    # tf_policies are functions that take in an image and return an augmented
    # image.
    tf_policies = []
    for policy in self.policies:
      tf_policy = []
      assert_ranges = []
      # Link string name to the correct python function and make sure the
      # correct argument is passed into that function.
      for policy_info in policy:
        _, prob, level = policy_info
        assert_ranges.append(tf.Assert(tf.less_equal(prob, 1.), [prob]))
        assert_ranges.append(
            tf.Assert(tf.less_equal(level, int(_MAX_LEVEL)), [level]))

        policy_info = list(policy_info) + [
            replace_value, self.cutout_const, self.translate_const
        ]
        tf_policy.append(_parse_policy_info(*policy_info))
      # Now build the tf policy that will apply the augmentation procedue
      # on image.
      def make_final_policy(tf_policy_):

        def final_policy(image_, bboxes_):
          for func, prob, args in tf_policy_:
            image_, bboxes_ = _apply_func_with_prob(func, image_, bboxes_, args,
                                                    prob)
          return image_, bboxes_

        return final_policy

      with tf.control_dependencies(assert_ranges):
        tf_policies.append(make_final_policy(tf_policy))

    return tf_policies

  def distort(self, image: tf.Tensor) -> tf.Tensor:
    """See base class."""
    input_image_type = image.dtype
    if input_image_type != tf.uint8:
      image = tf.clip_by_value(image, 0.0, 255.0)
      image = tf.cast(image, dtype=tf.uint8)

    tf_policies = self._make_tf_policies()
    image, _ = select_and_apply_random_policy(tf_policies, image, bboxes=None)
    image = tf.cast(image, dtype=input_image_type)
    return image

  def distort_with_boxes(self, image: tf.Tensor,
                         bboxes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """See base class."""
    input_image_type = image.dtype
    if input_image_type != tf.uint8:
      image = tf.clip_by_value(image, 0.0, 255.0)
      image = tf.cast(image, dtype=tf.uint8)

    tf_policies = self._make_tf_policies()
    image, bboxes = select_and_apply_random_policy(tf_policies, image, bboxes)
    image = tf.cast(image, dtype=input_image_type)
    assert bboxes is not None
    return image, bboxes

  @staticmethod
  def detection_policy_v0():
    """Autoaugment policy that was used in AutoAugment Paper for Detection.

    https://arxiv.org/pdf/1906.11172

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied sequentially on the image.

    Returns:
      the policy.
    """
    policy = [
        [('TranslateX_BBox', 0.6, 4), ('Equalize', 0.8, 10)],
        [('TranslateY_Only_BBoxes', 0.2, 2), ('Cutout', 0.8, 8)],
        [('Sharpness', 0.0, 8), ('ShearX_BBox', 0.4, 0)],
        [('ShearY_BBox', 1.0, 2), ('TranslateY_Only_BBoxes', 0.6, 6)],
        [('Rotate_BBox', 0.6, 10), ('Color', 1.0, 6)],
    ]
    return policy

  @staticmethod
  def policy_v0():
    """Autoaugment policy that was used in AutoAugment Paper.

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied sequentially on the image.

    Returns:
      the policy.
    """

    policy = [
        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
        [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
    ]
    return policy

  @staticmethod
  def policy_reduced_cifar10():
    """Autoaugment policy for reduced CIFAR-10 dataset.

    Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied sequentially on the image.

    Returns:
      the policy.
    """
    policy = [
        [('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
        [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
        [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
        [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
        [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)],
        [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
        [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
        [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
        [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
        [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)],
        [('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
        [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
        [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
        [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
        [('Solarize', 0.5, 2), ('Invert', 0.0, 3)],
        [('Equalize', 0.2, 0), ('AutoContrast', 0.6, 0)],
        [('Equalize', 0.2, 8), ('Equalize', 0.6, 4)],
        [('Color', 0.9, 9), ('Equalize', 0.6, 6)],
        [('AutoContrast', 0.8, 4), ('Solarize', 0.2, 8)],
        [('Brightness', 0.1, 3), ('Color', 0.7, 0)],
        [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
        [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
        [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
        [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
        [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)],
    ]
    return policy

  @staticmethod
  def policy_svhn():
    """Autoaugment policy for SVHN dataset.

    Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied sequentially on the image.

    Returns:
      the policy.
    """
    policy = [
        [('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
        [('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
        [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
        [('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
        [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
        [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
        [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
        [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
        [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
        [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
        [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
        [('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
        [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
        [('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
        [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
        [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
        [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
        [('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
        [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
        [('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
        [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
        [('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
        [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
        [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
        [('ShearX', 0.7, 2), ('Invert', 0.1, 5)],
    ]
    return policy

  @staticmethod
  def policy_reduced_imagenet():
    """Autoaugment policy for reduced ImageNet dataset.

    Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied sequentially on the image.

    Returns:
      the policy.
    """
    policy = [
        [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
        [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
        [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
        [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)]
    ]
    return policy

  @staticmethod
  def policy_simple():
    """Same as `policy_v0`, except with custom ops removed."""

    policy = [
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
    ]
    return policy

  @staticmethod
  def panoptic_deeplab_policy():
    policy = [
        [('Sharpness', 0.4, 1.4), ('Brightness', 0.2, 2.0)],
        [('Equalize', 0.0, 1.8), ('Contrast', 0.2, 2.0)],
        [('Sharpness', 0.2, 1.8), ('Color', 0.2, 1.8)],
        [('Solarize', 0.2, 1.4), ('Equalize', 0.6, 1.8)],
        [('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4)]]
    return policy

  @staticmethod
  def vit():
    """Autoaugment policy for a generic ViT."""
    policy = [
        [('Sharpness', 0.4, 1.4), ('Brightness', 0.2, 2.0), ('Cutout', 0.8, 8)],
        [('Equalize', 0.0, 1.8), ('Contrast', 0.2, 2.0), ('Cutout', 0.8, 8)],
        [('Sharpness', 0.2, 1.8), ('Color', 0.2, 1.8), ('Cutout', 0.8, 8)],
        [('Solarize', 0.2, 1.4), ('Equalize', 0.6, 1.8), ('Cutout', 0.8, 8)],
        [('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4), ('Cutout', 0.8, 8)],
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8), ('Cutout', 0.8, 8)],
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8), ('Cutout', 0.8, 8)],
        [('Posterize', 0.6, 7), ('Posterize', 0.6, 6), ('Cutout', 0.8, 8)],
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5), ('Cutout', 0.8, 8)],
        ]
    return policy

  @staticmethod
  def deit3_three_augment():
    """Autoaugment policy for three augmentations.

    Proposed in paper: https://arxiv.org/abs/2204.07118.

    Each tuple is an augmentation operation of the form
    (operation, probability, magnitude). Each element in policy is a
    sub-policy that will be applied on the image. Randomly chooses one of the
    three augmentation to apply on image.

    Returns:
      the policy.
    """
    policy = [
        [('Grayscale', 1.0, 0)],
        [('Solarize', 1.0, 5)],  # to have threshold as 128
        [('Gaussian_Noise', 1.0, 1)],  # to have low_std as 0.1
        ]
    return policy

  @staticmethod
  def policy_test():
    """Autoaugment test policy for debugging."""
    policy = [
        [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
    ]
    return policy


def _maybe_identity(x: Optional[tf.Tensor]) -> Optional[tf.Tensor]:
  return tf.identity(x) if x is not None else None


class RandAugment(ImageAugment):
  """Applies the RandAugment policy to images.

  RandAugment is from the paper https://arxiv.org/abs/1909.13719.
  """

  def __init__(self,
               num_layers: int = 2,
               magnitude: float = 10.,
               cutout_const: float = 40.,
               translate_const: float = 100.,
               magnitude_std: float = 0.0,
               prob_to_apply: Optional[float] = None,
               exclude_ops: Optional[List[str]] = None):
    """Applies the RandAugment policy to images.

    Args:
      num_layers: Integer, the number of augmentation transformations to apply
        sequentially to an image. Represented as (N) in the paper. Usually best
        values will be in the range [1, 3].
      magnitude: Integer, shared magnitude across all augmentation operations.
        Represented as (M) in the paper. Usually best values are in the range
        [5, 10].
      cutout_const: multiplier for applying cutout.
      translate_const: multiplier for applying translation.
      magnitude_std: randomness of the severity as proposed by the authors of
        the timm library.
      prob_to_apply: The probability to apply the selected augmentation at each
        layer.
      exclude_ops: exclude selected operations.
    """
    super(RandAugment, self).__init__()

    self.num_layers = num_layers
    self.magnitude = float(magnitude)
    self.cutout_const = float(cutout_const)
    self.translate_const = float(translate_const)
    self.prob_to_apply = (
        float(prob_to_apply) if prob_to_apply is not None else None)
    self.available_ops = [
        'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
        'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
        'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
    ]
    self.magnitude_std = magnitude_std
    if exclude_ops:
      self.available_ops = [
          op for op in self.available_ops if op not in exclude_ops
      ]

  @classmethod
  def build_for_detection(cls,
                          num_layers: int = 2,
                          magnitude: float = 10.,
                          cutout_const: float = 40.,
                          translate_const: float = 100.,
                          magnitude_std: float = 0.0,
                          prob_to_apply: Optional[float] = None,
                          exclude_ops: Optional[List[str]] = None):
    """Builds a RandAugment that modifies bboxes for geometric transforms."""
    augmenter = cls(
        num_layers=num_layers,
        magnitude=magnitude,
        cutout_const=cutout_const,
        translate_const=translate_const,
        magnitude_std=magnitude_std,
        prob_to_apply=prob_to_apply,
        exclude_ops=exclude_ops)
    box_aware_ops_by_base_name = {
        'Rotate': 'Rotate_BBox',
        'ShearX': 'ShearX_BBox',
        'ShearY': 'ShearY_BBox',
        'TranslateX': 'TranslateX_BBox',
        'TranslateY': 'TranslateY_BBox',
    }
    augmenter.available_ops = [
        box_aware_ops_by_base_name.get(op_name) or op_name
        for op_name in augmenter.available_ops
    ]
    return augmenter

  def _distort_common(
      self,
      image: tf.Tensor,
      bboxes: Optional[tf.Tensor] = None
  ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
    """Distorts the image and optionally bounding boxes."""
    input_image_type = image.dtype

    if input_image_type != tf.uint8:
      image = tf.clip_by_value(image, 0.0, 255.0)
      image = tf.cast(image, dtype=tf.uint8)

    replace_value = [128] * 3
    min_prob, max_prob = 0.2, 0.8

    aug_image = image
    aug_bboxes = bboxes

    for _ in range(self.num_layers):
      op_to_select = tf.random.uniform([],
                                       maxval=len(self.available_ops) + 1,
                                       dtype=tf.int32)

      branch_fns = []
      for (i, op_name) in enumerate(self.available_ops):
        prob = tf.random.uniform([],
                                 minval=min_prob,
                                 maxval=max_prob,
                                 dtype=tf.float32)
        func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
                                           replace_value, self.cutout_const,
                                           self.translate_const,
                                           self.magnitude_std)
        branch_fns.append((
            i,
            # pylint:disable=g-long-lambda
            lambda selected_func=func, selected_args=args: selected_func(
                image, bboxes, *selected_args)))
        # pylint:enable=g-long-lambda

      aug_image, aug_bboxes = tf.switch_case(
          branch_index=op_to_select,
          branch_fns=branch_fns,
          default=lambda: (tf.identity(image), _maybe_identity(bboxes)))  # pylint: disable=cell-var-from-loop

      if self.prob_to_apply is not None:
        aug_image, aug_bboxes = tf.cond(
            tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply,
            lambda: (tf.identity(aug_image), _maybe_identity(aug_bboxes)),
            lambda: (tf.identity(image), _maybe_identity(bboxes)))
      image = aug_image
      bboxes = aug_bboxes

    image = tf.cast(image, dtype=input_image_type)
    return image, bboxes

  def distort(self, image: tf.Tensor) -> tf.Tensor:
    """See base class."""
    image, _ = self._distort_common(image)
    return image

  def distort_with_boxes(self, image: tf.Tensor,
                         bboxes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """See base class."""
    image, bboxes = self._distort_common(image, bboxes)
    assert bboxes is not None
    return image, bboxes


class RandomErasing(ImageAugment):
  """Applies RandomErasing to a single image.

  Reference: https://arxiv.org/abs/1708.04896

  Implementation is inspired by
  https://github.com/rwightman/pytorch-image-models.
  """

  def __init__(self,
               probability: float = 0.25,
               min_area: float = 0.02,
               max_area: float = 1 / 3,
               min_aspect: float = 0.3,
               max_aspect: Optional[float] = None,
               min_count=1,
               max_count=1,
               trials=10):
    """Applies RandomErasing to a single image.

    Args:
      probability: Probability of augmenting the image. Defaults to `0.25`.
      min_area: Minimum area of the random erasing rectangle. Defaults to
        `0.02`.
      max_area: Maximum area of the random erasing rectangle. Defaults to `1/3`.
      min_aspect: Minimum aspect rate of the random erasing rectangle. Defaults
        to `0.3`.
      max_aspect: Maximum aspect rate of the random erasing rectangle. Defaults
        to `None`.
      min_count: Minimum number of erased rectangles. Defaults to `1`.
      max_count: Maximum number of erased rectangles. Defaults to `1`.
      trials: Maximum number of trials to randomly sample a rectangle that
        fulfills constraint. Defaults to `10`.
    """
    self._probability = probability
    self._min_area = float(min_area)
    self._max_area = float(max_area)
    self._min_log_aspect = math.log(min_aspect)
    self._max_log_aspect = math.log(max_aspect or 1 / min_aspect)
    self._min_count = min_count
    self._max_count = max_count
    self._trials = trials

  def distort(self, image: tf.Tensor) -> tf.Tensor:
    """Applies RandomErasing to single `image`.

    Args:
      image (tf.Tensor): Of shape [height, width, 3] representing an image.

    Returns:
      tf.Tensor: The augmented version of `image`.
    """
    uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
    mirror_cond = tf.less(uniform_random, self._probability)
    image = tf.cond(mirror_cond, lambda: self._erase(image), lambda: image)
    return image

  @tf.function
  def _erase(self, image: tf.Tensor) -> tf.Tensor:
    """Erase an area."""
    if self._min_count == self._max_count:
      count = self._min_count
    else:
      count = tf.random.uniform(
          shape=[],
          minval=int(self._min_count),
          maxval=int(self._max_count - self._min_count + 1),
          dtype=tf.int32)

    image_height = tf.shape(image)[0]
    image_width = tf.shape(image)[1]
    area = tf.cast(image_width * image_height, tf.float32)

    for _ in range(count):
      # Work around since break is not supported in tf.function
      is_trial_successfull = False
      for _ in range(self._trials):
        if not is_trial_successfull:
          erase_area = tf.random.uniform(
              shape=[],
              minval=area * self._min_area,
              maxval=area * self._max_area)
          aspect_ratio = tf.math.exp(
              tf.random.uniform(
                  shape=[],
                  minval=self._min_log_aspect,
                  maxval=self._max_log_aspect))

          half_height = tf.cast(
              tf.math.round(tf.math.sqrt(erase_area * aspect_ratio) / 2),
              dtype=tf.int32)
          half_width = tf.cast(
              tf.math.round(tf.math.sqrt(erase_area / aspect_ratio) / 2),
              dtype=tf.int32)

          if 2 * half_height < image_height and 2 * half_width < image_width:
            center_height = tf.random.uniform(
                shape=[],
                minval=0,
                maxval=int(image_height - 2 * half_height),
                dtype=tf.int32)
            center_width = tf.random.uniform(
                shape=[],
                minval=0,
                maxval=int(image_width - 2 * half_width),
                dtype=tf.int32)

            image = _fill_rectangle(
                image,
                center_width,
                center_height,
                half_width,
                half_height,
                replace=None)

            is_trial_successfull = True

    return image


class MixupAndCutmix:
  """Applies Mixup and/or Cutmix to a batch of images.

  - Mixup: https://arxiv.org/abs/1710.09412
  - Cutmix: https://arxiv.org/abs/1905.04899

  Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
  """

  def __init__(self,
               num_classes: int,
               mixup_alpha: float = .8,
               cutmix_alpha: float = 1.,
               prob: float = 1.0,
               switch_prob: float = 0.5,
               label_smoothing: float = 0.1):
    """Applies Mixup and/or Cutmix to a batch of images.

    Args:

      num_classes (int): Number of classes.
      mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
        beta distribution (for each image). If zero Mixup is deactivated.
        Defaults to .8.
      cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
        beta distribution (for each image). If zero Cutmix is deactivated.
        Defaults to 1..
      prob (float, optional): Of augmenting the batch. Defaults to 1.0.
      switch_prob (float, optional): Probability of applying Cutmix for the
        batch. Defaults to 0.5.
      label_smoothing (float, optional): Constant for label smoothing. Defaults
        to 0.1.
    """
    self.mixup_alpha = mixup_alpha
    self.cutmix_alpha = cutmix_alpha
    self.mix_prob = prob
    self.switch_prob = switch_prob
    self.label_smoothing = label_smoothing
    self.num_classes = num_classes
    self.mode = 'batch'
    self.mixup_enabled = True

    if self.mixup_alpha and not self.cutmix_alpha:
      self.switch_prob = -1
    elif not self.mixup_alpha and self.cutmix_alpha:
      self.switch_prob = 1

  def __call__(self, images: tf.Tensor,
               labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    return self.distort(images, labels)

  def distort(self, images: tf.Tensor,
              labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """Applies Mixup and/or Cutmix to batch of images and transforms labels.

    Args:
      images (tf.Tensor): Of shape [batch_size, height, width, 3] representing a
        batch of image, or [batch_size, time, height, width, 3] representing a
        batch of video.
      labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
        each image of the batch.

    Returns:
      Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
        `labels`.
    """
    labels = tf.reshape(labels, [-1])
    augment_cond = tf.less(
        tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob)
    # pylint: disable=g-long-lambda
    augment_a = lambda: self._update_labels(*tf.cond(
        tf.less(
            tf.random.uniform(shape=[], minval=0., maxval=1.0), self.switch_prob
        ), lambda: self._cutmix(images, labels), lambda: self._mixup(
            images, labels)))
    augment_b = lambda: (images, self._smooth_labels(labels))
    # pylint: enable=g-long-lambda

    return tf.cond(augment_cond, augment_a, augment_b)

  @staticmethod
  def _sample_from_beta(alpha, beta, shape):
    sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
    sample_beta = tf.random.gamma(shape, 1., beta=beta)
    return sample_alpha / (sample_alpha + sample_beta)

  def _cutmix(self, images: tf.Tensor,
              labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Applies cutmix."""
    lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
                                           tf.shape(labels))

    ratio = tf.math.sqrt(1 - lam)

    batch_size = tf.shape(images)[0]

    if images.shape.rank == 4:
      image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
      fill_fn = _fill_rectangle
    elif images.shape.rank == 5:
      image_height, image_width = tf.shape(images)[2], tf.shape(images)[3]
      fill_fn = _fill_rectangle_video
    else:
      raise ValueError('Bad image rank: {}'.format(images.shape.rank))

    cut_height = tf.cast(
        ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
    cut_width = tf.cast(
        ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)

    random_center_height = tf.random.uniform(
        shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32)
    random_center_width = tf.random.uniform(
        shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32)

    bbox_area = cut_height * cut_width
    lam = 1. - bbox_area / (image_height * image_width)
    lam = tf.cast(lam, dtype=tf.float32)

    images = tf.map_fn(
        lambda x: fill_fn(*x),
        (images, random_center_width, random_center_height, cut_width // 2,
         cut_height // 2, tf.reverse(images, [0])),
        dtype=(
            images.dtype, tf.int32, tf.int32, tf.int32, tf.int32, images.dtype),
        fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=images.dtype))

    return images, labels, lam

  def _mixup(self, images: tf.Tensor,
             labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Applies mixup."""
    lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
                                           tf.shape(labels))
    if images.shape.rank == 4:
      lam = tf.reshape(lam, [-1, 1, 1, 1])
    elif images.shape.rank == 5:
      lam = tf.reshape(lam, [-1, 1, 1, 1, 1])
    else:
      raise ValueError('Bad image rank: {}'.format(images.shape.rank))

    lam_cast = tf.cast(lam, dtype=images.dtype)
    images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])

    return images, labels, tf.squeeze(lam)

  def _smooth_labels(self, labels: tf.Tensor) -> tf.Tensor:
    off_value = self.label_smoothing / self.num_classes
    on_value = 1. - self.label_smoothing + off_value

    smooth_labels = tf.one_hot(
        labels, self.num_classes, on_value=on_value, off_value=off_value)
    return smooth_labels

  def _update_labels(self, images: tf.Tensor, labels: tf.Tensor,
                     lam: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    labels_1 = self._smooth_labels(labels)
    labels_2 = tf.reverse(labels_1, [0])

    lam = tf.reshape(lam, [-1, 1])
    labels = lam * labels_1 + (1. - lam) * labels_2

    return images, labels


def filter_boxes_by_ioa(
    bboxes: tf.Tensor, crop_box: tf.Tensor, min_box_overlap: float
) -> tf.Tensor:
  """Filter boxes by intersection over area (IOA).

  The boxes with IOA less than min_box_overlap will be replaced by
  (0, 0, 0, 0) so they can be filtered out later.

  Args:
    bboxes: a float tensor of shape [N, 4] representing normalized bounding box
      coordinates.
    crop_box: a float tensor of shape [1, 1, 4] representing the normalized crop
      box.
    min_box_overlap: minimum overlap of the box with the crop box to keep the
      box.

  Returns:
    a tensor of shape [N, 4] with filtered box coordinates replaced by 0.
  """
  ioas = box_ops.bbox_intersection_over_area(bboxes[None, ...], crop_box)[0]
  keep_boxes = ioas >= min_box_overlap
  # Set coordinates to (0, 0, 0, 0) for filtered boxes
  return bboxes * tf.cast(keep_boxes, dtype=bboxes.dtype)


def crop_normalized_boxes(
    bboxes: tf.Tensor,
    ori_image_size: tf.Tensor,
    new_image_size: tf.Tensor,
    offset: tf.Tensor,
) -> tf.Tensor:
  """Crop normalized boxes.

  Args:
    bboxes: a float tensor of shape [N, 4] representing normalized box
      coordinates.
    ori_image_size: an int tensor of shape [2] representing the original image
      size.
    new_image_size: an int tensor of shape [2] representing the cropped image
      size.
    offset: an int tensor of shape [2] representing the offset of the crop.

  Returns:
    a tensor of shape [N, 4] representing the new normalized bounding box
    coordinates in the new cropped image.
  """
  new_bboxes = box_ops.denormalize_boxes(bboxes, ori_image_size)
  new_bboxes -= tf.tile(tf.cast(offset, dtype=tf.float32), [2])[None, ...]
  new_bboxes = box_ops.normalize_boxes(new_bboxes, new_image_size)
  return tf.clip_by_value(new_bboxes, 0.0, 1.0)


class SSDRandomCrop(ImageAugment):
  """Random crop preprocessing as in the SSD paper.

  Liu et al., SSD: Single shot multibox detector
  https://arxiv.org/abs/1512.02325.

  The implementation originated from TF Object Detection API:
  https://github.com/tensorflow/models/blob/f36581036d3346a9496de06c8fd678d23cfe2103/research/object_detection/core/preprocessor.py#L3529
  """

  def __init__(
      self,
      params: Sequence[configs.SSDRandomCropParam] | None = None,
      aspect_ratio_range: tuple[float, float] = (0.5, 2.0),
      area_range: tuple[float, float] = (0.1, 1.0),
  ):
    """Apply random crop to the image as in the SSD paper.

    The SSD random crop will randomly select one set of the parameters.

    Args:
      params: a sequence of SSDRandomCropParam that contains:
        min_object_covered - a float representing minimum the cropped image
          must cover at least this fraction with at least one of the input
          bounding boxes.
        min_box_overlap - a float representing minimum overlap of the bounding
          box with the cropped image to keep the box.
        prob_to_apply - a float representing the probability to crop.
      aspect_ratio_range: allowed range for aspect ratio of the cropped image.
      area_range: allowed range for area ratio between cropped image and the
        original image.
    """
    if params is None:
      params = configs.SSDRandomCrop().ssd_random_crop_params
    self.num_cases = len(params)
    self.min_object_covered = tf.constant(
        [param.min_object_covered for param in params], dtype=tf.float32,
    )
    self.min_box_overlap = tf.constant(
        [param.min_box_overlap for param in params], dtype=tf.float32,
    )
    self.prob_to_apply = tf.constant(
        [param.prob_to_apply for param in params], dtype=tf.float32,
    )
    self.aspect_ratio_range = aspect_ratio_range
    self.area_range = area_range

  def distort_with_boxes(
      self, image: tf.Tensor, bboxes: tf.Tensor
  ) -> tuple[tf.Tensor, tf.Tensor]:
    """See base class."""
    i_params = tf.random.uniform([], maxval=self.num_cases, dtype=tf.int32)

    if tf.random.uniform(shape=[], maxval=1.0) > self.prob_to_apply[i_params]:
      return image, bboxes

    image_size = tf.shape(image)
    bboxes = tf.clip_by_value(bboxes, 0., 1.)
    offset, new_image_size, crop_box = tf.image.sample_distorted_bounding_box(
        image_size=image_size,
        bounding_boxes=bboxes[None, ...],
        min_object_covered=self.min_object_covered[i_params],
        aspect_ratio_range=self.aspect_ratio_range,
        area_range=self.area_range,
        max_attempts=100,
        use_image_if_no_bounding_boxes=True,
    )
    new_image = tf.slice(image, offset, new_image_size)

    new_bboxes = filter_boxes_by_ioa(
        bboxes, crop_box, self.min_box_overlap[i_params]
    )
    new_bboxes = crop_normalized_boxes(
        new_bboxes, image_size[:2], new_image_size[:2], offset[:2]
    )
    return new_image, new_bboxes