tensorflow/tensorflow

View on GitHub
tensorflow/python/ops/numpy_ops/np_array_ops.py

Summary

Maintainability
F
1 wk
Test Coverage
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common array methods."""
# pylint: disable=g-direct-tensorflow-import

import builtins
import enum
import functools
import math
import numbers

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_assert
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import manip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_utils
from tensorflow.python.types import core as core_tf_types
from tensorflow.python.util import nest
from tensorflow.python.util import tf_export


newaxis = np.newaxis
tf_export.tf_export('experimental.numpy.newaxis', v1=[]).export_constant(
    __name__, 'newaxis'
)


@tf_export.tf_export('experimental.numpy.empty', v1=[])
@np_utils.np_doc('empty')
def empty(shape, dtype=float):  # pylint: disable=redefined-outer-name
  return zeros(shape, dtype)


@tf_export.tf_export('experimental.numpy.empty_like', v1=[])
@np_utils.np_doc('empty_like')
def empty_like(a, dtype=None):
  return zeros_like(a, dtype)


@tf_export.tf_export('experimental.numpy.zeros', v1=[])
@np_utils.np_doc('zeros')
def zeros(shape, dtype=float):  # pylint: disable=redefined-outer-name
  dtype = (
      np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()
  )
  return array_ops.zeros(shape, dtype=dtype)


@tf_export.tf_export('experimental.numpy.zeros_like', v1=[])
@np_utils.np_doc('zeros_like')
def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
  dtype = np_utils.result_type_unary(a, dtype)

  dtype = dtypes.as_dtype(dtype)  # Work around b/149877262
  return array_ops.zeros_like(a, dtype)


@tf_export.tf_export('experimental.numpy.ones', v1=[])
@np_utils.np_doc('ones')
def ones(shape, dtype=float):  # pylint: disable=redefined-outer-name
  if dtype:
    dtype = np_utils.result_type(dtype)
  return array_ops.ones(shape, dtype=dtype)


@tf_export.tf_export('experimental.numpy.ones_like', v1=[])
@np_utils.np_doc('ones_like')
def ones_like(a, dtype=None):
  dtype = np_utils.result_type_unary(a, dtype)
  return array_ops.ones_like(a, dtype)


@tf_export.tf_export('experimental.numpy.eye', v1=[])
@np_utils.np_doc('eye')
def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-docstring
  if dtype:
    dtype = np_utils.result_type(dtype)
  if not M:
    M = N
  # Making sure N, M and k are `int`
  N = int(N)
  M = int(M)
  k = int(k)
  if k >= M or -k >= N:
    # tf.linalg.diag will raise an error in this case
    return zeros([N, M], dtype=dtype)
  if k == 0:
    return linalg_ops.eye(N, M, dtype=dtype)
  # We need the precise length, otherwise tf.linalg.diag will raise an error
  diag_len = builtins.min(N, M)
  if k > 0:
    if N >= M:
      diag_len -= k
    elif N + k > M:
      diag_len = M - k
  elif k <= 0:
    if M >= N:
      diag_len += k
    elif M - k > N:
      diag_len = N + k
  diagonal_ = array_ops.ones([diag_len], dtype=dtype)
  return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)


@tf_export.tf_export('experimental.numpy.identity', v1=[])
@np_utils.np_doc('identity')
def identity(n, dtype=float):
  return eye(N=n, M=n, dtype=dtype)


@tf_export.tf_export('experimental.numpy.full', v1=[])
@np_utils.np_doc('full')
def full(shape, fill_value, dtype=None):  # pylint: disable=redefined-outer-name
  if not isinstance(shape, np_arrays.ndarray):
    shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
  shape = atleast_1d(shape)
  fill_value = asarray(fill_value, dtype=dtype)
  return array_ops.broadcast_to(fill_value, shape)


# Using doc only here since np full_like signature doesn't seem to have the
# shape argument (even though it exists in the documentation online).
@tf_export.tf_export('experimental.numpy.full_like', v1=[])
@np_utils.np_doc_only('full_like')
def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):  # pylint: disable=missing-docstring,redefined-outer-name
  """order, subok and shape arguments mustn't be changed."""
  if order != 'K':
    raise ValueError('Non-standard orders are not supported.')
  if not subok:
    raise ValueError('subok being False is not supported.')
  if shape:
    raise ValueError('Overriding the shape is not supported.')

  a = asarray(a)
  dtype = dtype or np_utils.result_type(a)
  fill_value = asarray(fill_value, dtype=dtype)
  return array_ops.broadcast_to(fill_value, array_ops.shape(a))


def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
  """Main implementation of np.array()."""
  result_t = val

  if not isinstance(result_t, tensor_lib.Tensor):
    dtype = np_utils.result_type_unary(result_t, dtype)
    # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
    # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
    # while np.array allows them. We need to convert-then-cast.

    # EagerTensor conversion complains about "mixed types" when converting
    # tensors with no dtype information. This is because it infers types based
    # on one selected item in the list. So e.g. when converting [2., 2j]
    # to a tensor, it will select float32 as the inferred type and not be able
    # to convert the list to a float 32 tensor.
    # Since we have some information about the final dtype we care about, we
    # supply that information so that convert_to_tensor will do best-effort
    # conversion to that dtype first.
    result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype)
    result_t = math_ops.cast(result_t, dtype=dtype)
  elif dtype:
    result_t = math_ops.cast(result_t, dtype)

  if copy:
    result_t = array_ops.identity(result_t)

  max_ndmin = 32
  if ndmin > max_ndmin:
    raise ValueError(
        f'ndmin bigger than allowable number of dimensions: {max_ndmin}.'
    )

  if ndmin == 0:
    return result_t

  ndims = array_ops.rank(result_t)

  def true_fn():
    old_shape = array_ops.shape(result_t)
    new_shape = array_ops.concat(
        [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0
    )
    return array_ops.reshape(result_t, new_shape)

  result_t = np_utils.cond(
      np_utils.greater(ndmin, ndims), true_fn, lambda: result_t
  )
  return result_t


# TODO(wangpeng): investigate whether we can make `copy` default to False.
# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
@tf_export.tf_export('experimental.numpy.array', v1=[])
@np_utils.np_doc_only('array')
def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
  """Since Tensors are immutable, a copy is made only if val is placed on a

  different device than the current one. Even if `copy` is False, a new Tensor
  may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
  is an ndarray or a Tensor.
  """  # pylint:disable=g-docstring-missing-newline
  if dtype:
    dtype = np_utils.result_type(dtype)
  return _array_internal(val, dtype, copy, ndmin)


# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args


@tf_export.tf_export('experimental.numpy.asarray', v1=[])
@np_utils.np_doc('asarray')
def asarray(a, dtype=None):
  if dtype:
    dtype = np_utils.result_type(dtype)
  if isinstance(a, np_arrays.ndarray) and (
      not dtype or dtype == a.dtype.as_numpy_dtype
  ):
    return a
  return array(a, dtype, copy=False)


@tf_export.tf_export('experimental.numpy.asanyarray', v1=[])
@np_utils.np_doc('asanyarray')
def asanyarray(a, dtype=None):
  return asarray(a, dtype)


@tf_export.tf_export('experimental.numpy.ascontiguousarray', v1=[])
@np_utils.np_doc('ascontiguousarray')
def ascontiguousarray(a, dtype=None):
  return array(a, dtype, ndmin=1)


# Numerical ranges.
@tf_export.tf_export('experimental.numpy.arange', v1=[])
@np_utils.np_doc('arange')
def arange(start, stop=None, step=1, dtype=None):
  """Returns `step`-separated values in the range [start, stop).

  Args:
    start: Start of the interval. Included in the range.
    stop: End of the interval. If not specified, `start` is treated as 0 and
      `start` value is used as `stop`. If specified, it is not included in the
      range if `step` is integer. When `step` is floating point, it may or may
      not be included.
    step: The difference between 2 consecutive values in the output range. It is
      recommended to use `linspace` instead of using non-integer values for
      `step`.
    dtype: Optional. Type of the resulting ndarray. Could be a python type, a
      NumPy type or a TensorFlow `DType`. If not provided, the largest type of
      `start`, `stop`, `step` is used.

  Raises:
    ValueError: If step is zero.
  """
  if not step:
    raise ValueError('step must be non-zero.')
  if dtype:
    dtype = np_utils.result_type(dtype)
  else:
    if stop is None:
      dtype = np_utils.result_type(start, step)
    else:
      dtype = np_utils.result_type(start, step, stop)
  if step > 0 and (
      (stop is not None and start > stop) or (stop is None and start < 0)
  ):
    return array([], dtype=dtype)
  if step < 0 and (
      (stop is not None and start < stop) or (stop is None and start > 0)
  ):
    return array([], dtype=dtype)
  # TODO(srbs): There are some bugs when start or stop is float type and dtype
  # is integer type.
  return math_ops.cast(
      math_ops.range(start, limit=stop, delta=step), dtype=dtype
  )


# Building matrices.
@tf_export.tf_export('experimental.numpy.diag', v1=[])
@np_utils.np_doc('diag')
def diag(v, k=0):  # pylint: disable=missing-docstring
  """Raises an error if input is not 1- or 2-d."""
  v = asarray(v)
  v_rank = array_ops.rank(v)

  v.shape.with_rank_at_most(2)

  # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
  # tracing time if the shape is known.
  control_flow_assert.Assert(
      np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)),
      [v_rank],
  )

  def _diag(v, k):
    return np_utils.cond(
        math_ops.equal(array_ops.size(v), 0),
        lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
        lambda: array_ops.matrix_diag(v, k=k),
    )

  def _diag_part(v, k):
    v_shape = array_ops.shape(v)
    v, k = np_utils.cond(
        np_utils.logical_or(
            np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
            np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
        ),
        lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0),
        lambda: (v, k),
    )
    result = array_ops.matrix_diag_part(v, k=k)
    return result

  result = np_utils.cond(
      math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)
  )
  return result


