syntax = "proto3";

package xla;

option cc_enable_arenas = true;

// Primitive types are the individual values that can be held in rectangular
// multidimensional arrays. A description of the rectangular multidimensional
// array dimensions / primitive type is given by Shape, below.
// LINT.IfChange
enum PrimitiveType {
  // Invalid primitive type to serve as default.

  // Predicates are two-state booleans.
  PRED = 1;

  // Signed integral values of fixed width.
  S2 = 26;
  S4 = 21;
  S8 = 2;
  S16 = 3;
  S32 = 4;
  S64 = 5;

  // Unsigned integral values of fixed width.
  U2 = 27;
  U4 = 22;
  U8 = 6;
  U16 = 7;
  U32 = 8;
  U64 = 9;

  // Floating-point values of fixed width.
  // Note: if f16s are not natively supported on the device, they will be
  // converted to f16 from f32 at arbirary points in the computation.
  F16 = 10;
  F32 = 11;

  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
  // and 7 bits for the mantissa.
  BF16 = 16;

  F64 = 12;

  // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433
  // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the
  // existing IEEE types.
  // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only
  // Finite and NaN values are supported. Unlike IEEE types, infinities are not
  // supported.  NaN is represented when the exponent and mantissa bits are all
  // 1s. All other values are finite.
  // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The
  // "FNUZ" means only Finite and NaN values are supported; zero is unsigned.
  // Unlike IEEE types, infinities are not supported.  NaN is represented when
  // the exponent and mantissa bits are all 0s with a sign bit of 1. All other
  // values are finite.
  // Support for these dtypes is under development. They do not yet work
  // properly in most cases.
  // TODO(b/259609697): Fully support FP8.
  F8E5M2 = 19;
  F8E4M3FN = 20;
  F8E4M3B11FNUZ = 23;

  // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
  // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
  // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
  // The "FNUZ" means only Finite and NaN values are supported; zero is
  // unsigned. Unlike IEEE types, infinities are not supported.  NaN is
  // represented when the exponent and mantissa bits are all 0s with a sign bit
  // of 1. All other values are finite.
  // These differences mean there's an additional exponent value available. To
  // keep the same dynamic range as an IEEE-like FP8 type, the exponent is
  // biased one more than would be expected given the number of exponent bits
  // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
  F8E5M2FNUZ = 24;
  F8E4M3FNUZ = 25;

  // Complex values of fixed width.
  C64 = 15;   // Paired F32 (real, imag), as in std::complex<float>.
  C128 = 18;  // Paired F64 (real, imag), as in std::complex<double>.

  // A tuple is a polymorphic sequence; e.g. a shape that holds different
  // sub-shapes. They are used for things like returning multiple values from a
  // computation; e.g. a computation that returns weights and biases may have a
  // signature that results in a tuple like (f32[784x2000], f32[2000])
  // If a shape proto has the tuple element type, it may not have any entries
  // in the dimensions field.
  TUPLE = 13;

  // An opaque type used for passing context-specific data to a custom
  // operation. Shapes of this primitive type will have empty dimensions and
  // tuple_shapes fields.
  // (OPAQUE would be a better name for this identifier, but that conflicts with
  // a macro defined in windows.h.)