@tf_export.tf_export('experimental.numpy.diagonal', v1=[])
@np_utils.np_doc('diagonal')
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a)

  maybe_rank = a.shape.rank
  if (
      maybe_rank is not None
      and offset == 0
      and (axis1 == maybe_rank - 2 or axis1 == -2)
      and (axis2 == maybe_rank - 1 or axis2 == -1)
  ):
    return array_ops.matrix_diag_part(a)

  a = moveaxis(a, (axis1, axis2), (-2, -1))

  a_shape = array_ops.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (
        array_ops.zeros(
            array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype
        ),
        0,
    )

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = np_utils.cond(
      np_utils.logical_or(
          np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
          np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
      ),
      _zeros,
      lambda: (a, offset),
  )

  a = array_ops.matrix_diag_part(a, k=offset)
  return a


@tf_export.tf_export('experimental.numpy.diagflat', v1=[])
@np_utils.np_doc('diagflat')
def diagflat(v, k=0):
  v = asarray(v)
  return diag(array_ops.reshape(v, [-1]), k)


def _promote_dtype(*arrays):
  dtype = np_utils.result_type(*arrays)

  def _fast_asarray(a):
    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
      return a
    return _array_internal(a, dtype=dtype, copy=False)

  return [_fast_asarray(a) for a in arrays]


def _promote_dtype_binary(t1, t2):
  dtype = np_utils._result_type_binary(t1, t2)  # pylint: disable=protected-access
  if not (
      isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype
  ):
    t1 = _array_internal(t1, dtype=dtype, copy=False)
  if not (
      isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype
  ):
    t2 = _array_internal(t2, dtype=dtype, copy=False)
  return t1, t2


@tf_export.tf_export('experimental.numpy.all', v1=[])
@np_utils.np_doc('all')
def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
  a = asarray(a, dtype=bool)
  return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)


@tf_export.tf_export('experimental.numpy.any', v1=[])
@np_utils.np_doc('any')
def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
  a = asarray(a, dtype=bool)
  return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)


@tf_export.tf_export('experimental.numpy.compress', v1=[])
@np_utils.np_doc('compress')
def compress(condition, a, axis=None):  # pylint: disable=redefined-outer-name,missing-function-docstring
  condition = asarray(condition, dtype=bool)
  a = asarray(a)

  if condition.ndim != 1:
    raise ValueError('condition must be a 1-d array.')
  # `np.compress` treats scalars as 1-d arrays.
  if a.ndim == 0:
    a = ravel(a)

  if axis is None:
    a = ravel(a)
    axis = 0

  if axis < 0:
    axis += a.ndim

  assert axis >= 0 and axis < a.ndim

  # `tf.boolean_mask` requires the first dimensions of array and condition to
  # match. `np.compress` pads condition with False when it is shorter.
  condition_t = condition
  a_t = a
  if condition.shape[0] < a.shape[axis]:
    padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
    condition_t = array_ops.concat([condition_t, padding], axis=0)
  return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)


@tf_export.tf_export('experimental.numpy.copy', v1=[])
@np_utils.np_doc('copy')
def copy(a):
  return array(a, copy=True)


def _maybe_promote_to_int(a):
  if dtypes.as_dtype(a.dtype).is_integer:
    # If a is an integer type and its precision is less than that of `int`,
    # the output type will be `int`.
    a_numpy_dtype = a.dtype.as_numpy_dtype
    output_type = np.promote_types(a_numpy_dtype, int)
    if output_type != a_numpy_dtype:
      a = asarray(a, dtype=output_type)

  return a


@tf_export.tf_export('experimental.numpy.cumprod', v1=[])
@np_utils.np_doc('cumprod')
def cumprod(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
  a = asarray(a, dtype=dtype)

  if dtype is None:
    a = _maybe_promote_to_int(a)

  # If axis is None, the input is flattened.
  if axis is None:
    a = ravel(a)
    axis = 0
  elif axis < 0:
    axis += array_ops.rank(a)
  return math_ops.cumprod(a, axis)


@tf_export.tf_export('experimental.numpy.cumsum', v1=[])
@np_utils.np_doc('cumsum')
def cumsum(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
  a = asarray(a, dtype=dtype)

  if dtype is None:
    a = _maybe_promote_to_int(a)

  # If axis is None, the input is flattened.
  if axis is None:
    a = ravel(a)
    axis = 0
  elif axis < 0:
    axis += array_ops.rank(a)
  return math_ops.cumsum(a, axis)


@tf_export.tf_export('experimental.numpy.imag', v1=[])
@np_utils.np_doc('imag')
def imag(val):
  val = asarray(val)
  # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
  # return an ndarray.
  return math_ops.imag(val)


_TO_INT_ = 0
_TO_FLOAT = 1


def _reduce(
    tf_fn,
    a,
    axis=None,
    dtype=None,
    keepdims=None,
    promote_int=_TO_INT_,
    tf_bool_fn=None,
    preserve_bool=False,
):
  """A general reduction function.

  Args:
    tf_fn: the TF reduction function.
    a: the array to be reduced.
    axis: (optional) the axis along which to do the reduction. If None, all
      dimensions are reduced.
    dtype: (optional) the dtype of the result.
    keepdims: (optional) whether to keep the reduced dimension(s).
    promote_int: how to promote integer and bool inputs. There are three
      choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2)
      `_TO_FLOAT` always promotes them to a float type (determined by
      dtypes.default_float_type); (3) None: don't promote.
    tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
      only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
      is `np.bool_` and `preserve_bool` is True.
    preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
      is `np.bool_` (some reductions such as np.sum convert bools to integers,
      while others such as np.max preserve bools.

  Returns:
    An ndarray.
  """
  if dtype:
    dtype = np_utils.result_type(dtype)
  if keepdims is None:
    keepdims = False
  a = asarray(a, dtype=dtype)
  if (
      dtype == np.bool_ or preserve_bool and a.dtype == np.bool_
  ) and tf_bool_fn is not None:
    return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
  if dtype is None:
    dtype = a.dtype.as_numpy_dtype
    if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
      if promote_int == _TO_INT_:
        # If a is an integer/bool type and whose bit width is less than np.int_,
        # numpy up-casts it to np.int_ based on the documentation at
        # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html
        if dtype == np.bool_:
          is_signed = True
          width = 8  # We can use any number here that is less than 64
        else:
          is_signed = np.issubdtype(dtype, np.signedinteger)
          width = np.iinfo(dtype).bits
        # Numpy int_ and uint are defined as 'long' and 'unsigned long', so
        # should have the same bit width.
        if ops.is_auto_dtype_conversion_enabled():
          # We default to 32 bits when using auto dtype conversion semantics.
          if width < np.iinfo(np.int32).bits:
            if is_signed:
              dtype = np.int32
            else:
              dtype = np.uint32
        else:
          if width < np.iinfo(np.int_).bits:
            if is_signed:
              dtype = np.int_
            else:
              dtype = np.uint
        a = math_ops.cast(a, dtype)
      elif promote_int == _TO_FLOAT:
        # Use a default float type.
        a = math_ops.cast(a, np_utils.result_type(float))

  if isinstance(axis, tensor_lib.Tensor) and axis.dtype not in (
      dtypes.int32,
      dtypes.int64,
  ):
    axis = math_ops.cast(axis, dtypes.int64)

  return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)


# TODO (DarrenZhang01): Add `axis` support to the `size` API.
@tf_export.tf_export('experimental.numpy.size', v1=[])
@np_utils.np_doc('size')
def size(x, axis=None):  # pylint: disable=missing-docstring
  if axis is not None:
    raise NotImplementedError(
        'axis argument is not supported in the current `np.size` implementation'
    )
  if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
    return 1
  x = asarray(x)
  if x.shape.is_fully_defined():
    return np.prod(x.shape.as_list(), dtype=int)
  else:
    return array_ops.size_v2(x)


@tf_export.tf_export('experimental.numpy.sum', v1=[])
@np_utils.np_doc('sum')
def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
  return _reduce(
      math_ops.reduce_sum,
      a,
      axis=axis,
      dtype=dtype,
      keepdims=keepdims,
      tf_bool_fn=math_ops.reduce_any,
  )


@tf_export.tf_export('experimental.numpy.prod', v1=[])
@np_utils.np_doc('prod')
def prod(a, axis=None, dtype=None, keepdims=None):
  return _reduce(
      math_ops.reduce_prod,
      a,
      axis=axis,
      dtype=dtype,
      keepdims=keepdims,
      tf_bool_fn=math_ops.reduce_all,
  )


@tf_export.tf_export('experimental.numpy.mean', v1=[])
@np_utils.np_doc('mean', unsupported_params=['out'])
def mean(a, axis=None, dtype=None, out=None, keepdims=None):
  if out is not None:
    raise ValueError('Setting out is not supported.')
  return _reduce(
      math_ops.reduce_mean,
      a,
      axis=axis,
      dtype=dtype,
      keepdims=keepdims,
      promote_int=_TO_FLOAT,
  )


@tf_export.tf_export('experimental.numpy.amax', v1=[])
@np_utils.np_doc('amax', unsupported_params=['out'])
def amax(a, axis=None, out=None, keepdims=None):
  if out is not None:
    raise ValueError('Setting out is not supported.')
  return _reduce(
      math_ops.reduce_max,
      a,
      axis=axis,
      dtype=None,
      keepdims=keepdims,
      promote_int=None,
      tf_bool_fn=math_ops.reduce_any,
      preserve_bool=True,
  )


@tf_export.tf_export('experimental.numpy.amin', v1=[])
@np_utils.np_doc('amin', unsupported_params=['out'])
def amin(a, axis=None, out=None, keepdims=None):
  if out is not None:
    raise ValueError('Setting out is not supported.')
  return _reduce(
      math_ops.reduce_min,
      a,
      axis=axis,
      dtype=None,
      keepdims=keepdims,
      promote_int=None,
      tf_bool_fn=math_ops.reduce_all,
      preserve_bool=True,
  )


@tf_export.tf_export('experimental.numpy.var', v1=[])
@np_utils.np_doc('var')
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):  # pylint: disable=missing-docstring
  if dtype:
    working_dtype = np_utils.result_type(a, dtype)
  else:
    working_dtype = None
  if out is not None:
    raise ValueError('Setting out is not supported.')
  if ddof != 0:
    # TF reduce_variance doesn't support ddof, so calculate it using raw ops.
    def reduce_fn(input_tensor, axis, keepdims):
      means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True)
      centered = input_tensor - means
      if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
        centered = math_ops.cast(
            math_ops.real(centered * math_ops.conj(centered)),
            input_tensor.dtype,
        )
      else:
        centered = math_ops.square(centered)
      squared_deviations = math_ops.reduce_sum(
          centered, axis=axis, keepdims=keepdims
      )

      if axis is None:
        n = array_ops.size(input_tensor)
      else:
        if axis < 0:
          axis += array_ops.rank(input_tensor)
        n = math_ops.reduce_prod(
            array_ops.gather(array_ops.shape(input_tensor), axis)
        )
      n = math_ops.cast(n - ddof, input_tensor.dtype)

      return math_ops.cast(math_ops.divide(squared_deviations, n), dtype)

  else:
    reduce_fn = math_ops.reduce_variance

  result = _reduce(
      reduce_fn,
      a,
      axis=axis,
      dtype=working_dtype,
      keepdims=keepdims,
      promote_int=_TO_FLOAT,
  )
  if dtype:
    result = math_ops.cast(result, dtype)
  return result


@tf_export.tf_export('experimental.numpy.std', v1=[])
@np_utils.np_doc('std')
def std(a, axis=None, keepdims=None):  # pylint: disable=missing-function-docstring
  return _reduce(
      math_ops.reduce_std,
      a,
      axis=axis,
      dtype=None,
      keepdims=keepdims,
      promote_int=_TO_FLOAT,
  )


@tf_export.tf_export('experimental.numpy.ravel', v1=[])
@np_utils.np_doc('ravel')
def ravel(a):  # pylint: disable=missing-docstring
  a = asarray(a)
  return array_ops.reshape(a, [-1])


@tf_export.tf_export('experimental.numpy.real', v1=[])
@np_utils.np_doc('real')
def real(val):
  val = asarray(val)
  # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
  # return an ndarray.
  return math_ops.real(val)


@tf_export.tf_export('experimental.numpy.repeat', v1=[])
@np_utils.np_doc('repeat')
def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
  a = asarray(a)
  original_shape = a._shape_as_list()  # pylint: disable=protected-access
  # Best effort recovery of the shape.
  known_shape = original_shape is not None and None not in original_shape
  if known_shape:
    if not original_shape:
      original_shape = (repeats,)
    else:
      repeats_np = np.ravel(np.array(repeats))
      if repeats_np.size == 1:
        repeats_np = repeats_np.item()
        if axis is None:
          original_shape = (repeats_np * np.prod(original_shape),)
        else:
          original_shape[axis] = repeats_np * original_shape[axis]
      else:
        if axis is None:
          original_shape = (repeats_np.sum(),)
        else:
          original_shape[axis] = repeats_np.sum()

  repeats = asarray(repeats)
  result = array_ops.repeat(a, repeats, axis)
  if known_shape:
    result.set_shape(original_shape)

  return result


@tf_export.tf_export('experimental.numpy.around', v1=[])
@np_utils.np_doc('around')
def around(a, decimals=0):  # pylint: disable=missing-docstring
  a = asarray(a)
  dtype = a.dtype.as_numpy_dtype
  factor = math.pow(10, decimals)
  if np.issubdtype(dtype, np.inexact):
    factor = math_ops.cast(factor, dtype)
  else:
    # Use float as the working dtype when a.dtype is exact (e.g. integer),
    # because `decimals` can be negative.
    float_dtype = np_utils.result_type(float)
    a = a.astype(float_dtype)
    factor = math_ops.cast(factor, float_dtype)
  a = math_ops.multiply(a, factor)
  a = math_ops.round(a)
  a = math_ops.divide(a, factor)
  return a.astype(dtype)


setattr(np_arrays.ndarray, '__round__', around)