  // A token type threaded between side-effecting operations. Shapes of this
  // primitive type will have empty dimensions and tuple_shapes fields.
  TOKEN = 17;

// Describes the padding configuration for Pad operation. The padding amount on
// both edges as well as between the elements are specified for each dimension.
message PaddingConfig {
  // Describes the padding configuration for a dimension.
  message PaddingConfigDimension {
    // Padding amount on the low-end (next to the index 0). May be negative.
    int64 edge_padding_low = 1;

    // Padding amount on the high-end (next to the highest index). May be
    // negative.
    int64 edge_padding_high = 2;

    // Padding amount between the elements. May not be negative.
    int64 interior_padding = 3;

  // The padding configuration for all dimensions.
  repeated PaddingConfigDimension dimensions = 1;

// A DimLevelType indicates the encoding method for a dimension in an array.
// The semantics of this field are identical to those of the MLIR SparseTensor
// dialect.
// This should be kept in sync with the SparseTensor DimLevelType enum:
// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86
enum DimLevelType {
  // The corresponding dimension is Dense, every entry is stored.
  DIM_DENSE = 0;
  // The corresponding dimension is Compressed, only nonzeros are stored.
  // The corresponding dimension contains a single coordinate, no sibling
  // elements for each parent.
  // The corresponding dimension is Compressed, but with potential trailing
  // zeros, thus an extra upper bound (high) is used to exclude those zeros.
  // E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)].

// Describes a tile used in tiling-based layout. Refer to
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
// details about tiling-based layout.
message TileProto {
  // Number of elements in each dimension of the tile. It's ordered from the
  // most major dimension of the tile to the most minor dimension of the tile.
  // The dimensions correspond to a suffix of the dimensions of the shape being
  // tiled.
  repeated int64 dimensions = 1;

// Describes how data should be split between different memories.
message SplitConfigProto {
  // The dimension that is split.
  int64 dimension = 1;
  // The indices where each split point occurs. For example, if the dimension
  // size is 1024, a split_indices value of {512} indicates a two-way split of
  // data through the middle.
  repeated int64 split_indices = 2;

// A layout describes how the array is placed in (1D) memory space.  This
// includes the minor-to-major ordering of dimensions within a shape.
// Clients must specify the layouts of input Literals to the
// computation. Layouts specified in interior operations which take Shapes (for
// example, Convert) are ignored.
// See the XLA documentation for more information on shapes and layouts.
// LINT.IfChange
message LayoutProto {
  // The dimension level type list for this array, specifying the way in which
  // each array dimension is represented in memory. If this list is empty, the
  // array is assumed to be dense.
  repeated DimLevelType dim_level_types = 9;

  // Whether each dimension is unique or ordered.  Each of the following lists
  // must be empty, or have one entry for each entry of dim_level_types.  If
  // either list is empty, all dimensions are assumed to be unique and ordered,
  // respectively.  Entries in this list may not be false for some DimLevelType
  // values (such as DIM_DENSE in particular).
  repeated bool dim_unique = 13;
  repeated bool dim_ordered = 14;

  // Sequence of dimension numbers, from minor (fastest varying index) to major
  // (slowest varying index). This field is required.
  repeated int64 minor_to_major = 1;

  // A sequence of tiles, starting from the tile that's applied first to the
  // Shape.
  // TODO(b/119839262): implement tiling in each backend or add Unimplemented
  // error.
  repeated TileProto tiles = 6;

  // The shape is padded at the end to multiple of, in terms of number of
  // elements. This is useful when tiling does not bring the shape to certain
  // desired granules. Tiling effectively pads/reshapes/transposes the shape
  // to another shape. This field pads the total number of elements of that
  // new shape to a multiple of certain number of elements. This is useful such
  // as we want a layout which does not tile the data but still requires it to
  // be padded to certain number of elements.
  int64 tail_padding_alignment_in_elements = 16;

  // (Optional) Bit size of each element. When unspecified or being 0, default
  // to ShapeUtil::ByteSizeOfPrimitiveType.
  int64 element_size_in_bits = 7;

  // Memory space where this array resides. The integer field is interpreted in
  // a backend-specific manner.
  int64 memory_space = 8;

  // The integer types to be used for indices and pointers.  These fields must
  // not be used unless the layout represents a sparse array.  The PrimitiveType
  // must correspond to an unsigned integer (U8, U16, U32, or U64).
  // If not provided, the compiler will use the largest unsigned integer
  // that is naturally supported by the target device (U32 or U64 in currently
  // supported devices).
  PrimitiveType index_primitive_type = 11;
  PrimitiveType pointer_primitive_type = 12;

  // The physical, on-device shape used to represent the shape this layout
  // belongs to. Only used for sparse arrays.
  // The layout(s) contained within the physical shape should not also contain
  // a physical shape.
  ShapeProto physical_shape = 10;

  // The dynamic shape metadata size in bytes in front of the shape data. The
  // field may be non-zero for a static shape whose associated buffer is for a
  // dynamic shape, e.g. a result of SliceToDynamic.
  int64 dynamic_shape_metadata_prefix_bytes = 15;

  // The split configurations which describe if/how the data is split between
  // different memories.
  repeated SplitConfigProto split_configs = 17;

// A shape describes the number of dimensions in the array, the size of each
// dimension, and the primitive component type.
// Tuples are a special case in that they have rank zero and have tuple_shapes
// defined.
// See the XLA documentation for more information on shapes and layouts.
message ShapeProto {
  // The element type for this shape.
  PrimitiveType element_type = 2;

  // The size (number of elements) for each dimension, or an upper bound on the
  // size if the dimension is dynamic.  In XLA, dimensions are numbered from 0
  // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
  // size of dimension 0, the second element is the size of dimension 1, and so
  // forth.  Empty list indicates a scalar.
  // If the respective element in 'is_dimension_dynamic' is true then the value
  // in this field represents an upper bound on the size of the dimension.
  repeated int64 dimensions = 3;

  // For tuples only, the shapes of constituent shapes in the tuple sequence.
  repeated ShapeProto tuple_shapes = 4;

  // The layout used to back this shape.
  LayoutProto layout = 5;

  // For arrays, this indicates whether or not each dimension is
  // dynamically-sized. The number of elements in this repeated field should be
  // zero (indicating that no dimensions are dynamic) or equal to the number of
  // elements in the 'dimensions' field.
  repeated bool is_dynamic_dimension = 6;

// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShapeProto {
  repeated ShapeProto parameters = 1;
  ShapeProto result = 2;
  repeated string parameter_names = 3;

// Statistics of a computation.
message ComputationStats {
  // The number of floating point operations in the computation.
  double flop_count = 1;

  // The number of transcendental operations (e.g., exp) in the computation.
  double transcendental_count = 2;

// The type optimization profiles in use for Op-level optimizations.
enum ProfileType {
  INVALID = 0;
  WINDOW = 1;
  FLAG = 2;
  INTEGER = 3;

// The source of the optimization profile.
enum ProfileSource {

// The compilation event that triggered the use of the profile.
enum CompilationEvent {

// Symbolization metadata for HLO Instructions.
// This metadata is used for debugging XLA code generation, as well as
// performance profiling of XLA-generated executables.
message OpMetadata {
  // The framework op name that generated this XLA op.
  // Frameworks that build on top of XLA should mirror the names of their ops
  // back to users by specifying the op_type. In this way, even if the
  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
  // multiple ops, then each op should have the op_type be "SoftMax".)
  string op_type = 1;
  // The user-specified name of the op.
  // This name is often unique within a computation. Note: some frameworks
  // add auto-generated names if the user does not provide one.
  string op_name = 2;
  // Indicate a file and line that this op is associated to in a user's program.
  // e.g. it could be the file and line of user code that generated the op.
  string source_file = 3;
  int32 source_line = 4;

  // Deprecated, use [ProfileInfo][profile_type] instead.
  repeated ProfileType profile_type = 5 [deprecated = true];

  // The footprint of the generated code for the instruction.
  int64 size_of_generated_code_in_bytes = 8;
  // The size of the working set, i.e., the amount of memory, used by the
  // instruction in a compiler-managed fast device memory.
  int64 size_of_memory_working_set_in_bytes = 9;

  // Information about the optimization profile that this operation contains.
  message ProfileInfo {
    // The type of optimization profiles that this operation contains.
    repeated ProfileType profile_type = 1;
    // Speedup of tuned config compared to default config.
    // TODO(b/203817882) Set the relative_speedup.
    double relative_speedup = 2;
    // The source of the optimization profiles that this operation contains.
    ProfileSource profile_source = 3;
    // The compilation event that triggered the use of the profiles.
    CompilationEvent compilation_event = 4;

  // Profile information for the Op.
  ProfileInfo profile_info = 10;

  // Deduplicated HLO name for this op. In some cases, we can have multiple
  // instructions (e.g. fusions) that are considered duplicates. We want to
  // group them together under the same name so that we can group them together
  // during analysis (e.g. HLO Op Profile tool in Xprof).
  // E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates,
  // fusion.2 and fusion.3 will have deduplicated_name = fusion.1
  string deduplicated_name = 12;

  // Whether to preserve the layout of the HLO op.
  bool preserve_layout = 13;

  // 1-based position of the frame in frames flat array.
  // Ids are 1-based to keep 0 value as representation of non-set property.
  int32 stack_frame_id = 15;

// Profile data from the execution of a computation.
message ExecutionProfile {
  // Whether the executable was read from the compilation cache.
  bool compilation_cache_hit = 1;

  // The time in milliseconds spent to compile the computation. This only set if
  // the executable was not read from the compilation cache
  // (compilation_cache_hit == false).
  int64 compile_time_ms = 2;

  // The number of cycles spent for the computation. This does not include the
  // time taken for the data transfers between the host and the device. This is
  // a target-dependent field and only used for debugging purposes.
  int64 compute_cycle_count = 3;

  // The time in nanoseconds spent for the computation, without data transfer.
  int64 compute_time_ns = 4;

  // The time in nanoseconds spent for the entire computation, including the
  // result data transfer time. Current implementation does not spend any cycles
  // for the input data transfer since the memory is initialized with the proper
  // values before the execution.
  int64 compute_and_transfer_time_ns = 5;

  // The size of the binary code in the executable.
  int64 executable_size_in_bytes = 6;

  // Whether this profile was drawn from a cache of profiles instead of from
  // execution on the hardware.
  bool profile_cache_hit = 7;

  // Whether a warm-up run of the computation was executed before the
  // measured execution.
  bool warmup_run_executed = 8;

// Handle given to a user that represents an execution that the user launched
// asynchronously on the device.
message ExecutionHandle {
  int64 handle = 1;

// Handle given to a user that represents a globally accessible allocation.
// Contrast this against a ComputationDataHandle, which is not globally
// accessible, since it only exists within a specific computation.
message GlobalDataHandle {
  int64 handle = 1;

// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
message DeviceHandle {
  int64 handle = 1;

  // The number of model-parallel virtual devices that communicate via XLA
  // Send/Recv instructions.
  int64 device_count = 2;

// Handle given to a user to represent a channel between two computations
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
// Send instructions will be blocked until the data is transferred.
message ChannelHandle {
  int64 handle = 1;
  enum ChannelType {
    // Invalid primitive type to serve as default.

    // A channel for sending data between devices.

    // A channel for sending data from the device to the host. Can only be used
    // with a Send operation.

    // A channel for sending data from the host to the device. Can only be used
    // with a Recv operation.
  ChannelType type = 2;

// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
  int32 replica_count = 1;
  int32 computation_count = 2;

  // Each logical computation runs on replica_count physical devices.
  // ComputationDevice represents the device ids assinged to the replicas.
  message ComputationDevice {
    repeated int64 replica_device_ids = 1;
  repeated ComputationDevice computation_devices = 3;

// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message LiteralProto {
  ShapeProto shape = 1;
  repeated bool preds = 2;
  bytes s2s = 26;
  bytes s4s = 21;
  bytes s8s = 15;
  bytes u2s = 27;
  bytes u4s = 22;
  bytes u8s = 3;
  repeated int32 s32s = 4;
  repeated int64 s64s = 5;
  repeated uint32 u32s = 6;
  repeated uint64 u64s = 7;
  repeated float f32s = 8;
  repeated double f64s = 9;
  repeated float c64s = 12;    // Stored as interleaved real, imag floats.
  repeated double c128s = 18;  // Stored as interleaved real, imag doubles.
  repeated LiteralProto tuple_literals = 10;
  // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
  bytes f16s = 11;
  bytes bf16s = 13;
  bytes u16s = 16;
  bytes s16s = 17;
  bytes f8e5m2s = 19;
  bytes f8e4m3fns = 20;
  bytes f8e4m3b11fnuzs = 23;
  bytes f8e5m2fnuzs = 24;
  bytes f8e4m3fnuzs = 25;
  repeated int64 sparse_indices = 14;
message WindowDimension {
  // The size of the window in this dimension. For a rectangle, this would be
  // the width or height.
  int64 size = 1;

  // The stride at which the window moves across the base area in this
  // dimension. In other words, this is the spacing between different
  // positions of the window in this dimension.
  int64 stride = 2;

  // If positive, means the amount of padding to add to the base area at the low
  // end of this dimension; if negative, its negative means the number of
  // elements removed from the low end of this dimension. For example, in the
  // horizontal dimension of a rectangle, this would be the number of padding
  // values to pad on the left, given that indices increase when going right.
  // The actual padding value depends upon the context. Convolution pads with
  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
  // init value.
  int64 padding_low = 3;

  // As padding_low, but on the high end of this dimension. For example, in the
  // horizontal dimension of a rectangle, this would be the number of values to
  // pad on the right, given that indices increase when going right.
  int64 padding_high = 4;

  // Dilation factor of the sliding window in this dimension. A dilation factor
  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
  // implicitly placed between each kernel element. This value may not be less
  // than 1. See documentation for convolution.
  int64 window_dilation = 5;

  // Dilation factor of the base area in this dimension. A dilation factor of 1
  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
  // placed between each base area element. This value may not be less than 1.
  // See documentation for convolution.
  int64 base_dilation = 6;

  // Window reversal means that this dimension was logically reversed before the
  // operation.
  bool window_reversal = 7;

// Describes the windowing in an operation such as convolution.
// The window is moved across a base area and for each position of the
// window a computation is performed. The field below describes the
// window and the movement of the window across a base area.
message Window {
  repeated WindowDimension dimensions = 1;

// Describes the dimension numbers for a gather operation.
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
// more details.
message GatherDimensionNumbers {
  // "Window indices" is a term for a set of indices that index into the
  // interior of a dynamic-slice from the input tensor, the starting indices for
  // which were computed from output_gather_dims (see the operation semantic for
  // how this is defined) and the start_indices tensor.
  // The window indices for a specific output index Out is computed as:
  //  i = 0
  //  for (k : [0, input_tensor_shape.rank))
  //    window_indices[k] =
  //      if k in collapsed_slice_dims
  //      then 0
  //      else Out[offset_dims[i++]]
  repeated int64 offset_dims = 1;
  repeated int64 collapsed_slice_dims = 2;

  // This is interpreted as a map from i to start_index_map[i]. It
  // transforms the gather index looked up from the start_indices tensor into
  // the starting index in the input space.
  repeated int64 start_index_map = 3;

  // The dimension in the start_indices input that contains the starting
  // indices.
  int64 index_vector_dim = 4;

// Describes the dimension numbers for a scatter operation.
// All the fields are similar to the corresponding fields in
// GatherDimensionNumbers. Differences are noted below.
message ScatterDimensionNumbers {
  // The set of dimensions in the updates shape that are window dimensions.
  repeated int64 update_window_dims = 1;
  // The set of window dimensions that must be inserted into the updates shape.
  repeated int64 inserted_window_dims = 2;

  repeated int64 scatter_dims_to_operand_dims = 3;
  int64 index_vector_dim = 4;

message ConvolutionDimensionNumbers {
  // The number of the dimension that represents batch in the input.
  int64 input_batch_dimension = 7;

  // The number of the dimension that represents features in the input.
  int64 input_feature_dimension = 8;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the input.
  repeated int64 input_spatial_dimensions = 11;

  // The number of the dimension that represents input features in the
  // convolutional kernel (rhs).
  int64 kernel_input_feature_dimension = 3;

  // The number of the dimension that represents output features in
  // the convolutional kernel (rhs).
  int64 kernel_output_feature_dimension = 4;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the kernel (rhs). window.strides(0) is the
  // stride in the kernel_spatial_dimensions(0) dimension.
  repeated int64 kernel_spatial_dimensions = 6;

  // The number of the dimension that represents batch in the output.
  int64 output_batch_dimension = 9;

  // The number of the dimension that represents features in the output.
  int64 output_feature_dimension = 10;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the output.
  repeated int64 output_spatial_dimensions = 12;

enum PaddingType {
  PADDING_VALID = 1;  // Only valid portion of the base are covered.
  PADDING_SAME = 2;  // Extra is added to produce same output size as the input.

enum FftType {
  FFT = 0;    // Forward FFT; complex in, complex out.
  IFFT = 1;   // Inverse FFT; complex in, complex out.
  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
              //                   fft_length real out

message DotDimensionNumbers {
  // The dimension numbers that represent the 'lhs' contracting dimensions.
  repeated int64 lhs_contracting_dimensions = 1;
  // The dimension numbers that represent the 'rhs' contracting dimensions.
  repeated int64 rhs_contracting_dimensions = 2;
  // The dimension numbers that represent the 'lhs' batch dimensions.
  repeated int64 lhs_batch_dimensions = 3;
  // The dimension numbers that represent the 'rhs' batch dimensions.
  repeated int64 rhs_batch_dimensions = 4;

enum SparsityType {

  // Structured N:M sparsity.

// Contains sparsity metadata for a sparse dot operation.
// The only supported type atm is structured 2:4 sparsity, which is natively
// supported on NVidia GPUs.
// Restrictions:
// - only one operand of the dot operation may be sparse;
// - only the contracting dimension may be sparse.
message SparsityDescriptor {
  SparsityType type = 1;

  // Sparse operand index (0 or 1).
  int32 index = 2;
  // Sparse dimension number.
  int32 dimension = 3;

  // Structured N:M sparsity (N < M).
  int32 n = 4;
  int32 m = 5;

enum RandomDistribution {

  // Creates a uniform-distribution-generated random number on the semi-open
  // interval [parameter[0], parameter[1]).

  // Creates a normal-distribution-generated random number with mean
  // parameter[0] and standard deviation parameter[1].

enum RandomAlgorithm {
  RNG_DEFAULT = 0;  // Backend dependent default algorithm.
message TriangularSolveOptions {
  // If true, solves ax = b. If false, solves xa = b.
  bool left_side = 1;

  // If true, 'a' is lower triangular. If false, 'a' is upper triangular.
  bool lower = 2;

  // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
  bool unit_diagonal = 3;

  // Should we transpose or use the adjoint of 'a'?
  enum Transpose {
    NO_TRANSPOSE = 1;  // Don't transpose 'a'.
    TRANSPOSE = 2;     // Transpose 'a'.
    ADJOINT = 3;       // Complex conjugate and transpose 'a'.
  Transpose transpose_a = 4;

message CholeskyOptions {
  // If true, uses the lower triangle of `a`. If false, uses the upper triangle
  // of `a`.
  bool lower = 1;

// Attributes of the sort custom call (cub::DeviceRadixSort).
message SortOptions {
  bool descending = 1;

// Generic map of attributes used to pass hints / configuration options from
// the Python frontend to the XLA backend.
message FrontendAttributes {
  map<string, string> map = 1;

// Represents a single statistic to track.
message Statistic {
  // Must be a single word consisting of any alphanumeric characters
  string stat_name = 1;
  // Must be within a range of [0, 100], in order for the graph dumper to
  // properly render the statistic onto the graph.
  double stat_val = 2;

// Represents the information needed to visualize propagation statistics when
// rendering an HLO graph. This includes an array of statistics as well as the
// index of the statistic to render.
message StatisticsViz {
  int64 stat_index_to_visualize = 1;
  repeated Statistic statistics = 2;

message OpSharding {
  enum Type {
    // This sharding is replicated across all devices (implies maximal,
    // all other fields are unused).
    // This sharding is maximal - one device runs the entire operation.
    MAXIMAL = 1;
    // This sharding is a tuple - only the tuple_shardings field is valid.
    TUPLE = 2;
    // None of the above; tile_shape and tile_assignment are both used.
    OTHER = 3;
    // This op is manually sharded: the shapes are already partitioned and the
    // partitioner should not change this op.
    MANUAL = 4;
    // This sharding is a placeholder sharding with lowest precedence, it can be
    // overwriten by any other shardings.
    UNKNOWN = 5;
  Type type = 1;
  // The shape of the sharded tile.
  ShapeProto tile_shape = 2;
  // The shape of the tile assignment tensor - this must be the same rank as
  // tile_shape and the product of its dimensions must equal
  // tile_assignment_devices.size().
  repeated int64 tile_assignment_dimensions = 3;
  // Flattened list of device IDs. The order of flattening is the same as used
  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
  // Only one of tile_assignment_devices and iota_dimensions shall be non-empty.
  repeated int64 tile_assignment_devices = 4;
  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
  // in pre-order. The tuple shape could be nested; here we store just a
  // flattened list of all leaves in the tuple shape. Note that the tuple shape
  // is not stored here; shardings do not store the shapes to which they are
  // applied, this is inferred from the instruction this sharding gets attached
  // to.
  repeated OpSharding tuple_shardings = 5;

  // Only used for OTHER type. If true, data is sharded according to other
  // dimensions of tile_assignment(), but replicated across devices along the
  // last dimension. (Experimental)
  bool replicate_on_last_tile_dim = 6;
  // This field is used to track the source of this sharding, usually derived
  // from instructions. Multple metadata may be populated if sharding is
  // combined with other shardings.  Metadata are to not be populated when
  // type == TUPLE and instead metadata should be set on individual tuple
  // elements.
  repeated OpMetadata metadata = 7;

  // This field is used to represented the sharding type of each subgroup.
  // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={
  // replicate, manual, unreduced}} means that each of the last 3 dimensions
  // in [2,2,2,2] represents a subgrouping in replicate, manual,
  // unreduced sharding type respectively.
  repeated Type last_tile_dims = 8;

  // Dimensions used to reshape the 1D iota array of device IDs.
  // Only one of tile_assignment_devices and iota_reshape_dims shall be
  // non-empty.
  repeated int64 iota_reshape_dims = 9;

  // Dimension permutations to transposed the iota array reshaped to
  // iota_reshape_dims. This must have the same size as iota_reshape_dims.
  repeated int32 iota_transpose_perm = 10;

  // This field decides whether this op is in a shard group.
  bool is_shard_group = 11;

  // This field is used to store the unique id of the shard group.
  int64 shard_group_id = 12;

  // Used to decide whether this op is to be sharded like some other ops, or to
  // which other ops will be sharded like.
  enum ShardGroupType {
    // This op will be sharded exactly the same as the other op. (hard
    // restriction)
    AS = 0;
    // This op will try to allow sharding propagation within the same group even
    // there is no data dependencies among them, but there is no guarantee that
    // the final shardings within the same group will be exactly the same. (soft
    // restriction)
    LIKE = 1;

  ShardGroupType shard_group_type = 13;
// Describes the replica groups in a cross replica op (e.g., all-reduce and
// all-to-all).
message ReplicaGroup {
  // The ids of the replicas that belongs to the same group. The ordering of the
  // ids matters in some ops (e.g., all-to-all).
  repeated int64 replica_ids = 1;

// Describes the source target pair in the collective permute op.
message SourceTarget {
  int64 source = 1;
  int64 target = 2;

// Used to indicate the precision configuration. It has backend specific
// meaning.
message PrecisionConfig {
  enum Precision {
    DEFAULT = 0;
    HIGH = 1;
    HIGHEST = 2;
    // Each U8/S8 value in a tensor actually represents 2 nibble values.

  // The algorithm used to evaluate the instruction.
  // The naming convention for the dot instruction is
  // and ACCUM_TYPE correspond to the types in the "primitive dot operations"
  // (such as TensorCore operations) and NUM_OPS is the number of such
  // operations used per "primitive tile". When the NUM_OPS
  // field is skipped, it is assumed to be 1. The types mentioned in the name
  // are independent of the storage types.
  // In general ATYPE and BTYPE are the precisions that the LHS and RHS of the
  // operation are rounded to and ACCUMTYPE is the accumulation type. If a
  // backend does not support the given algorithm, an error is raised. The
  // Algorithm enum is intended to eventually replace the Precision enum.
  enum Algorithm {
    // If the algorithm is `ALG_UNSET`, we will decide the algorithm based on
    // the operand_precision values (for now).
    ALG_UNSET = 0;
    // The storage type can be any 8-bit floating point type.
    ALG_DOT_ANY_F8_ANY_F8_F32 = 1;
    // The storage type can be any 8-bit floating point type. Intermediate
    // results will not periodically be promoted to a higher precision. This
    // corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's
    // maxNumImpreciseAcc=32 setting may be similar.
    ALG_DOT_F16_F16_F16 = 3;
    ALG_DOT_F16_F16_F32 = 4;
    ALG_DOT_BF16_BF16_BF16 = 5;
    ALG_DOT_BF16_BF16_F32 = 6;
    // An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better
    // precision.
    ALG_DOT_BF16_BF16_F32_X3 = 7;
    // An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better
    // precision (similar to F32).
    ALG_DOT_BF16_BF16_F32_X6 = 8;
    ALG_DOT_TF32_TF32_F32 = 9;
    // An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better
    // precision (similar to F32).
    ALG_DOT_TF32_TF32_F32_X3 = 10;
    ALG_DOT_F32_F32_F32 = 11;
    ALG_DOT_F64_F64_F64 = 12;

  repeated Precision operand_precision = 1;

  // Currently doesn't do anything, but we plan to support it for dot and
  // possibly more instructions.
  // TODO(b/316147294): Support this on GPU and add this to StableHLO as well.
  // If this is set, then `operand_precision` should be set to DEFAULT and it
  // will be ignored.
  Algorithm algorithm = 2;

// Describes whether all data-parallelism replicas will receive the same
// parameter data at each buffer.
message ParameterReplication {
  // A list of boolean values for the flattened leaf buffers. Each value
  // indicates whether the corresponding leaf buffer is replicated.
  // If this field is empty, it means no buffer is replicated. Otherwise, the
  // number of elements in this field must match the number of leaf buffers in
  // the HLO instruction's shape.
  repeated bool replicated_at_leaf_buffers = 1;

// A backend-config for kWhile loops that stores the loop's trip count, if it is
// known.
// This is useful for backends that can implement a `for i in 0..N` loop more
// efficiently than a `while` loop.  For example, on GPUs, we can implement a
// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
// whereas implementing a `while` loop requires a host-device sync on each
// iteration.
message WhileLoopBackendConfig {
  message KnownTripCount {
    int64 n = 1;
  // This indirection lets us distinguish between known-trip-count == 0 and
  // unknown-trip-count.
  KnownTripCount known_trip_count = 1;

// Specifies a pair of output/operand buffers that alias each other for
// kCustomCall and kFusion
message OutputOperandAliasing {
  repeated int64 output_shape_index = 1;
  int64 operand_index = 2;
  repeated int64 operand_shape_index = 3;