@tf_export.tf_export('experimental.numpy.reshape', v1=[])
@np_utils.np_doc('reshape')
def reshape(a, newshape, order='C'):
  """order argument can only b 'C' or 'F'."""
  if order not in {'C', 'F'}:
    raise ValueError('Unsupported order argument {}'.format(order))

  a = asarray(a)
  if isinstance(newshape, int):
    newshape = [newshape]

  if order == 'F':
    r = array_ops.transpose(
        array_ops.reshape(array_ops.transpose(a), newshape[::-1])
    )
  else:
    r = array_ops.reshape(a, newshape)

  return r


def _reshape_method_wrapper(a, *newshape, **kwargs):
  order = kwargs.pop('order', 'C')
  if kwargs:
    raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))

  if len(newshape) == 1 and not isinstance(newshape[0], int):
    newshape = newshape[0]

  return reshape(a, newshape, order=order)


@tf_export.tf_export('experimental.numpy.expand_dims', v1=[])
@np_utils.np_doc('expand_dims')
def expand_dims(a, axis):
  a = asarray(a)
  return array_ops.expand_dims(a, axis=axis)


@tf_export.tf_export('experimental.numpy.squeeze', v1=[])
@np_utils.np_doc('squeeze')
def squeeze(a, axis=None):
  a = asarray(a)
  return array_ops.squeeze(a, axis)


@tf_export.tf_export('experimental.numpy.flatten', v1=[])
@np_utils.np_doc('flatten', link=np_utils.NoLink())
def flatten(a, order='C'):
  a = asarray(a)
  if order == 'C' or order == 'A' or order == 'K':
    # Row major.
    return array_ops.reshape(a, [-1])
  elif order == 'F':
    # Column major
    return array_ops.reshape(array_ops.transpose(a), [-1])
  else:
    raise ValueError(
        'order can only be C, A, K (all row major) or F (column major).'
    )


@tf_export.tf_export('experimental.numpy.transpose', v1=[])
@np_utils.np_doc('transpose')
def transpose(a, axes=None):
  a = asarray(a)
  if axes is not None:
    axes = asarray(axes)
  return array_ops.transpose(a=a, perm=axes)


@tf_export.tf_export('experimental.numpy.swapaxes', v1=[])
@np_utils.np_doc('swapaxes')
def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
  a = asarray(a)

  def adjust_axes(axes, rank):
    def f(x):
      if isinstance(x, int):
        if x < 0:
          x = x + rank
      else:
        x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
      return x

    return nest.map_structure(f, axes)

  if (
      a.shape.rank is not None
      and isinstance(axis1, int)
      and isinstance(axis2, int)
  ):
    # This branch makes sure `perm` is statically known, to avoid a
    # not-compile-time-constant XLA error.
    a_rank = a.shape.rank
    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
    perm = list(range(a_rank))
    perm[axis1] = axis2
    perm[axis2] = axis1
  else:
    a_rank = array_ops.rank(a)
    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
    perm = math_ops.range(a_rank)
    perm = array_ops.tensor_scatter_update(
        perm, [[axis1], [axis2]], [axis2, axis1]
    )
  a = array_ops.transpose(a, perm)
  return a


@tf_export.tf_export('experimental.numpy.moveaxis', v1=[])
@np_utils.np_doc('moveaxis')
def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
  """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
  if not source and not destination:
    return a

  a = asarray(a)

  if isinstance(source, int):
    source = (source,)
  if isinstance(destination, int):
    destination = (destination,)
  if len(source) != len(destination):
    raise ValueError('The lengths of source and destination must equal')

  a_rank = np_utils._maybe_static(array_ops.rank(a))  # pylint: disable=protected-access

  def _correct_axis(axis, rank):
    if axis < 0:
      return axis + rank
    return axis

  source = tuple(_correct_axis(axis, a_rank) for axis in source)
  destination = tuple(_correct_axis(axis, a_rank) for axis in destination)

  if a.shape.rank is not None:
    perm = [i for i in range(a_rank) if i not in source]
    for dest, src in sorted(zip(destination, source)):
      assert dest <= len(perm)
      perm.insert(dest, src)
  else:
    r = math_ops.range(a_rank)

    def _remove_indices(a, b):
      """Remove indices (`b`) from `a`."""
      items = array_ops_stack.unstack(
          sort_ops.sort(array_ops_stack.stack(b)), num=len(b)
      )

      i = 0
      result = []

      for item in items:
        result.append(a[i:item])
        i = item + 1

      result.append(a[i:])

      return array_ops.concat(result, 0)

    minus_sources = _remove_indices(r, source)
    minus_dest = _remove_indices(r, destination)

    perm = array_ops.scatter_nd(
        array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank]
    )
    perm = array_ops.tensor_scatter_update(
        perm, array_ops.expand_dims(destination, 1), source
    )
  a = array_ops.transpose(a, perm)

  return a


@tf_export.tf_export('experimental.numpy.pad', v1=[])
@np_utils.np_doc('pad')
def pad(array, pad_width, mode, **kwargs):  # pylint: disable=redefined-outer-name
  """Only supports modes 'constant', 'reflect' and 'symmetric' currently."""
  constant_values = kwargs.get('constant_values', 0)
  if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
    raise ValueError('Unsupported padding mode: ' + mode)
  mode = mode.upper()
  array = asarray(array)
  pad_width = asarray(pad_width, dtype=dtypes.int32)
  return array_ops.pad(
      tensor=array,
      paddings=pad_width,
      mode=mode,
      constant_values=constant_values,
  )


@tf_export.tf_export('experimental.numpy.take', v1=[])
@np_utils.np_doc('take')
def take(a, indices, axis=None, out=None, mode='clip'):
  """out argument is not supported, and default mode is clip."""
  if out is not None:
    raise ValueError('out argument is not supported in take.')

  if mode not in {'raise', 'clip', 'wrap'}:
    raise ValueError("Invalid mode '{}' for take".format(mode))

  a = asarray(a)
  indices = asarray(indices)

  if axis is None:
    a = array_ops.reshape(a, [-1])
    axis = 0

  axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
  if mode == 'clip':
    indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
  elif mode == 'wrap':
    indices = math_ops.floormod(indices, axis_size)
  else:
    raise ValueError("The 'raise' mode to take is not supported.")

  return array_ops.gather(a, indices, axis=axis)


@tf_export.tf_export('experimental.numpy.where', v1=[])
@np_utils.np_doc_only('where')
def where(condition, x=None, y=None):
  """Raises ValueError if exactly one of x or y is not None."""
  condition = asarray(condition, dtype=np.bool_)
  if x is None and y is None:
    return nonzero(condition)
  elif x is not None and y is not None:
    x, y = _promote_dtype(x, y)
    return array_ops.where_v2(condition, x, y)
  raise ValueError('Both x and y must be ndarrays, or both must be None.')


@tf_export.tf_export('experimental.numpy.select', v1=[])
@np_utils.np_doc('select')
def select(condlist, choicelist, default=0):  # pylint: disable=missing-docstring
  if len(condlist) != len(choicelist):
    msg = 'condlist must have length equal to choicelist ({} vs {})'
    raise ValueError(msg.format(len(condlist), len(choicelist)))
  if not condlist:
    raise ValueError('condlist must be non-empty')
  choices = _promote_dtype(default, *choicelist)
  choicelist = choices[1:]
  output = choices[0]
  # The traversal is in reverse order so we can return the first value in
  # choicelist where condlist is True.
  for cond, choice in zip(condlist[::-1], choicelist[::-1]):
    output = where(cond, choice, output)
  return output


@tf_export.tf_export('experimental.numpy.shape', v1=[])
@np_utils.np_doc(
    'shape',
    link=np_utils.Link(
        'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html'
    ),
)
def shape(a):
  a = asarray(a)
  return a.shape


@tf_export.tf_export('experimental.numpy.ndim', v1=[])
@np_utils.np_doc('ndim', link=np_utils.NoLink())
def ndim(a):
  a = asarray(a)
  return a.ndim


@tf_export.tf_export('experimental.numpy.isscalar', v1=[])
@np_utils.np_doc('isscalar')
def isscalar(num):
  return ndim(num) == 0


def _boundaries_to_sizes(a, boundaries, axis):
  """Converting boundaries of splits to sizes of splits.

  Args:
    a: the array to be split.
    boundaries: the boundaries, as in np.split.
    axis: the axis along which to split.

  Returns:
    A list of sizes of the splits, as in tf.split.
  """
  if axis >= len(a.shape):
    raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
  total_size = a.shape[axis]
  sizes = []
  sizes_sum = 0
  prev = 0
  for i, b in enumerate(boundaries):
    size = b - prev
    if size < 0:
      raise ValueError(
          'The %s-th boundary %s is smaller than the previous boundary %s'
          % (i, b, prev)
      )
    size = builtins.min(size, builtins.max(0, total_size - sizes_sum))
    sizes.append(size)
    sizes_sum += size
    prev = b
  sizes.append(builtins.max(0, total_size - sizes_sum))
  return sizes


@tf_export.tf_export('experimental.numpy.split', v1=[])
@np_utils.np_doc('split')
def split(ary, indices_or_sections, axis=0):
  ary = asarray(ary)
  if not isinstance(indices_or_sections, int):
    indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
  return array_ops.split(ary, indices_or_sections, axis=axis)


def _split_on_axis(np_fun_name, axis):  # pylint: disable=missing-function-docstring
  @np_utils.np_doc(np_fun_name)
  def f(ary, indices_or_sections):
    # for 1-D array, hsplit becomes vsplit
    new_axis = np_utils.cond(
        math_ops.equal(axis, 1),
        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
            math_ops.equal(array_ops.rank(ary), 1), lambda: 0, lambda: axis
        ),
        lambda: axis,
    )
    if isinstance(indices_or_sections, int):
      ary_shape = ary.shape[new_axis]
      if ary_shape is not None and ary_shape % indices_or_sections:
        raise ValueError('array split does not result in an equal division')
    return split(ary, indices_or_sections, axis=new_axis)

  return f


vsplit = tf_export.tf_export('experimental.numpy.vsplit', v1=[])(
    _split_on_axis('vsplit', axis=0)
)
hsplit = tf_export.tf_export('experimental.numpy.hsplit', v1=[])(
    _split_on_axis('hsplit', axis=1)
)
dsplit = tf_export.tf_export('experimental.numpy.dsplit', v1=[])(
    _split_on_axis('dsplit', axis=2)
)


@tf_export.tf_export('experimental.numpy.broadcast_to', v1=[])
@np_utils.np_doc('broadcast_to')
def broadcast_to(array, shape):  # pylint: disable=redefined-outer-name
  return full(shape, array)


@tf_export.tf_export('experimental.numpy.stack', v1=[])
@np_utils.np_doc('stack')
def stack(arrays, axis=0):  # pylint: disable=missing-function-docstring
  if isinstance(arrays, (np_arrays.ndarray, tensor_lib.Tensor)):
    arrays = asarray(arrays)
    if axis == 0:
      return arrays
    else:
      return swapaxes(arrays, 0, axis)
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
  ]
  return asarray(array_ops_stack.stack(unwrapped_arrays, axis))


@tf_export.tf_export('experimental.numpy.hstack', v1=[])
@np_utils.np_doc('hstack')
def hstack(tup):
  arrays = [atleast_1d(a) for a in tup]
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
  ]
  rank = array_ops.rank(unwrapped_arrays[0])
  return np_utils.cond(
      math_ops.equal(rank, 1),
      lambda: array_ops.concat(unwrapped_arrays, axis=0),
      lambda: array_ops.concat(unwrapped_arrays, axis=1),
  )


@tf_export.tf_export('experimental.numpy.vstack', v1=[])
@np_utils.np_doc('vstack')
def vstack(tup):
  arrays = [atleast_2d(a) for a in tup]
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
  ]
  return array_ops.concat(unwrapped_arrays, axis=0)


@tf_export.tf_export('experimental.numpy.dstack', v1=[])
@np_utils.np_doc('dstack')
def dstack(tup):
  arrays = [atleast_3d(a) for a in tup]
  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
  unwrapped_arrays = [
      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
  ]
  return array_ops.concat(unwrapped_arrays, axis=2)


def _pad_left_to(n, old_shape):
  old_shape = asarray(old_shape, dtype=np.int32)
  new_shape = array_ops.pad(
      old_shape,
      [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
      constant_values=1,
  )
  return asarray(new_shape)


def _atleast_nd(n, new_shape, *arys):
  """Reshape arrays to be at least `n`-dimensional.

  Args:
    n: The minimal rank.
    new_shape: a function that takes `n` and the old shape and returns the
      desired new shape.
    *arys: ndarray(s) to be reshaped.

  Returns:
    The reshaped array(s).
  """

  def f(x):
    # pylint: disable=g-long-lambda
    x = asarray(x)
    return asarray(
        np_utils.cond(
            np_utils.greater(n, array_ops.rank(x)),
            lambda: reshape(x, new_shape(n, array_ops.shape(x))),
            lambda: x,
        )
    )

  arys = list(map(f, arys))
  if len(arys) == 1:
    return arys[0]
  else:
    return arys


@tf_export.tf_export('experimental.numpy.atleast_1d', v1=[])
@np_utils.np_doc('atleast_1d')
def atleast_1d(*arys):
  return _atleast_nd(1, _pad_left_to, *arys)


@tf_export.tf_export('experimental.numpy.atleast_2d', v1=[])
@np_utils.np_doc('atleast_2d')
def atleast_2d(*arys):
  return _atleast_nd(2, _pad_left_to, *arys)


@tf_export.tf_export('experimental.numpy.atleast_3d', v1=[])
@np_utils.np_doc('atleast_3d')
def atleast_3d(*arys):  # pylint: disable=missing-docstring
  def new_shape(_, old_shape):
    # pylint: disable=g-long-lambda
    ndim_ = array_ops.size(old_shape)
    return np_utils.cond(
        math_ops.equal(ndim_, 0),
        lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32),
        lambda: np_utils.cond(
            math_ops.equal(ndim_, 1),
            lambda: array_ops.pad(old_shape, [[1, 1]], constant_values=1),
            lambda: array_ops.pad(old_shape, [[0, 1]], constant_values=1),
        ),
    )

  return _atleast_nd(3, new_shape, *arys)


@tf_export.tf_export('experimental.numpy.nonzero', v1=[])
@np_utils.np_doc('nonzero')
def nonzero(a):
  a = atleast_1d(a)
  if a.shape.rank is None:
    raise ValueError(
        "The rank of `a` is unknown, so we can't decide how many "
        'arrays to return.'
    )
  return array_ops_stack.unstack(
      array_ops.where_v2(math_ops.cast(a, dtypes.bool)), a.shape.rank, axis=1
  )


@tf_export.tf_export('experimental.numpy.diag_indices', v1=[])
@np_utils.np_doc('diag_indices')
def diag_indices(n, ndim=2):  # pylint: disable=missing-docstring,redefined-outer-name
  if n < 0:
    raise ValueError(
        'n argument to diag_indices must be nonnegative, got {}'.format(n)
    )
  if ndim < 0:
    raise ValueError(
        'ndim argument to diag_indices must be nonnegative, got {}'.format(ndim)
    )

  return (math_ops.range(n),) * ndim


@tf_export.tf_export('experimental.numpy.tri', v1=[])
@np_utils.np_doc('tri')
def tri(N, M=None, k=0, dtype=None):  # pylint: disable=invalid-name,missing-docstring
  M = M if M is not None else N
  if dtype is not None:
    dtype = np_utils.result_type(dtype)
  else:
    # Use a default float type.
    dtype = np_utils.result_type(float)

  if k < 0:
    lower = -k - 1
    if lower > N:
      r = array_ops.zeros([N, M], dtype)
    else:
      # Keep as tf bool, since we create an upper triangular matrix and invert
      # it.
      o = array_ops.ones([N, M], dtype=dtypes.bool)
      r = math_ops.cast(
          math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype
      )
  else:
    o = array_ops.ones([N, M], dtype)
    if k > M:
      r = o
    else:
      r = array_ops.matrix_band_part(o, -1, k)
  return r


@tf_export.tf_export('experimental.numpy.tril', v1=[])
@np_utils.np_doc('tril')
def tril(m, k=0):  # pylint: disable=missing-docstring
  m = asarray(m)
  if m.shape.ndims is None:
    raise ValueError('Argument to tril should have known rank')
  m_shape = m.shape.as_list()

  if len(m_shape) < 2:
    raise ValueError('Argument to tril must have rank at least 2')

  if m_shape[-1] is None or m_shape[-2] is None:
    raise ValueError(
        'Currently, the last two dimensions of the input array '
        'need to be known.'
    )

  z = constant_op.constant(0, m.dtype)

  mask = tri(*m_shape[-2:], k=k, dtype=bool)
  return array_ops.where_v2(
      array_ops.broadcast_to(mask, array_ops.shape(m)), m, z
  )


@tf_export.tf_export('experimental.numpy.triu', v1=[])
@np_utils.np_doc('triu')
def triu(m, k=0):  # pylint: disable=missing-docstring
  m = asarray(m)
  if m.shape.ndims is None:
    raise ValueError('Argument to triu should have known rank')
  m_shape = m.shape.as_list()

  if len(m_shape) < 2:
    raise ValueError('Argument to triu must have rank at least 2')

  if m_shape[-1] is None or m_shape[-2] is None:
    raise ValueError(
        'Currently, the last two dimensions of the input array '
        'need to be known.'
    )

  z = constant_op.constant(0, m.dtype)

  mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
  return array_ops.where_v2(
      array_ops.broadcast_to(mask, array_ops.shape(m)), z, m
  )


@tf_export.tf_export('experimental.numpy.flip', v1=[])
@np_utils.np_doc('flip')
def flip(m, axis=None):  # pylint: disable=missing-docstring
  m = asarray(m)

  if axis is None:
    return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))

  axis = np_utils._canonicalize_axis(axis, array_ops.rank(m))  # pylint: disable=protected-access

  return array_ops.reverse(m, [axis])


@tf_export.tf_export('experimental.numpy.flipud', v1=[])
@np_utils.np_doc('flipud')
def flipud(m):  # pylint: disable=missing-docstring
  return flip(m, 0)


@tf_export.tf_export('experimental.numpy.fliplr', v1=[])
@np_utils.np_doc('fliplr')
def fliplr(m):  # pylint: disable=missing-docstring
  return flip(m, 1)


@tf_export.tf_export('experimental.numpy.roll', v1=[])
@np_utils.np_doc('roll')
def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
  a = asarray(a)

  if axis is not None:
    return manip_ops.roll(a, shift, axis)

  # If axis is None, the roll happens as a 1-d tensor.
  original_shape = array_ops.shape(a)
  a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
  return array_ops.reshape(a, original_shape)


@tf_export.tf_export('experimental.numpy.rot90', v1=[])
@np_utils.np_doc('rot90')
def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
  m_rank = array_ops.rank(m)
  ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank)  # pylint: disable=protected-access

  k = k % 4
  if k == 0:
    return m
  elif k == 2:
    return flip(flip(m, ax1), ax2)
  else:
    perm = math_ops.range(m_rank)
    perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])

    if k == 1:
      return transpose(flip(m, ax2), perm)
    else:
      return flip(transpose(m, perm), ax2)


@tf_export.tf_export('experimental.numpy.vander', v1=[])
@np_utils.np_doc('vander')
def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,invalid-name
  x = asarray(x)

  x_shape = array_ops.shape(x)
  if N is None:
    N = x_shape[0]

  N_temp = np_utils.get_static_value(N)  # pylint: disable=invalid-name
  if N_temp is not None:
    N = N_temp
    if N < 0:
      raise ValueError('N must be nonnegative')
  else:
    control_flow_assert.Assert(N >= 0, [N])

  rank = array_ops.rank(x)
  rank_temp = np_utils.get_static_value(rank)
  if rank_temp is not None:
    rank = rank_temp
    if rank != 1:
      raise ValueError('x must be a one-dimensional array')
  else:
    control_flow_assert.Assert(math_ops.equal(rank, 1), [rank])

  if increasing:
    start = 0
    limit = N
    delta = 1
  else:
    start = N - 1
    limit = -1
    delta = -1

  x = array_ops.expand_dims(x, -1)
  return math_ops.pow(
      x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)
  )


@tf_export.tf_export('experimental.numpy.ix_', v1=[])
@np_utils.np_doc('ix_')
def ix_(*args):  # pylint: disable=missing-docstring
  n = len(args)
  output = []
  for i, a in enumerate(args):
    a = asarray(a)
    a_rank = array_ops.rank(a)
    a_rank_temp = np_utils.get_static_value(a_rank)
    if a_rank_temp is not None:
      a_rank = a_rank_temp
      if a_rank != 1:
        raise ValueError(
            'Arguments must be 1-d, got arg {} of rank {}'.format(i, a_rank)
        )
    else:
      control_flow_assert.Assert(math_ops.equal(a_rank, 1), [a_rank])

    new_shape = [1] * n
    new_shape[i] = -1
    dtype = a.dtype
    if dtype == dtypes.bool:
      output.append(array_ops.reshape(nonzero(a)[0], new_shape))
    elif dtype.is_integer:
      output.append(array_ops.reshape(a, new_shape))
    else:
      raise ValueError(
          'Only integer and bool dtypes are supported, got {}'.format(dtype)
      )

  return output


@tf_export.tf_export('experimental.numpy.broadcast_arrays', v1=[])
@np_utils.np_doc('broadcast_arrays')
def broadcast_arrays(*args, **kwargs):  # pylint: disable=missing-docstring
  subok = kwargs.pop('subok', False)
  if subok:
    raise ValueError('subok=True is not supported.')
  if kwargs:
    raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))

  args = [asarray(arg) for arg in args]
  return np_utils.tf_broadcast(*args)


@tf_export.tf_export('experimental.numpy.sign', v1=[])
@np_utils.np_doc_only('sign')
def sign(x, out=None, where=None, **kwargs):  # pylint: disable=missing-docstring,redefined-outer-name
  if out:
    raise ValueError('tf.numpy doesnt support setting out.')
  if where:
    raise ValueError('tf.numpy doesnt support setting where.')
  if kwargs:
    raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))

  x = asarray(x)
  dtype = x.dtype.as_numpy_dtype
  if np.issubdtype(dtype, np.complexfloating):
    result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
  else:
    result = math_ops.sign(x)

  return result


# Note that np.take_along_axis may not be present in some supported versions of
# numpy.
@tf_export.tf_export('experimental.numpy.take_along_axis', v1=[])
@np_utils.np_doc('take_along_axis')
def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
  arr = asarray(arr)
  indices = asarray(indices)

  if axis is None:
    return take_along_axis(arr.ravel(), indices, 0)

  rank = array_ops.rank(arr)
  axis = axis + rank if axis < 0 else axis

  # Broadcast shapes to match, ensure that the axis of interest is not
  # broadcast.
  arr_shape_original = array_ops.shape(arr, out_type=indices.dtype)
  indices_shape_original = array_ops.shape(indices, out_type=indices.dtype)
  arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
  indices_shape = array_ops.tensor_scatter_update(
      indices_shape_original, [[axis]], [1]
  )
  broadcasted_shape = array_ops.broadcast_dynamic_shape(
      arr_shape, indices_shape
  )
  arr_shape = array_ops.tensor_scatter_update(
      broadcasted_shape, [[axis]], [arr_shape_original[axis]]
  )
  indices_shape = array_ops.tensor_scatter_update(
      broadcasted_shape, [[axis]], [indices_shape_original[axis]]
  )
  arr = array_ops.broadcast_to(arr, arr_shape)
  indices = array_ops.broadcast_to(indices, indices_shape)

  # Save indices shape so we can restore it later.
  possible_result_shape = indices.shape

  # Correct indices since gather doesn't correctly handle negative indices.
  indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)

  swapaxes_ = lambda t: swapaxes(t, axis, -1)

  dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
  arr = np_utils.cond(
      dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr)
  )
  indices = np_utils.cond(
      dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices)
  )

  arr_shape = array_ops.shape(arr)
  arr = array_ops.reshape(arr, [-1, arr_shape[-1]])

  indices_shape = array_ops.shape(indices)
  indices = array_ops.reshape(indices, [-1, indices_shape[-1]])

  result = array_ops.gather(arr, indices, batch_dims=1)
  result = array_ops.reshape(result, indices_shape)
  result = np_utils.cond(
      dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result)
  )
  result.set_shape(possible_result_shape)

  return result


# pylint: disable=redefined-builtin,undefined-variable
@tf_export.tf_export('experimental.numpy.max', v1=[])
@np_utils.np_doc('max', link=np_utils.AliasOf('amax'))
def max(a, axis=None, keepdims=None):
  return amax(a, axis=axis, keepdims=keepdims)


@tf_export.tf_export('experimental.numpy.min', v1=[])
@np_utils.np_doc('min', link=np_utils.AliasOf('amin'))
def min(a, axis=None, keepdims=None):
  return amin(a, axis=axis, keepdims=keepdims)


@tf_export.tf_export('experimental.numpy.round', v1=[])
@np_utils.np_doc('round', link=np_utils.AliasOf('around'))
def round(a, decimals=0):
  return around(a, decimals=decimals)


# pylint: enable=redefined-builtin,undefined-variable


_SLICE_ERROR = (
    'only integers, slices (`:`), ellipsis (`...`), '
    'numpy.newaxis (`None`) and integer or boolean arrays are valid indices'
)


def _as_index(idx, need_scalar=True):
  """Helper function to parse idx as an index.

  Args:
    idx: index
    need_scalar: If idx needs to be a scalar value.

  Returns:
    A pair, (indx, bool). First one is the parsed index and can be a tensor,
    or scalar integer / Dimension. Second one is True if rank is known to be 0.

  Raises:
    IndexError: For incorrect indices.
  """
  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
    return idx, True
  data = asarray(idx)
  if data.dtype == dtypes.bool:
    if data.shape.ndims != 1:
      # TODO(agarwal): handle higher rank boolean masks.
      raise NotImplementedError('Need rank 1 for bool index %s' % idx)
    data = array_ops.where_v2(data)
    data = array_ops.reshape(data, [-1])
  if need_scalar and data.shape.rank not in (None, 0):
    raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx))
  np_dtype = data.dtype.as_numpy_dtype
  if not np.issubdtype(np_dtype, np.integer):
    raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx))
  if data.dtype not in (dtypes.int64, dtypes.int32):
    # TF slicing can only handle int32/int64. So we need to cast.
    promoted_dtype = np.promote_types(np.int32, np_dtype)
    if promoted_dtype == np.int32:
      data = math_ops.cast(data, dtypes.int32)
    elif promoted_dtype == np.int64:
      data = math_ops.cast(data, dtypes.int64)
    else:
      raise IndexError(_SLICE_ERROR + ', got {!r}'.format(idx))
  return data, data.shape.rank == 0


class _UpdateMethod(enum.Enum):
  UPDATE = 0
  ADD = 1
  MIN = 2
  MAX = 3


def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
  """Helper function for __getitem__ and _with_index_update_helper.

  This function collects the indices in `slice_spec` into two buckets, which we
  can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2
  `gather`.  They also correspond to "basic indices" and "advanced indices" in
  numpy.  This function supports both reading and writing at the indices. The
  reading path can be summarized as `gather(stride_slice(tensor, idx1),
  idx2)`. The writing path can be summarized as `strided_slice_update(tensor,
  idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`.  (`gather` here
  means `tf.gather` or `tf.gather_nd`; `scatter` here means
  `tf.tensor_scatter_update`.)  The writing path is inefficient because it needs
  to first read out a portion (probably much larger than `updates`) of `tensor`
  using `strided_slice`, update it, and then write the portion back. An
  alternative approach is to only use `scatter`, which amounts to using the
  indexing mechanism of gather/scatter to implement
  strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter
  because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but
  not TF gather/scatter because they don't support spans (except those that
  cover entire dimensions, i.e. `:`).  If we materialize spans into individual
  indices, the size of the index tensor would explode.  (Note that XLA
  Gather/Scatter have a similar problem for stride > 1 because they don't
  support strides.  Indices such as `1:2:8` will need to be materialized into
  individual indices such as [1, 3, 5, 7].)

  Args:
    tensor: the tensor to be read from or write into.
    slice_spec: the indices.
    update_method: (optional) a member of `_UpdateMethod`, indicating how to
      update the values (replacement, add, etc.). `None` indicates just reading.
    updates: (optional) the new values to write into `tensor`. It must have the
      same dtype as `tensor`.

  Returns:
    The result of reading (if `update_method` is `None`) or the updated `tensor`
    after writing.
  """
  begin, end, strides = [], [], []
  new_axis_mask, shrink_axis_mask = 0, 0
  begin_mask, end_mask = 0, 0
  ellipsis_mask = 0
  advanced_indices = []
  shrink_indices = []
  for index, s in enumerate(slice_spec):
    if isinstance(s, slice):
      if s.start is not None:
        begin.append(_as_index(s.start)[0])
      else:
        begin.append(0)
        begin_mask |= 1 << index
      if s.stop is not None:
        end.append(_as_index(s.stop)[0])
      else:
        end.append(0)
        end_mask |= 1 << index
      if s.step is not None:
        strides.append(_as_index(s.step)[0])
      else:
        strides.append(1)
    elif s is Ellipsis:
      begin.append(0)
      end.append(0)
      strides.append(1)
      ellipsis_mask |= 1 << index
    elif s is array_ops.newaxis:
      begin.append(0)
      end.append(0)
      strides.append(1)
      new_axis_mask |= 1 << index
    else:
      s, is_scalar = _as_index(s, False)
      if is_scalar:
        begin.append(s)
        end.append(s + 1)
        strides.append(1)
        shrink_axis_mask |= 1 << index
        shrink_indices.append(index)
      else:
        begin.append(0)
        end.append(0)
        strides.append(1)
        begin_mask |= 1 << index
        end_mask |= 1 << index
        advanced_indices.append((index, s, ellipsis_mask != 0))

  # stack possibly involves no tensors, so we must use op_scope correct graph.
  with ops.name_scope(
      None,
      'strided_slice',
      [tensor] + begin + end + strides,
      skip_on_eager=False,
  ) as name:
    if begin:
      packed_begin, packed_end, packed_strides = (
          array_ops_stack.stack(begin),
          array_ops_stack.stack(end),
          array_ops_stack.stack(strides),
      )
      if (
          packed_begin.dtype == dtypes.int64
          or packed_end.dtype == dtypes.int64
          or packed_strides.dtype == dtypes.int64
      ):
        if packed_begin.dtype != dtypes.int64:
          packed_begin = math_ops.cast(packed_begin, dtypes.int64)
        if packed_end.dtype != dtypes.int64:
          packed_end = math_ops.cast(packed_end, dtypes.int64)
        if packed_strides.dtype != dtypes.int64:
          packed_strides = math_ops.cast(packed_strides, dtypes.int64)
    else:
      var_empty = constant_op.constant([], dtype=dtypes.int32)
      packed_begin = packed_end = packed_strides = var_empty
    if update_method == _UpdateMethod.UPDATE and not advanced_indices:
      return array_ops.tensor_strided_slice_update(
          tensor,
          packed_begin,
          packed_end,
          packed_strides,
          updates,
          begin_mask=begin_mask,
          end_mask=end_mask,
          shrink_axis_mask=shrink_axis_mask,
          new_axis_mask=new_axis_mask,
          ellipsis_mask=ellipsis_mask,
          name=name,
      )
    else:
      # TODO(b/164251540): Find a better way to support update that does not
      #   involve one read + two writes.
      if updates is not None:
        original_tensor = tensor
      # TODO(agarwal): set_shape on tensor to set rank.
      tensor = array_ops.strided_slice(
          tensor,
          packed_begin,
          packed_end,
          packed_strides,
          begin_mask=begin_mask,
          end_mask=end_mask,
          shrink_axis_mask=shrink_axis_mask,
          new_axis_mask=new_axis_mask,
          ellipsis_mask=ellipsis_mask,
          name=name,
      )
    if not advanced_indices:
      if update_method is None:
        return tensor
      assert update_method != _UpdateMethod.UPDATE
      # TF lacks TensorStridedSliceAdd and alike, so we need to do
      # read+add+update.
      if update_method == _UpdateMethod.ADD:
        update_op = math_ops.add
      elif update_method == _UpdateMethod.MIN:
        update_op = math_ops.minimum
      elif update_method == _UpdateMethod.MAX:
        update_op = math_ops.maximum
      return array_ops.tensor_strided_slice_update(
          original_tensor,
          packed_begin,
          packed_end,
          packed_strides,
          update_op(tensor, updates),
          begin_mask=begin_mask,
          end_mask=end_mask,
          shrink_axis_mask=shrink_axis_mask,
          new_axis_mask=new_axis_mask,
          ellipsis_mask=ellipsis_mask,
          name=name + '_2',
      )
    advanced_indices_map = {}
    for index, data, had_ellipsis in advanced_indices:
      if had_ellipsis:
        num_shrink = len([x for x in shrink_indices if x > index])
        dim = index - len(slice_spec) + num_shrink
      else:
        num_shrink = len([x for x in shrink_indices if x < index])
        dim = index - num_shrink
      advanced_indices_map[dim] = data
    dims = sorted(advanced_indices_map.keys())
    dims_contiguous = True
    if len(dims) > 1:
      if dims[0] < 0 and dims[-1] >= 0:  # not all same sign
        dims_contiguous = False
      else:
        for i in range(len(dims) - 1):
          if dims[i] + 1 != dims[i + 1]:
            dims_contiguous = False
            break
    indices = [advanced_indices_map[x] for x in dims]
    indices = _promote_dtype(*indices)
    indices = np_utils.tf_broadcast(*indices)
    stacked_indices = array_ops_stack.stack(indices, axis=-1)
    # Skip the contiguous-dims optimization for update because there is no
    # tf.*scatter* op that supports the `axis` argument.
    if not dims_contiguous or updates is not None:
      if range(len(dims)) != dims:
        tensor = moveaxis(tensor, dims, range(len(dims)))
      tensor_shape_prefix = array_ops.shape(
          tensor, out_type=stacked_indices.dtype
      )[: len(dims)]
      stacked_indices = array_ops.where_v2(
          stacked_indices < 0,
          stacked_indices + tensor_shape_prefix,
          stacked_indices,
      )
      if updates is None:
        return array_ops.gather_nd(tensor, stacked_indices)
      else:
        # We only need to move-axis `updates` in the contiguous case becausce
        # only in this case the result dimensions of advanced indexing are in
        # the middle of `updates`. In the non-contiguous case, those dimensions
        # are always at the front.
        if dims_contiguous:
          # TODO(wangpeng): Support unknown rank (e.g. by partially flattening
          #   `updates`)
          if stacked_indices.shape.rank is None:
            raise NotImplementedError(
                'Rank of the advanced indices must currently be known'
            )
          batch_size = stacked_indices.shape.rank - 1
          batch_start = dims[0]
          if batch_start < 0:
            batch_start += len(dims) - batch_size

          def range_(start, length):
            return range(start, start + length)

          updates = moveaxis(
              updates, range_(batch_start, batch_size), range(batch_size)
          )
        if update_method == _UpdateMethod.UPDATE:
          update_op = array_ops.tensor_scatter_update
        elif update_method == _UpdateMethod.ADD:
          update_op = array_ops.tensor_scatter_add
        elif update_method == _UpdateMethod.MIN:
          update_op = array_ops.tensor_scatter_min
        elif update_method == _UpdateMethod.MAX:
          update_op = array_ops.tensor_scatter_max
        tensor = update_op(tensor, stacked_indices, updates)
        if range(len(dims)) != dims:
          tensor = moveaxis(tensor, range(len(dims)), dims)
        return array_ops.tensor_strided_slice_update(
            original_tensor,
            packed_begin,
            packed_end,
            packed_strides,
            tensor,
            begin_mask=begin_mask,
            end_mask=end_mask,
            shrink_axis_mask=shrink_axis_mask,
            new_axis_mask=new_axis_mask,
            ellipsis_mask=ellipsis_mask,
            name=name + '_2',
        )
    # Note that gather_nd does not support gathering from inside the array.
    # To avoid shuffling data back and forth, we transform the indices and
    # do a gather instead.
    rank = np_utils._maybe_static(array_ops.rank(tensor))  # pylint: disable=protected-access
    dims = [(x + rank if x < 0 else x) for x in dims]
    shape_tensor = array_ops.shape(tensor)
    dim_sizes = array_ops.gather(shape_tensor, dims)
    if len(dims) == 1:
      stacked_indices = indices[0]
    stacked_indices = math_ops.cast(stacked_indices, dtypes.int32)
    stacked_indices = array_ops.where_v2(
        stacked_indices < 0, stacked_indices + dim_sizes, stacked_indices
    )
    axis = dims[0]
    if len(dims) > 1:
      index_scaling = math_ops.cumprod(dim_sizes, reverse=True, exclusive=True)

      def _tensordot(a, b):
        # TODO(b/168657656): This function should be replaced by
        # tensordot(axis=1) once MatMul has int32 XLA kernel.
        b = array_ops.broadcast_to(b, array_ops.shape(a))
        return math_ops.reduce_sum(a * b, axis=-1)

      stacked_indices = _tensordot(stacked_indices, index_scaling)
      flat_shape = array_ops.concat(
          [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims) :]], axis=0
      )
      tensor = array_ops.reshape(tensor, flat_shape)

    return array_ops.gather(tensor, stacked_indices, axis=axis)


def _as_spec_tuple(slice_spec):
  """Convert slice_spec to tuple."""
  if isinstance(slice_spec, (list, tuple)) and not isinstance(
      slice_spec, np.ndarray
  ):
    is_index = True
    for s in slice_spec:
      if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)):
        is_index = False
        break
      elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
        is_index = False
        break
    if not is_index:
      return tuple(slice_spec)
  return (slice_spec,)


def _getitem(self, slice_spec):
  """Implementation of ndarray.__getitem__."""
  if (
      isinstance(slice_spec, bool)
      or (
          isinstance(slice_spec, core_tf_types.Tensor)
          and slice_spec.dtype == dtypes.bool
      )
      or (
          isinstance(slice_spec, (np.ndarray, np_arrays.ndarray))
          and slice_spec.dtype == np.bool_
      )
  ):
    return array_ops.boolean_mask(tensor=self, mask=slice_spec)

  if not isinstance(slice_spec, tuple):
    slice_spec = _as_spec_tuple(slice_spec)

  result_t = _slice_helper(self, slice_spec)
  return result_t


def _with_index_update_helper(update_method, a, slice_spec, updates):
  """Implementation of ndarray._with_index_*."""
  if (
      isinstance(slice_spec, bool)
      or (
          isinstance(slice_spec, core_tf_types.Tensor)
          and slice_spec.dtype == dtypes.bool
      )
      or (
          isinstance(slice_spec, (np.ndarray, np_arrays.ndarray))
          and slice_spec.dtype == np.bool_
      )
  ):
    slice_spec = nonzero(slice_spec)

  if not isinstance(slice_spec, tuple):
    slice_spec = _as_spec_tuple(slice_spec)

  a_dtype = a.dtype
  a, updates = _promote_dtype_binary(a, updates)
  result_t = _slice_helper(a, slice_spec, update_method, updates)
  return result_t.astype(a_dtype)


setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
setattr(
    np_arrays.ndarray,
    '_with_index_update',
    functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE),
)
setattr(
    np_arrays.ndarray,
    '_with_index_add',
    functools.partial(_with_index_update_helper, _UpdateMethod.ADD),
)
setattr(
    np_arrays.ndarray,
    '_with_index_min',
    functools.partial(_with_index_update_helper, _UpdateMethod.MIN),
)
setattr(
    np_arrays.ndarray,
    '_with_index_max',
    functools.partial(_with_index_update_helper, _UpdateMethod.MAX),
)