tensorflow/tensorflow

View on GitHub
tensorflow/python/tpu/tensor_tracer.py

Summary

Maintainability
F
2 wks
Test Coverage
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========================================================================
"""A utility to trace tensor values on TPU."""

import collections
import hashlib
import operator
import os
import os.path
import sys

import numpy as np

from tensorflow.core.framework import summary_pb2
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_case
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import summary_ops_v2 as summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import analytics
from tensorflow.python.platform import gfile
from tensorflow.python.platform import remote_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary_iterator
from tensorflow.python.tpu import tensor_tracer_flags
from tensorflow.python.tpu import tensor_tracer_report
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import training_util

_DEVICE_TYPE_TPU = 'tpu'
_DEVICE_TYPE_CPU = 'cpu'
_TRACE_MODE_PART_TENSOR_SIZE = 3

_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow'
_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
_REASON_SKIP_SCALAR = 'not-traced-scalar'
_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
_REASON_SCALAR_GET_TRACED = 'traced-scalar'
_REASON_TENSOR_GET_TRACED = 'traced-tensor'
_REASON_USER_INCLUDED = 'traced-user-included'
_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'

_OUTPUT_STREAM_ESCAPE = 'file://'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
_TRACE_FILE_NAME = 'trace.all'
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
_TT_SNAPSHOT = 'tensor_tracer_snapshot'
_REPLICA_ID_TAG = '#replica-id: '
_SKIP_REPORT_FILE = 'None'  # Do not write report proto if --report_file=None

_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS
_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY

_TT_SUMMARY_TAG = 'tensor_tracer_summary'
_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'

_TT_SUMMARY_MAX_QUEUE = 10

tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1',
                                'tensor tracer usage', 'method')


def _graph_summary_tag(graph):
  """Generates and returns a summary tag name for the given graph."""

  if graph is None:
    raise RuntimeError('graph is None')
  # The chance of collision with md5 is effectively 0.
  hash_id = hashlib.md5()
  hash_id.update(repr(graph).encode('utf-8'))
  # hexdigest() returns a string.
  return hash_id.hexdigest()


def set_parameters(tensor_tracer_params=None):
  """Enables tensor tracer and sets its parameters.

  Example usage:
    tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir',
                                'trace_mode': 'norm',
                                'report_file': '/usr/tmp/trace_dir/report.all'}
    tensor_tracer.set_parameters(tensor_tracer_parameters)

  This sets up the parameters for tensor tracer. A call to tensor tracer as
  below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be
  skipped as this call is hooked into tpu.rewrite.
    tt = tensor_tracer.TensorTracer()
    loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss)

  Args:
    tensor_tracer_params: Tensor tracer parameter dictionary. Below gives
    examples of these parameters: See tensor_tracer_report.py for all
      parameters.
        - enable: If set, tensor tracer will be enabled. Calling
          enable_tensor_tracer automatically adds this parameters.
        - trace_mode: The trace_mode to be used by tensor tracer. These include:
          - summary: Collects multiple statistics for traced tensors, and writes
            them a summary file that can be visualized using tensorboard. This
            mode currently only works for TPUEstimator. It can be also be used
            for other models, but outfeed must be handled by the user.
          - norm: Collects norm of each traced tensor and writes them into a
            text file pointed by 'trace_dir' flag. (Default mode).
          - nan-inf: Checks the existince of NaNs and Infs in the tensor, and
            writes a boolean value to a text file pointed by 'trace_dir' flag.
            Note that 'norm' mode can also capture this information with more
            numerical info.
          - max-abs: Collects the absolute max for each traced tensors and
            writes it into a text file pointed by 'trace_dir' flag.
          - full-tensor: Writes the full tensor content of the traced tensors
            into a text file pointed by 'trace_dir' flag.
          - part-tensor: Writes a part of the tensor content of the traced
            tensors into a text file pointed by 'trace_dir' flag.
          - full_tensor_summary: Writes the full tensors as binary event files.
            The outputs can be read using: trace =
              tensor_tracer.read_tensor_tracer_event_file(event_file_path)

        - report_file: Path to the metadata file that is written during graph
          construction. If not set, metadata will be printed to stdout during
          graph construction.
        - trace_dir: Path where the execution traces will be written during the
          graph execution. If not set, trace will be printed to stderr.
        - trace_level: Tensor tracer aims to trace everything it can. This
          introduces some overhead on graph execution and graph compilation
          times. Using trace_level parameter, it is possible to trace operation
          based on their priorities. For example, - trace_level=7 is the highest
          trace_level, in which every op is traced. - trace_level=6 will skip
          constant operations such as tf.constant. - trace_level=5 will skip
          less important ops such as tf.identities. - The default trace_level=3,
          that will skip concat ops, or random number generators. - To reduce
          the graph compile time overhead, trace_level can be set to 0, that
          will skip additions, and substractions, and multiplications as well.
        - excluded_opnames: If set, any matching op name will not be traced.
          excluded_opnames can be set as a regular expression. E.g,
          excluded_opnames=.* will exclude everything.
        - excluded_optypes: If set, any matching op type will not be traced.
          excluded_optypes can be set as a regular expression. E.g,
          excluded_optypes=.* will exclude everything. excluded_optypes=MatMul
          will exclude all MatMul ops from tracing.
        - included_opnames: If set, any matching op name will be forced to be
          traced. included_opnames can be set as a regular expression. E.g,
          '--included_opnames=some_op --excluded_opname=*.' will only trace
          some_op.
        - included_optypes: If set, any matching op type will be forced to be
          traced. included_optypes can be set as a regular expression. E.g,
          '--included_optypes=some_op_type --excluded_optypes=*.' will trace
          only the ops with type 'some_op_type'
        - flush_summaries: If summary mode is used, flush_summaries=1 will
          flush summaries using outside compilation. Note that, if used with
          low level APIs, flush_summaries=1 is necessary to obtain results.
        Advanced Flags:
        - trace_scalar: Scalar values are not traced by default. If this flag is
          set, scalar values will also be traced.
        - op_range: In the form of '%d:%d' that limits the tracing to the ops
          within this limit. --op_range='5:10' will trace only the ops that have
            topological order between 5-10.
        - submode: 'brief' or 'detailed'. If the trace mode is not compact,
          brief mode will print only the id of each traced tensor to save some
          space. 'detailed' mode prints the full tensor name.
        - use_fingerprint_subdirectory: The trace directory will be chosen as
          using the fingerprint of the trace metadata under the provided
          trace_dir.
  """
  enable_flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
  if tensor_tracer_params:
    for key, value in tensor_tracer_params.items():
      enable_flags += ' --%s=%s' % (key, value)
  os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = enable_flags


def op_priority(op_type):
  """Returns the priority of the op.

  If the priority of the op is k, it will be traced if trace_level>=k.
  Args:
    op_type: String name of the operation type.
  Returns:
    Integer value corresponding the priority of the op.
  """
  if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
                 'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
    # Lowest priority ops, e.g., constant ops across different steps,
    # They will be traced only if trace_level>=7
    return 7

  if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
                 'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'):
    # Operations without numerical effects.
    # They will be only if trace_level>=6
    return 6
  if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
                 'CollectivePermute', 'SplitV', 'DynamicPartition'):
    # Operations that merge or slice an input, will be traced if trace_level>=5
    return 5
  if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
    # Operations less likely to provide useful information,
    # will be traced if trace_level>=4
    return 4
  if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
    # Add operations that are less likely create any issues, will be traced
    # if trace_level>=3 (default=3)
    return 3
  if op_type in ('Neg', 'Sub'):
    # Sub operations that are less likely create any issues, will be traced
    # trace_level>=2
    return 2
  if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
                 'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'):
    # Multiplication and some other operations, will be traced if trace_level>=1
    return 1

  # Unclassified op_types default to being traced at level 2 and above.
  return 2


def read_tensor_tracer_event_file(event_file):
  """Reads the event file written by tensor tracer.

  This can be used to read the full tensors written into binary event files by
  by TensorTracer with trace_mode=full_tensor_summary.

  Example usage:
    result_dict_list = tensor_tracer.read_tensor_tracer_event_file(
      event_file_path)
    for result_dict in result_dict_list:
      for step, tensor_dict in result_dict.items():
        for tensor_name, full_tensor_content in tensor_dict.items():
          logging.info(tensor_name, full_tensor_content)

  Args:
    event_file: Path to the event file that contains only tensor tracer events.
  Returns:
    A list of event dictionaries, each of which with the form:
    {step_number: {tensor_name: tensor_content}}. This is a list instead of
    a single event dictionary because it is possible that an event file may
    have multiple event traces, each of them covering the same step ranges.
  Raises:
    ValueError: If an unexpected trace is found.
  """

  # Keeps track of how many times that a step number shows up in these events.
  step_occurrence_count = collections.defaultdict(int)

  # List of step occurrences.
  step_occurrence_list = []

  for trace_event in summary_iterator.summary_iterator(event_file):
    # First event is an event with file_version: "brain.Event:2"
    if not trace_event.HasField('summary'):
      continue
    if len(trace_event.summary.value) != 1:
      raise ValueError('Single step contains %d summary values,'
                       ' expected 1.' % len(trace_event.summary.value))
    step = trace_event.step
    step_occurrence_count[step] += 1  # a new occurrence for this step.

    occurrence_idx = step_occurrence_count[step] - 1
    occurrence_size = len(step_occurrence_list)

    if occurrence_idx == occurrence_size:
      # This particular occurrence isn't yet recorded on step_occurrence_list.
      # So append this new occurrence to the end of step_occurrence_list.
      new_occurrence = collections.defaultdict(dict)
      step_occurrence_list.append(new_occurrence)
    else:
      # This particular occurrence must be already recorded on
      # step_occurrence_list (i.e. occurrence_idx < occurrence_size).
      if occurrence_idx > occurrence_size:
        raise ValueError('Unexpected: occurrence_idx (%d) > '
                         'occurrence_size (%d)' % (occurrence_idx,
                                                   occurrence_size))
    tensor_value = trace_event.summary.value[0]
    tensor_name = tensor_value.tag

    real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
    tensor_content = np.frombuffer(
        tensor_value.tensor.tensor_content,
        dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
        ).reshape(real_shape)
    step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content
  return step_occurrence_list


def trace_tensor(tensor, tracepoint_name=None):
  """Programmatic interface to trace a tensor with Tensor Tracer.

  Tensor Tracer, by default, traces all tensors in the execution. This function
  can be used to limit traced tensors. If this function is called for a subset
  of the tensors, only those will be traced.

  For example, Tensor Traacer will only trace c below.
    c = tf.MatMul(a, b)
    tensor_tracer.trace_tensor(c)
    d = tf.add(c, 1)
  Args:
     tensor: the tensor object for which the tracing is requested.
     tracepoint_name: an optional tensor tracepoint name string. A tracepoint
       name is an Tensor Tracer internal name for the tensor. It is useful when
       comparing equivalent traces from different models that have different
       tensor namings. Equivalent tensors (with different names) can be mapped
       to each other by assigning a common tracepoint_name.

  Returns:
    The provided tensor.
  """
  if tracepoint_name is None:
    tracepoint_name = tensor.name
  tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
  tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
                                 (tensor, tracepoint_name))
  return tensor


def keras_layer_tracepoint(layer, checkpoint_name):
  """An interface for adding the tensor outputs of a keras layer.

  Encapsulates trace_tensor.

  Args:
     layer: A keras layer.
     checkpoint_name: a string name for the checkpoint. This name has to be a
     unique name if used within model comparison. The tensors that have the same
     checkpoint identifier is compared in model comparison.

  Returns:
    The provided layer.
  """
  try:
    outputs = layer.output
    if tensor_util.is_tf_type(outputs):
      trace_tensor(outputs, '%s' % (checkpoint_name))
    else:
      idx = 0
      for output_tensor in outputs:
        if tensor_util.is_tf_type(outputs):
          trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx))
        idx += 1
  except AttributeError:
    pass
  except RuntimeError:
    pass
  return layer


class TensorTracer:
  """A software construct for tracing tensor values in a TF graph.

  This utility is disabled by default. It is hooked into tpu.rewrite, so it can
  easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as
  below without a code change.
    export TENSOR_TRACER_FLAGS="--enable=1"

  Below is the use example to enable it on CPUs or GPUs, or for more advance use
  cases on TPUs.

    a = x + 1
    b = a * 2
    rs = tf.reduce_sum(b)
    tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir',
                             'report_file: 'path/to/report/file'})
    tt = tensor_tracer.TensorTracer()
    if on_tpu:
      rs = tt.trace_tpu(tf.get_default_graph(),
                          tensor_fetches=rs)
    else:
      rs = tt.trace_cpu(tf.get_default_graph(),
                          tensor_fetches=rs)
    session.run(rs)

  If it is enabled, it will trace the output tensor values of
  selected Ops in the graph. It has two outputs: (1) the traces and (2)
  a report. The traces are dumped to a specified directory during the graph
  execution, while the report is dumped during the graph construction.
  By passing options via the env variable, users can change:
     (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
         full tensor values)
     (2) which Ops to be traced (via op.name or op.type)
     (3) output trace file path.

  """
  # The set of graphs that are rewritten by tensor tracer.
  _traced_graphs = set()

  @staticmethod
  def is_enabled():
    """Returns True if TensorTracer is enabled."""
    try:
      enable = tensor_tracer_flags.TTParameters().is_enabled()
      # Add metrics to determine API usage.
      if enable: tt_gauge.get_cell('is_enabled').set(True)
      return enable
    except (ValueError, RuntimeError) as e:
      logging.warning(
          'Tensor Tracer V1 flags processing error encountered in is_enabled '
          'check. %s', e)
      # TODO(b/210212559): Find a more robust fix.
      # Should only produce exception if Tensor Tracer is enabled.
      return True

  @staticmethod
  def check_device_type(device_type):
    """Checks if the given device type is valid."""

    if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
      raise ValueError('Invalid device_type "%s"'%device_type)

  @staticmethod
  def check_trace_mode(device_type, trace_mode):
    """Checks if the given trace mode work on the given device type.

    Args:
      device_type: Device type, TPU, GPU, CPU.
      trace_mode: Tensor tracer trace mode.
    Raises:
      ValueError: If the given trace mode is not supported for the device.
    """
    if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
      if device_type != _DEVICE_TYPE_TPU:
        raise ValueError('Device_type "%s" is not yet supported for '
                         'trace mode "%s"' % (device_type, trace_mode))

  @staticmethod
  def loop_cond_op(op):
    return op.type in ('LoopCond', 'RefLoopCond')

  @staticmethod
  def while_loop_op(op):
    """Returns true if op is one of the special ops of in a while loop.

    Args:
       op: A tf.Operation.

    Returns:
       True if the given op is one of [Switch, Merge, Enter, Exit,
       NextIteration, LoopCond], which are all building blocks for TF while
       loops.
    """
    return  (control_flow_util.IsLoopSwitch(op) or
             control_flow_util.IsLoopMerge(op) or
             control_flow_util.IsLoopEnter(op) or
             control_flow_util.IsLoopExit(op) or
             TensorTracer.loop_cond_op(op) or
             op.type in ('RefNextIteration', 'NextIteration'))

  @staticmethod
  def control_flow_op(op):
    """Returns true if op is one of the special ops of in a while loop.

    Args:
       op: A tf.Operation.

    Returns:
       True if the given op is one of [Switch, Merge, Enter, Exit,
       NextIteration, LoopCond], which are all building blocks for TF while
       loops.
    """
    return  (control_flow_util.IsSwitch(op) or
             control_flow_util.IsMerge(op))

  @staticmethod
  def unsafe_op(op):
    """Returns True if this op is not safe to be traced."""

    # Reasons for not including following op types:
    #    Assign: cause incorrect result with CPU tracing.
    if op.type == 'Assign':
      return True
    return False

  @staticmethod
  def device_mismatch(device_type, op):
    if device_type == _DEVICE_TYPE_TPU:
      # pylint: disable=protected-access
      return tpu_replication._TPU_REPLICATE_ATTR not in op.node_def.attr
      # pylint: enable=protected-access
    return False

  @staticmethod
  def unsafe_scalar_trace(op):
    """Return true if scalar output tensor from Op is not safe to be traced."""

    # Tracing the following causes cycle in the graph on TPU.
    if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
                   'Switch', 'Less', 'ReadVariableOp'):
      return True
    # Tracing the following will cause casting-issue
    # with the norm tracing mode or other compilation issues on CPU.
    if op.type in ('VarHandleOp', 'IteratorToStringHandle',
                   'IteratorGetNext', 'OneShotIterator',
                   'IteratorV2', 'MakeIterator',
                   'BatchDatasetV2', 'MapDataset',
                   'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
      return True
    return False

  def _is_interesting_op(self, op):
    """Returns True if the given op is not an interesting one to be traced."""
    return op_priority(op.type) <= self._parameters.trace_level

  @staticmethod
  def reason(op_idx, details):
    """Returns reason why the Op at op_idx is traced or not."""

    return '%d %s'%(op_idx, details)

  def __init__(self):
    """Initializes a TensorTracer.

    Sets the various member fields from the flags (if given) or the defaults.
    """
    self._replica_id = None
    self._tt_config = tensor_tracer_report.TensorTracerConfig()
    self._parameters = tensor_tracer_flags.TTParameters()
    self._host_call_fn = {}
    # _cache_variables is a dict (key = graph, value = dicts
    # (key = name, value = tensors))
    self._cache_variables = {}
    self._history_value_cache = {}

    self._traced_op_names = set()
    self._report_proto = None
    # _temp_cache_var is a dict (key = graph, value = [])
    self._temp_cache_var = {}
    self._report_proto_path = ''
    self._outmost_context = None

  def report_proto(self):
    """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.

    Returns:
      A tensor_tracer.proto object.
    Raises:
      ValueError if called before tracing happens, or when trace mode is not
      summary or full_tensor_summary.
    """
    if self._report_proto:
      return self._report_proto
    else:
      raise ValueError('Call to report_proto must be done after tracing.'
                       'Report proto only exists for '
                       'trace_mode=[summary|full_tensor_summary]')

  def report_proto_path(self):
    """Getter for path where tensor_tracer.proto object should be written.

    Returns:
      A string path.
    """
    return self._report_proto_path

  def _escape_namescopes(self, variable_name):
    return variable_name.replace('/', '_').replace(':', '_')

  def _cache_variable_for_graph(self, graph):
    if graph not in self._cache_variables:
      self._cache_variables[graph] = {}
    return self._cache_variables[graph]

  def _create_or_get_tensor_history_values_cache(self,
                                                 cache_name,
                                                 graph,
                                                 shape=None,
                                                 dtype=dtypes.float32):
    """Creates a variable as the cache to store historic intermediate tensor values.

    Args:
      cache_name: Name to be given to the cache (an instance of tf.variable).
      graph: Tensorflow graph.
      shape: A list of dimensions.
      dtype: Data type of created cache.
    Returns:
      A ref to newly created or existing cache with the given dimensions.
    Raises:
      ValueError:
        (1) If graph is None, or
        (2) shape is None when a new cache needs to be created.
    """
    if graph is None:
      raise ValueError('Invalid graph.')

    if graph not in self._history_value_cache:
      self._history_value_cache[graph] = {}

    if cache_name not in self._history_value_cache[graph]:
      if shape is None:
        raise ValueError('shape must be provided at cache creation.')
      if dtype.is_integer:
        init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
      else:
        init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE

      # Create in proper graph and base name_scope.
      with graph.as_default() as g, g.name_scope(None):
        self._history_value_cache[graph][
            cache_name] = variable_scope.get_variable(
                'tt_history' + '_' + self._escape_namescopes(cache_name),
                shape=shape,
                dtype=dtype,
                initializer=init_ops.constant_initializer(init_val),
                trainable=False,
                use_resource=True,
                collections=[
                    _TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES
                ])

    return self._history_value_cache[graph][cache_name]

  def _create_or_get_tensor_values_cache(self, cache_name, graph,
                                         shape=None, dtype=dtypes.float32):
    """Creates a variable as the cache to store intermediate tensor values.

    Args:
      cache_name: Name to be given to the cache (an instance of tf.variable).
      graph: Tensorflow graph.
      shape: A list of dimensions.
      dtype: Data type of created cache.
    Returns:
      A ref to newly created or existing cache with the given dimensions.
    Raises:
      ValueError:
        (1) If graph is None, or
        (2) shape is None when a new cache needs to be created.
    """
    if graph is None:
      raise ValueError('Invalid graph.')

    graph_cache_var = self._cache_variable_for_graph(graph)

    if cache_name not in graph_cache_var:
      if shape is None:
        raise ValueError('shape must be provided at cache creation.')
      if dtype.is_integer:
        init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
      else:
        init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE

      # Create in proper graph and base name_scope.
      with graph.as_default() as g, g.name_scope(None):
        graph_cache_var[cache_name] = variable_scope.get_variable(
            _TT_SNAPSHOT + '_' + self._escape_namescopes(cache_name),
            shape=shape, dtype=dtype,
            initializer=init_ops.constant_initializer(init_val),
            trainable=False,
            use_resource=True,
            collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
    return graph_cache_var[cache_name]

  def _add_replica_id_to_graph(self):
    """Adds nodes for computing the replica ID to the graph."""

    if self._tt_config.num_replicas:
      with ops.control_dependencies(None):
        # Uses None as dependency to run outside of TPU graph rewrites.
        self._replica_id = tpu_ops.tpu_replicated_input(
            list(range(self._tt_config.num_replicas)),
            name='tt_replica_id')
    else:
      self._replica_id = 'unknown'

  def _inside_op_range(self, idx):
    """Return True if the given index is inside the selected range."""

    if idx < self._parameters.op_range[0]:
      return False
    return (self._parameters.op_range[1] < 0 or
            idx <= self._parameters.op_range[1])

  def _is_user_included_op(self, op):
    """Checks whether the op is included in the tensor tracer flags.

    Args:
      op: tf Operation
    Returns:
      True, if the op is included.
      An op is included if:
      - Its op name is given in included_opnames
      - Its op type is given in included_optypes
      - The op is at most _trace_ops_before_included hops before an included op
      - The op is at most _trace_ops_after_included hops after an included op
    """
    for opname_re in self._parameters.included_opname_re_list:
      if opname_re.match(op.name):
        return True

    for optype_re in self._parameters.included_optype_re_list:
      if optype_re.match(op.type):
        return True
    return False

  def _is_user_excluded_op(self, op):
    for opname_re in self._parameters.excluded_opname_re_list:
      if opname_re.match(op.name):
        return True
    for optype_re in self._parameters.excluded_optype_re_list:
      if optype_re.match(op.type):
        return True
    return False

  def _signature_types(self):
    """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
    if self._parameters.trace_mode in set([
        tensor_tracer_flags.TRACE_MODE_NAN_INF,
        tensor_tracer_flags.TRACE_MODE_NORM,
        tensor_tracer_flags.TRACE_MODE_HISTORY,
        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
      return {self._parameters.trace_mode: 0}
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
      return self._parameters.summary_signatures
    return {}

  def _num_signature_dimensions(self):
    return len(self._signature_types())

  def _use_temp_cache(self):
    """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable.

    Returns:
      A boolean, denoting whether to use a temporary cache or not.
    """
    # If full tensors need to be stored tf.variables, then do not use temp
    # variables to store them.
    if self._use_tensor_buffer():
      return False
    if self._use_tensor_values_cache():
      return self._parameters.use_temp_cache_var
    else:
      # Temporary caches only replaces tf.Variables caches. If no cache is used
      # return False.
      return False

  def _use_tensor_values_cache(self):
    """Returns True if immediate tensors should be first saved to a cache."""
    return self._parameters.use_compact_trace

  def _use_tensor_buffer(self):
    """Returns true if the whole tensor needs to be cached/buffered in memory."""
    return (self._parameters.trace_mode ==
            tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)

  def _merge_tensor_signatures(self, signatures):
    """Returns a tensor that merges the given signatures.

    Args:
      signatures: A dictionary of the signature updates from signature name to
      a tensor of dimension [1].
    Returns:
      A tensor that concats the signature values in a predefined order.
    Raises:
      ValueError: Unable to merge signatures.
    """
    sorted_update = []
    if self._num_signature_dimensions() > 1:
      signature_indices = self._signature_types()
      for _, val in sorted(signatures.items(),
                           key=lambda item: signature_indices[item[0]]):
        sorted_update.append(val)
      updates = array_ops_stack.stack(
          sorted_update, axis=0, name='merge_single_op_signatures')
    elif self._num_signature_dimensions() == 1:
      # Avoid stack operation if there is only a single signature.
      (_, val), = signatures.items()
      updates = val
    else:
      raise ValueError('Cannot merge 0 signatures. Check the value passed for '
                       'flag --signatures.')
    return updates

  def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph):
    """Returns an op that will save the given updates to an entry in the cache.

    Args:
      cache_idx: The cache index of the tensor within the cache.
      updates: A dictionary of the signature updates from signature name to
      a tensor of dimension [1].
      graph: A TensorFlow graph.
    Raises:
      RuntimeError:
        (1) graph is not already in self._temp_cache_var, or
        (2) cache_idx is out of range.
    """
    updates = self._merge_tensor_signatures(updates)
    updates = array_ops.reshape(updates,
                                [self._num_signature_dimensions()])
    if graph not in self._temp_cache_var:
      raise RuntimeError('graph is not in self._temp_cache_var')
    if cache_idx >= len(self._temp_cache_var[graph]):
      raise RuntimeError('cache_idx (%d) is out of range (%d)' % (
          cache_idx, len(self._temp_cache_var[graph])))
    self._temp_cache_var[graph][cache_idx] = updates

  def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph):
    """Returns an op that will save the given updates to an entry in the cache.

    Args:
      cache_idx: The cache index of the tensor within the cache.
      updates: A dictionary of the signature updates.
      graph: A TensorFlow graph.
    Returns:
      Cache update operation.
    """
    # state_ops.scatter_update allows updates only along the first dimension.
    # Make a compact array by concatenating different signatures, and update
    # them all together.
    updates = self._merge_tensor_signatures(updates)
    updates = array_ops.reshape(updates,
                                [1, self._num_signature_dimensions()])
    indices = constant_op.constant([cache_idx])
    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
    return state_ops.scatter_update(cache, indices, updates).op

  def _snapshot_tensor(self, tensor):
    """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.

    Args:
      tensor: tensor whose values will be stored in a new tf.Variable.
    Returns:
      An assignment operation.
    """

    snapshot_variable = self._create_or_get_tensor_values_cache(
        tensor.name, tensor.op.graph,
        tensor.shape.as_list(), tensor.dtype)
    return state_ops.assign(snapshot_variable, tensor).op

  def _preprocess_traced_tensor(self, tensor):
    """Computes NAN/Norm/Max on TPUs before sending to CPU.

    Args:
      tensor: The tensor to be traced.
    Returns:
      A tensor that should be input to the trace_function.
    Raises:
      RuntimeError: If the signature is invalid.
    """

    def _detect_nan_inf(tensor):
      """Trace function for detecting any NaN/Inf in the tensor."""

      if tensor.dtype.is_floating:
        mask = math_ops.reduce_any(
            gen_math_ops.logical_or(
                gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
        output_tensor = cond.cond(
            mask,
            lambda: constant_op.constant([1.0]),
            lambda: constant_op.constant([0.0]))
      else:
        output_tensor = constant_op.constant([0.0])
      return output_tensor

    def _compute_signature(tensor, tf_op, cast_to_f32=True):
      if cast_to_f32:
        tensor = math_ops.cast(tensor, dtypes.float32)
      output_tensor = tf_op(tensor)
      # Return type should be scalar. Set it if it does not have the
      # information.
      if not output_tensor.get_shape().is_fully_defined():
        output_tensor = array_ops.reshape(output_tensor, [])
      return output_tensor

    def _show_size(tensor):
      # In order to check the size of a tensor.
      # Not all sizes are known at the compile time, also, different replicas
      # sometimes get different sizes of tensors.
      # Collect it here to be used in merging replica data.
      tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
      # Cast to float32, so that it can be placed into same cache with other
      # signatures.
      return math_ops.cast(tsize, dtypes.float32)

    def _show_max(tensor, cast_to_f32=True):
      # returns -inf for empty tensor
      return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)

    def _show_min(tensor, cast_to_f32=True):
      # returns inf for empty tensor
      return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)

    def _show_norm(tensor, cast_to_f32=True):
      # returns 0 for empty tensor
      return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)

    def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06):
      # returns nan for empty tensor and treats nans as non-zero numbers
      def sparsity_fn(tensor):
        non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance)
        nans = math_ops.is_nan(tensor)
        return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans))

      return _compute_signature(tensor, sparsity_fn, cast_to_f32)

    def _show_mean_and_variance(tensor, cast_to_f32=True):
      """Returns the mean and variance of the given tensor."""
      if cast_to_f32:
        tensor = math_ops.cast(tensor, dtypes.float32)
      # returns nan for empty tensor
      mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
      # The shape has to be 1. Set it if it does not have the information.
      if not mean.get_shape().is_fully_defined():
        mean = array_ops.reshape(mean, [])
      if not var.get_shape().is_fully_defined():
        var = array_ops.reshape(var, [])
      return mean, var

    def _show_max_abs(tensor, cast_to_f32=True):
      return _compute_signature(
          tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32)

    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
      return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
    if (self._parameters.trace_mode ==
        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
      return {self._parameters.trace_mode: tensor}
    if (self._parameters.trace_mode in (
        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
      return {self._parameters.trace_mode: tensor}
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
      return {self._parameters.trace_mode: array_ops.reshape(
          _show_norm(tensor), [1])}
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_HISTORY:
      return {self._parameters.trace_mode: array_ops.reshape(
          _show_norm(tensor), [1])}
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
      return {self._parameters.trace_mode: _show_max_abs(tensor)}

    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
      tensor = math_ops.cast(tensor, dtypes.float32)
      result_dict = {}
      # Call mean and variance computation here to avoid adding the same nodes
      # twice.
      if (_TT_SUMMARY_MEAN in self._signature_types() or
          _TT_SUMMARY_VAR in self._signature_types()):
        mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)

      for signature_name, _ in sorted(self._signature_types().items(),
                                      key=lambda x: x[1]):
        if signature_name == _TT_SUMMARY_NORM:
          signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
        elif signature_name == _TT_SUMMARY_MAX:
          signature_result_tensor = _show_max(tensor, cast_to_f32=False)
        elif signature_name == _TT_SUMMARY_MAX_ABS:
          signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False)
        elif signature_name == _TT_SUMMARY_MIN:
          signature_result_tensor = _show_min(tensor, cast_to_f32=False)
        elif signature_name == _TT_SUMMARY_SPARSITY:
          signature_result_tensor = _show_sparsity(tensor)
        elif signature_name == _TT_SUMMARY_SIZE:
          signature_result_tensor = _show_size(tensor)
        elif signature_name == _TT_SUMMARY_MEAN:
          signature_result_tensor = mean
        elif signature_name == _TT_SUMMARY_VAR:
          signature_result_tensor = variance
        else:
          raise ValueError('Unknown signature type :%s.' % signature_name)

        result_dict[signature_name] = signature_result_tensor
      return result_dict

    raise RuntimeError(
        'Unsupported signature for trace mode %s.'
        % self._parameters.trace_mode)

  def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
    """Makes the tensor tracing function called by outside compilation.

    Args:
      tensor_name: name of the tensor being traced.
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
    Returns:
      A function to be passed as the first argument to outside compilation.

    Raises:
      RuntimeError: If the trace mode is invalid.
    """

    def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
      """Prints a tensor value to a file.

      Args:
        tensor_name: name of the tensor being traced.
        num_elements: number of elements to print (-1 means print all).
        tensor: the tensor needs to be returned.
        output_tensor: the tensor needs to be printed.

      Returns:
        The same tensor passed via the "tensor" argument.

      Raises:
        ValueError: If tensor_name is not already in
                    tensor_trace_order.tensorname_to_cache_idx.
      """

      if self._parameters.is_brief_mode():
        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
          raise ValueError(
              'Tensor %s with name %s is not in the tensorname_to_cache_idx' %
              (tensor, tensor_name))
        msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name]
      else:
        msg = '"%s"' % tensor_name

      if self._parameters.trace_dir:
        output_path = os.path.join(
            self._parameters.trace_dir,
            _TRACE_FILE_NAME + self._get_outfile_suffix())
        output_stream = _OUTPUT_STREAM_ESCAPE + output_path
      else:
        output_stream = sys.stderr
      return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
                                  '@', self._replica_id,
                                  '\n', output_tensor, '\n',
                                  summarize=num_elements,
                                  output_stream=output_stream)

    def _show_part_tensor(tensor):
      """Trace function for printing part of the tensor."""

      return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
                           tensor, tensor)

    def _show_full_tensor(tensor):
      """Trace function for printing the entire tensor."""

      return _print_tensor(tensor_name, -1, tensor, tensor)

    if (self._parameters.trace_mode ==
        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
      return _show_part_tensor
    # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
    # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
    # performed within TPUs and only their results are transferred to CPU.
    # Simply, print the full tensor for these trace modes.
    if self._parameters.trace_mode in (
        tensor_tracer_flags.TRACE_MODE_NAN_INF,
        tensor_tracer_flags.TRACE_MODE_NORM,
        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
        tensor_tracer_flags.TRACE_MODE_SUMMARY,
        tensor_tracer_flags.TRACE_MODE_HISTORY
        ):
      return _show_full_tensor

    raise RuntimeError('Full tensor support is not available with trace mode %s'
                       %self._parameters.trace_mode)

  def _is_in_control_flow(self, op):
    """Returns true if the given op is inside a tf.cond or in tf.while_loop.

    Args:
      op: A tensorflow op that should be checked whether in control flow or not.
    Returns:
      A boolean value whether the op is in control flow or not.
    """
    return control_flow_util.IsInCond(op)

  def _is_in_outmost_while_loop(self, op):
    """Returns true if the op is at the same level with the training loop.

    Returns false if the op is in an inner while loop or if it is outside of the
    training loop.
    Args:
      op: tf.Operation

    Returns:
      A boolean.
    """
    ctxt = self._get_op_control_flow_context(op)
    outer_while_context = control_flow_util.GetContainingWhileContext(ctxt)
    return outer_while_context == control_flow_util.GetContainingWhileContext(
        self._outmost_context)

  def _should_trace_in_control_flow(self):
    """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop."""
    # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY
    # forces the execution of the traced tensors. We should not trace the ops
    # that may not be executed due to control flow.
    if self._use_temp_cache():
      return False
    elif self._tt_config.device_type == _DEVICE_TYPE_TPU:
      # On TPUs do not trace in control flow unless we use caches to store
      # intermediate values as calling outside compilation within an inner loop
      # causes errors.
      return self._use_tensor_values_cache() or self._use_tensor_buffer()
    return True

  def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
    """Returns True if we should not trace Op.

    Args:
      op_id: Topological index of the op.
      op: tf.Operation
      ops_in_exec_path: Set of operations that are in the execution path.
      report_handler: An instance of tensor_tracer_report.TTReportHandle.
    Returns:
      True if the op should not be traced, false otherwise.
    """
    if TensorTracer.while_loop_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
      return True
    if TensorTracer.control_flow_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
      return True
    if TensorTracer.unsafe_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
      return True
    if TensorTracer.device_mismatch(self._tt_config.device_type, op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
      return True
    if op not in ops_in_exec_path:
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
      return True
    # TensorTracer will not trace the operations that are in an inner while loop
    # or tf.cond when a temporary cache is used. Temporary cache adds direct
    # data dependencies to traced operations, and needs a static number of
    # traced operations. For these cases,
    # - We do not know the number of slots required when there are inner while
    # loops. TensorTracer can only trace the result of a while loop.
    # - We do not know ahead of time which branch of the tf.cond
    # will be taken, so we avoid introducing data dependencies for the
    # operations inside a tf.cond.
    # - We also cannot have a data dependency to an operation in a different
    # while context.
    if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op):
      if not self._should_trace_in_control_flow():
        report_handler.instrument_op(
            op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW))
        return True
    if self._is_user_included_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
      if tensor_tracer_flags.TT_CHECK_FILTER.value:
        logging.info('USER_INCLUDED op %s', op.name)
      return False

    if not self._inside_op_range(op_id):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
      return True
    if not self._is_interesting_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
      return True
    if self._is_user_excluded_op(op):
      report_handler.instrument_op(
          op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
      if tensor_tracer_flags.TT_CHECK_FILTER.value:
        logging.info('USER_EXCLUDED op %s', op.name)
      return True
    return False

  def _skip_tensor(self, op_id, out_tensor, report_handler):
    """Returns True if we should not trace out_tensor.

    Args:
      op_id: Topological index of the op producing tensor.
      out_tensor: tf.Tensor
      report_handler: An instance of tensor_tracer_report.TTReportHandle.
    Returns:
      True if the tensor should not be traced, false otherwise.
    """

    # Skips a tensor if the tensor has a non-numeric type.
    #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
    #         because it also excludes tensors with dtypes, bool, and
    #         float32_ref, which we actually want to trace.
    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
                                    dtypes.string])
    if out_tensor.dtype in non_numeric_tensor_types:

      report_handler.instrument_tensor(
          out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
      return True
    # Skip a tensor if it feeds a special while loop op.
    if [consumer for consumer in out_tensor.consumers() if
        TensorTracer.while_loop_op(consumer)]:
      report_handler.instrument_tensor(
          out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
      return True
    if self._is_user_included_op(out_tensor.op):
      report_handler.instrument_tensor(
          out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
      if tensor_tracer_flags.TT_CHECK_FILTER.value:
        logging.info('USER_INCLUDED tensor %s', out_tensor.name)
      return False
    if self._is_user_excluded_op(out_tensor.op):
      report_handler.instrument_tensor(
          out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
      if tensor_tracer_flags.TT_CHECK_FILTER.value:
        logging.info('USER_EXCLUDED tensor %s', out_tensor.name)
      return True
    if not out_tensor.get_shape().is_fully_defined():
      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
      # to a scalar before the outside compilation call.
      if self._parameters.trace_mode in (
          tensor_tracer_flags.TRACE_MODE_NAN_INF,
          tensor_tracer_flags.TRACE_MODE_NORM,
          tensor_tracer_flags.TRACE_MODE_HISTORY,
          tensor_tracer_flags.TRACE_MODE_MAX_ABS,
          tensor_tracer_flags.TRACE_MODE_SUMMARY
          ):
        report_handler.instrument_tensor(
            out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
        return False
      else:
        report_handler.instrument_tensor(
            out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
        return True
    rank = len(out_tensor.shape)
    if rank < 1:
      # scalar
      if self._parameters.trace_scalar_ops:
        if TensorTracer.unsafe_scalar_trace(out_tensor.op):
          report_handler.instrument_tensor(
              out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
          return True
        else:
          report_handler.instrument_tensor(
              out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
          return False
      else:
        report_handler.instrument_tensor(
            out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
        return True
    else:
      # tensor
      report_handler.instrument_tensor(
          out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
      return False

  def _filter_execution_path_operations(self, operations, fetches):
    """Returns the set of ops in the execution path to compute given fetches."""

    # If no fetch provided, then return all operations.
    if fetches is None:
      return set(operations)
    # Convert to list, if a single element is provided.
    if not isinstance(fetches, (list, tuple)):
      fetches = [fetches]
    # If a tensor is given as fetch, convert it to op.
    op_fetches = []
    for fetch in fetches:
      if isinstance(fetch, ops.Operation):
        op_fetches.append(fetch)
      elif isinstance(fetch, tensor_lib.Tensor):
        op_fetches.append(fetch.op)
      else:
        raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
                           %fetch)

    execution_path_operations = set(op_fetches)
    traverse_stack = list(op_fetches)
    while True:
      if not traverse_stack:
        break
      head_op = traverse_stack.pop()
      input_ops = [tensor_input.op for tensor_input in head_op.inputs]
      input_ops.extend(head_op.control_inputs)

      for input_op in input_ops:
        if input_op not in execution_path_operations:
          # Filter out loop condition operations, tracing them causes a cycle.
          # Trace only the loop-body.
          if TensorTracer.loop_cond_op(input_op):
            continue
          execution_path_operations.add(input_op)
          traverse_stack.append(input_op)
    return execution_path_operations

  def _determine_and_instrument_traced_tensors(self, graph_order,
                                               ops_in_exec_path,
                                               tensor_trace_points,
                                               report_handler):
    """Determines the tensors to trace and instruments the trace details.

    Args:
      graph_order: graph_order tuple containing graph (tf.graph), operations
        (list of operations), op_to_idx (op id mapping), (tensors) list of
        tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
        there is a cycle in the graph), topological_order_or_cycle (list of ops
        in topological order or list of ops creating a cycle).
      ops_in_exec_path: Set of ops in the execution path.
      tensor_trace_points: Collection of programatic tensor trace points.
      report_handler: An instance of tensor_tracer_report.TTReportHandle.
    Returns:
      List of tensors to be traced.
    """

    traced_tensors = []
    checkpoint_operations = set([tensor.op
                                 for (tensor, _) in tensor_trace_points])
    for op_id, op in enumerate(graph_order.operations):
      if checkpoint_operations and op not in checkpoint_operations:
        continue
      if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
        continue
      for i in range(len(op.outputs)):
        out_tensor = op.outputs[i]
        if not self._skip_tensor(op_id, out_tensor, report_handler):
          traced_tensors.append(out_tensor)
    return traced_tensors

  def _check_trace_files(self):
    """Checks if any requirements for trace files are satisfied."""

    if not self._parameters.trace_dir:
      # traces will be written to stderr. No need to check trace files.
      return
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
      # Output files are handled by tf.summary operations, no need to precreate
      # them.
      return
    if not gfile.Exists(self._parameters.trace_dir):
      file_io.recursive_create_dir(self._parameters.trace_dir)
      if not gfile.Exists(self._parameters.trace_dir):
        raise RuntimeError('Failed to create trace directory at %s' %
                           self._parameters.trace_dir)

  def _create_temp_cache(self, num_traced_tensors, num_signatures, graph):
    """Creates a temporary cache with the given dimensions.

    Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops
    that have shape of [num_signatures].
    Args:
      num_traced_tensors: Int, denoting total number of traced tensors.
      num_signatures: Int, denoting the number of statistics collected per
        tensors.
      graph: TensorFlow graph.
    """
    init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
                                      dtype=dtypes.float32,
                                      shape=[num_signatures])
    self._temp_cache_var[graph] = [
        init_value for _ in range(num_traced_tensors)]

  def _determine_trace_and_create_report(self, graph, ops_in_exec_path,
                                         graph_summary_tag):
    """Work needs to be done prior to TPU or CPU tracing.

    Args:
      graph: tf.graph
      ops_in_exec_path: Set of operations in the execution path.
      graph_summary_tag: the summary tag name for the given graph.
    Returns:
      An instance of tensor_tracer_report.TensorTraceOrder, containing list of
      tensors to be traced with their topological order information.
    Raises:
      RuntimeError: If opname filtering is incorrectly set.
    """

    self._check_trace_files()

    graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
    tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)

    report_handler = tensor_tracer_report.TTReportHandle()
    traced_tensors = self._determine_and_instrument_traced_tensors(
        graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
    logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors))
    if traced_tensors and tensor_tracer_flags.TT_CHECK_FILTER.value:
      raise RuntimeError('Verify ops being traced by tensor tracer.')

    tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
                                                               traced_tensors)
    num_signatures = self._num_signature_dimensions()
    # Create a cache variable if compact_tracing is used.
    if num_signatures and self._use_tensor_values_cache():
      if self._use_temp_cache():
        self._create_temp_cache(len(traced_tensors), num_signatures, graph)
      else:
        self._create_or_get_tensor_values_cache(
            _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
        if self._parameters.trace_mode in (
            tensor_tracer_flags.TRACE_MODE_HISTORY):
          self._create_or_get_tensor_history_values_cache(
              _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
    if self._parameters.trace_mode in (
        tensor_tracer_flags.TRACE_MODE_SUMMARY,
        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
      self._report_proto = report_handler.create_report_proto(
          self._tt_config, self._parameters, tensor_trace_order,
          tensor_trace_points, self._signature_types())
      if self._parameters.use_fingerprint_subdir:
        self._parameters.trace_dir = os.path.join(
            self._parameters.trace_dir, self._report_proto.fingerprint)
        logging.info('TensorTracer updating trace_dir to %s',
                     self._parameters.trace_dir)
      self._report_proto_path = report_handler.report_proto_path(
          self._parameters.trace_dir, graph_summary_tag)

      if self._parameters.report_file_path != _SKIP_REPORT_FILE:
        report_handler.write_report_proto(self._report_proto_path,
                                          self._report_proto, self._parameters)
    else:
      if self._parameters.trace_mode not in (
          tensor_tracer_flags.TRACE_MODE_HISTORY):
        report_handler.create_report(self._tt_config, self._parameters,
                                     tensor_trace_order, tensor_trace_points)
    return tensor_trace_order

  def _create_host_call(self):
    return self._parameters.trace_mode in (
        tensor_tracer_flags.TRACE_MODE_SUMMARY,
        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)

  def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream,
                             tensor_trace_order):
    """Generates a print operation to print trace inspection.

    Args:
      cache: Tensor storing the trace results for the step.
      replica_id: Tensor storing the replica id of the running core.
      step_num: Step number.
      output_stream: Where to print the outputs, e.g., file path, or sys.stderr.
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.

    Returns:
      The Op to flush the cache to file.
    """
    def _inspect_tensor(tensor):
      """Returns the text to be printed for inspection output."""
      if (self._parameters.trace_mode ==
          tensor_tracer_flags.TRACE_MODE_NAN_INF):
        return cond.cond(
            math_ops.greater(tensor, 0.0),
            lambda: 'has NaNs/Infs!',
            lambda: 'has no NaNs or Infs.')
      else:
        return tensor

    # Check if there are graph operations being profiled.
    if not tensor_trace_order.traced_tensors:
      logging.warn('Inspect mode has no tensors in the cache to check.')
      return control_flow_ops.no_op

    # Check if the cache includes any nan or inf
    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
      # Cache has 1s or 0s if the mode is NaN_INF
      step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0)
    else:
      # Cache has the actual numerics for other modes.
      step_has_nan_or_inf = math_ops.reduce_any(
          gen_math_ops.logical_or(
              gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache)))

    # Summarizing message for each step.
    step_error_message = cond.cond(
        step_has_nan_or_inf,
        lambda: 'NaNs or Infs in the step!',
        lambda: 'No numerical issues have been found for the step.')

    # No need to print core numbers if the cache is merged already.
    if self._parameters.collect_summary_per_core:
      stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->',
               step_error_message,
               'Printing tensors for mode:%s...' % self._parameters.trace_mode]
    else:
      stats = ['\n\n', 'step:', step_num, '-->', step_error_message,
               'Printing tensors for mode:%s...' % self._parameters.trace_mode]

    for tensor_name, cache_idx in sorted(
        tensor_trace_order.tensorname_to_cache_idx.items(),
        key=lambda item: item[1]):
      if self._parameters.collect_summary_per_core:
        stats.extend([
            '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
      else:
        stats.extend([
            '\n', 'step:', step_num, ',',
            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
    return logging_ops.print_v2(*stats, summarize=-1,
                                output_stream=output_stream)

  def _inspect_history_cache(self, cache, replica_id, step_num,
                             tensor_trace_order):
    """Generates a conditional print operation to log differences in tensor values.

    Args:
      cache: Tensor storing the trace results for the step.
      replica_id: Tensor storing the replica id of the running core.
      step_num: Step number.
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.

    Returns:
      The Op to flush the cache to file.
    """
    # Check if there are graph operations being profiled.
    if not tensor_trace_order.traced_tensors:
      logging.warn('TT history mode has no tensors in the cache to check.')
      return control_flow_ops.no_op

    stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num]
    diffs = []
    for tensor_name, cache_idx in sorted(
        tensor_trace_order.tensorname_to_cache_idx.items(),
        key=lambda item: item[1]):

      tensor_to_write = cache[cache_idx, 0]
      snapshot_variable = self._create_or_get_tensor_history_values_cache(
          tensor_to_write.name, tensor_to_write.op.graph,
          tensor_to_write.shape.as_list(), tensor_to_write.dtype)

      with ops.control_dependencies([snapshot_variable]):
        old_value = state_ops.assign_add(snapshot_variable, 0.0)

      with ops.control_dependencies([old_value]):
        new_value = math_ops.cast(tensor_to_write, dtypes.float32)
        delta = math_ops.abs(math_ops.subtract(old_value, new_value))
        updated = state_ops.assign(snapshot_variable, new_value)
        diffs.append(delta)
      with ops.control_dependencies([updated]):
        new_value_from_var = state_ops.assign_add(snapshot_variable, 0.0)

      stats.extend([
          '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
          tensor_name, '-->', old_value, new_value_from_var, delta])

    diff_stack = array_ops_stack.stack(diffs)
    step_max = math_ops.reduce_max(diff_stack)

    return cond.cond(
        math_ops.greater(step_max, tensor_tracer_flags.DELTA_THRESHOLD.value),
        lambda: logging_ops.print_v2(*stats, summarize=-1),
        lambda: control_flow_ops.no_op())  # pylint: disable=unnecessary-lambda

  def _get_outfile_suffix(self):
    if remote_utils.is_remote_path(self._parameters.trace_dir):
      return remote_utils.get_appendable_file_encoding()
    else:
      return ''

  def _generate_flush_cache_op(self, num_replicas, on_tpu,
                               tensor_trace_order, graph):
    """Generates an Op that will flush the cache to file.

    Args:
      num_replicas: total number of replicas.
      on_tpu: if the graph is executed on TPU.
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
      graph: TensorFlow graph.

    Returns:
      The Op to flush the cache to file.
    """

    def _flush_fun(cache, replica_id, step_num):
      """Flushes the cache to a file corresponding to replica_id."""

      def _f(file_index):
        """Generates a func that flushes the cache to a file."""
        def _print_cache():
          """Flushes the cache to a file."""
          replica_str = ('%d' % file_index)
          if self._parameters.trace_dir:
            output_path = (os.path.join(self._parameters.trace_dir,
                                        _COMPACT_TRACE_FILE_PREFIX)
                           + replica_str + self._get_outfile_suffix())
            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
          else:
            output_stream = sys.stderr

          new_step_line = _REPLICA_ID_TAG + replica_str
          print_ops = []
          if self._parameters.inspect_trace:
            if self._num_signature_dimensions() > 1:
              raise ValueError('Inspecting multi signatures are not supported.')
            if self._parameters.trace_mode in (
                tensor_tracer_flags.TRACE_MODE_HISTORY):
              print_ops.append(
                  self._inspect_history_cache(
                      cache=cache,
                      replica_id=replica_id,
                      step_num=step_num,
                      tensor_trace_order=tensor_trace_order))
            else:
              print_ops.append(
                  self._inspect_summary_cache(
                      cache=cache,
                      replica_id=replica_id,
                      step_num=step_num,
                      output_stream=output_stream,
                      tensor_trace_order=tensor_trace_order))
          else:
            for i in range(self._num_signature_dimensions()):
              print_ops.append(logging_ops.print_v2(
                  new_step_line, '\n',
                  cache[:, i], '\n',
                  summarize=-1,
                  output_stream=output_stream))
          with ops.control_dependencies(print_ops):
            return constant_op.constant(0).op
        return _print_cache

      def _eq(file_index):
        return math_ops.equal(replica_id, file_index)

      flush_op_cases = {}
      flush_op_cases[_eq(0)] = _f(0)
      for i in range(1, num_replicas):
        if on_tpu and not self._parameters.collect_summary_per_core:
          # If this is the case, the cache is already merged for all cores.
          # Only first core flushes the cache.
          flush_op_cases[_eq(i)] = control_flow_ops.no_op
        else:
          flush_op_cases[_eq(i)] = _f(i)
      # Each replica needs to determine where to write their output.
      # To do this, we check if replica_id is 0, then 1, ..., and then
      # num_replicas - 1 statically; and return the corresponding static file
      # name. We cannot simply set the file name in python, as replica_id is
      # only known during tf runtime, and we cannot create dynamic filenames.
      return control_flow_case.case(flush_op_cases, exclusive=True)

    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
    if self._use_temp_cache():
      cache_val = cache
    else:
      cache_val = cache.value()

    if on_tpu:
      # If we do not need to collect traces for all cores, merge and aggregate
      # per core trace.
      if not self._parameters.collect_summary_per_core:
        cache_val = self.merge_caches_on_tpu(cache_val)
        cache_val = self.aggregate_global_cache(cache_val)[0]

      flush_op = tpu_replication.outside_compilation(
          _flush_fun, cache_val, self._replica_id,
          array_ops.identity(training_util.get_or_create_global_step()))
    else:
      global_step = training_util.get_or_create_global_step()
      flush_op = _flush_fun(cache_val, self._replica_id, global_step)

    if self._use_temp_cache():
      with ops.control_dependencies([flush_op]):
        return constant_op.constant(0).op
    else:
      # Re-initialize the local cache variable.
      with ops.control_dependencies([flush_op]):
        reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
                                           dtype=cache.dtype,
                                           shape=cache.shape)
        assign_op = state_ops.assign(cache, reset_value).op
        with ops.control_dependencies([assign_op]):
          return constant_op.constant(0).op

  def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu,
                                 tensor_trace_order, graph):
    """Flushes the intermediate tensor values in the graph to the cache.

    Args:
      tensor_fetches: list of tensor results returned by the model_fn.
      op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
      on_tpu: if the graph is executed on TPU.
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
      graph: TensorFlow graph.

    Returns:
      An identical copy of tensor_fetches.
    """
    # Add a dependency to op and tensor fetches to make sure that all tracing
    # ops are executed before flushing trace results.
    if not tensor_trace_order.traced_tensors:
      logging.warn('No tensor values being traced. No flush cache op added.')
      return tensor_fetches
    with ops.control_dependencies(op_fetches +
                                  [tensor.op for tensor in tensor_fetches]):
      flush_cache_op = self._generate_flush_cache_op(
          self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph)
      return control_flow_ops.tuple(tensor_fetches,
                                    control_inputs=[flush_cache_op])

  def _process_tensor_fetches(self, tensor_fetches):
    """Check that tensor_fetches is not empty and have valid tensors."""
    # If none or empty list.
    if tensor_fetches is None:
      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
                         'None.')
    if not isinstance(tensor_fetches, (list, tuple)):
      tensor_fetches = [tensor_fetches]
    elif not tensor_fetches:
      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
                         'empty list.')
    fetches = []
    for fetch in tensor_fetches:
      if isinstance(fetch, tensor_lib.Tensor):
        fetches.append(fetch)
      else:
        raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
    return fetches

  def _process_op_fetches(self, op_fetches):
    """Check that op_fetches have valid ops."""
    if op_fetches is None:
      return []

    if not isinstance(op_fetches, (list, tuple)):
      op_fetches = [op_fetches]

    fetches = []
    for fetch in op_fetches:
      if isinstance(fetch, ops.Operation):
        fetches.append(fetch)
      elif isinstance(fetch, tensor_lib.Tensor):
        fetches.append(fetch.op)
      else:
        logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
                        fetch)
    return fetches

  def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
    """Changes current_fetches' format, so that it matches input_fetches."""
    if isinstance(input_fetches, tensor_lib.Tensor):
      if len(current_fetches) != 1:
        raise RuntimeError('Tensor tracer input/output fetches do not match.')
      return current_fetches[0]
    else:
      if len(current_fetches) != len(current_fetches):
        raise RuntimeError('Tensor tracer input/output fetches do not match.')
      elif isinstance(input_fetches, tuple):
        return tuple(current_fetches)
      else:
        return current_fetches

  def _get_op_control_flow_context(self, op):
    """Returns the control flow of the given op.

    Args:
      op: tf.Operation for which the control flow context is requested.
    Returns:
      op_control_flow_context: which the is control flow context of the given
      op. If the operation type is LoopExit, returns the outer control flow
      context.
    """
    # pylint: disable=protected-access
    op_control_flow_context = op._control_flow_context
    # pylint: enable=protected-access
    if control_flow_util.IsLoopExit(op):
      op_control_flow_context = op_control_flow_context.outer_context
    return op_control_flow_context

  def merge_caches_on_tpu(self, local_tpu_cache_tensor):
    """Merges the given caches on tpu.

    Args:
      local_tpu_cache_tensor: A local tensor that needs to be merged
        by concanting data from other tpu cores.
    Returns:
      A merged tf.Tensor.
    """
    x = array_ops.broadcast_to(
        local_tpu_cache_tensor,
        shape=[self._tt_config.num_replicas] +
        local_tpu_cache_tensor.shape.as_list())

    if tensor_tracer_flags.TT_SINGLE_CORE_SUMMARIES.value:
      return x

    return tpu_ops.all_to_all(
        x, concat_dimension=0, split_dimension=0,
        split_count=self._tt_config.num_replicas,
        group_assignment=[list(range(self._tt_config.num_replicas))])

  def aggregate_global_cache(self, global_tt_summary_cache):
    """Merges the given caches on tpu.

    Args:
      global_tt_summary_cache: The global tensor tracer summary cache tensor
        with shape (num_cores, num_traced_tensors, num_traced_signatures). First
        dimension corresponds to core_id, where global_tpu_cache_tensor[i]
        correspond to the local cache from core-i.
    Returns:
      An aggregated tf.Tensor.
    Raises:
      RuntimeError: if there is no aggregate function defined for a signature.
    """

    # Merge only statistics tensor, if it is any other tensor we simply,
    # concatenate them.
    agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
    signature_idx_map = self._signature_types()
    aggregation_result = []
    for signature, idx in sorted(signature_idx_map.items(),
                                 key=operator.itemgetter(1)):
      if signature not in agg_fn_map:
        raise RuntimeError('No aggregation function is defined for '
                           'signature %s.' % signature)
      # The dimensions of the statistics tensor is
      # num_cores x num_traced_tensors x num_signatures
      # value[:,:,idx] will return the portion of the tensor related
      # to signature.
      signature_tensor = global_tt_summary_cache[:, :, idx]
      # Merge it along the first (core) axis.
      agg_fn = agg_fn_map[signature]
      agg_tensor = agg_fn(signature_tensor, axis=0)
      aggregation_result.append(agg_tensor)
    # Merge results corresponding to different signatures

    merged_signatures = array_ops_stack.stack(aggregation_result)
    # merged_signatures has dimensions
    # num_signatures x num_traced_tensors, transpose it so that it
    # will match with the original structure
    # num_traced_tensors x num_signatures.
    transposed_signatures = array_ops.transpose(merged_signatures)
    # Expand 1 more dimension so that it will match with the expected
    # structure num_cores x num_traced_tensors x num_signatures.
    return array_ops.expand_dims(transposed_signatures, axis=0)

  def _prepare_host_call_fn(self, processed_t_fetches,
                            op_fetches, graph, graph_summary_tag):
    """Creates a host call function that will write the cache as tb summary.

    Args:
      processed_t_fetches: List of tensor provided to session.run.
      op_fetches: List of operations provided to session.run.
      graph: TensorFlow graph.
      graph_summary_tag: the summary_tag name for the given graph.
    Raises:
      ValueError if trace_dir is not set.
    """
    if self._parameters.trace_dir is None:
      raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
                       '--trace_dir=/model/dir')

    def _write_cache(step, event_file_suffix=None, **kwargs):
      """Writes the given caches as tensor summary.

      Args:
        step: Step tensor with dimension [num_cores].
        event_file_suffix: Event filename suffix tensor.
        **kwargs: The dictionary of tensors that needs to be written as
          summaries. Key and value pairs within kwargs correspond to the tag
          name, and tensor content that will be written using summary.write.
          The trace_modes that use this function are:
            - summary: In summary mode, kwargs includes a single (tag, content)
            pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
            variable. The dimension of the signature_cache is:
              num_cores x num_traced_tensors x num_signatures.
            - full_tensor_summary: kwargs will include all traced tensors. Tag
            and content correspond to the name of the tensor, and its actual
            content.
      Returns:
        A tf.Operation that needs to be executed for the host call dependencies.
      """
      file_suffix = _TT_EVENT_FILE_SUFFIX
      if event_file_suffix is not None:
        file_suffix = string_ops.string_join([file_suffix, event_file_suffix],
                                             separator='.')
      # TODO(deveci): Parametrize max_queue, so that flushing op can be called
      # less frequently.
      # Setting max_queue to 100 appears to be safe even when the number of
      # iterations are much lower, as the destructor of the writer flushes it.
      summary_write_ops = []
      summary_writer = summary.create_file_writer_v2(
          self._parameters.trace_dir,
          filename_suffix=file_suffix,
          max_queue=_TT_SUMMARY_MAX_QUEUE)
      graph.add_to_collection(
          TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)

      step_value = step[0]
      dt = step_value.dtype

      # The step parameter to a summary write call must be 64-bit.
      if dt.__ne__(dtypes.int64) and dt.__ne__(
          dtypes.uint64) and dt.__ne__(dtypes.float64):
        step_value = math_ops.cast(step_value, dtypes.int64)

      with summary_writer.as_default():
        summary_metadata = summary_pb2.SummaryMetadata(
            plugin_data=summary_pb2.SummaryMetadata.PluginData(
                plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
        for key, value in kwargs.items():
          # Check whether we need to compute aggregated statistics that merge
          # all cores statistics.
          if not self._parameters.collect_summary_per_core:
            # Merge only statistics tensor, if it is any other tensor we simply,
            # concatenate them.
            # Also, if there is only a single core (first dim. is 0), then skip
            # aggregation.
            if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1:
              value = self.aggregate_global_cache(value)
          with ops.control_dependencies([summary_writer.init()]):
            summary_write_ops.append(summary.write(
                _TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag,
                value, metadata=summary_metadata,
                step=step_value))
      return control_flow_ops.group(summary_write_ops)

    global_step = training_util.get_or_create_global_step()
    step = array_ops.reshape(global_step, [1])
    self._host_call_fn = {}

    host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]

    caches_to_write = {}
    with ops.control_dependencies(host_call_deps):
      all_caches = self._cache_variable_for_graph(graph)
      for cache_name, cache_variable in all_caches.items():
        # Increase the cache rank by 1, so that when host call concatenates
        # tensors from different replicas, we can identify them with [core_id].
        new_cache_shape = [1]
        new_cache_shape.extend(cache_variable.shape.as_list())
        cache = array_ops.reshape(cache_variable, new_cache_shape)
        caches_to_write[cache_name] = cache
    # Add step to parameter dictionary.
    caches_to_write['step'] = step
    # Other options without adding step to parameter dictionary are
    #  * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
    #    considers caches_to_write as a single parameter, rather than a keyword
    #    parameters.
    #  * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
    #    a syntax error.
    self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)

  def host_call_deps_and_fn(self):
    return self._host_call_fn

  def get_traced_op_names(self):
    """Returns the set of traced op names."""
    return self._traced_op_names

  def _trace_execution(self, graph,
                       tensor_fetches,
                       op_fetches=None,
                       on_tpu=True):
    """Commong tracing function for both CPU and TPUs.

    The caller function should set device_type, num_replicas,
    num_replicas_per_host, num_hosts and replica_id before calling
    _trace_execution.


    Args:
      graph: the graph of Ops executed on the TPU.
      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
        returned by model_fn given to session.run. Function must be provided
        with as least one tensor to fetch.
      op_fetches: A list of op fetches returned by model_fn given to
        session.run. op_fetches and tensor_fetches are used to determine the
        nodes that will be executed. Can be None.
      on_tpu: True if executing on TPU.

    Returns:
      tensor_fetches: an exact copy of tensor_fetches that has additional
                      dependencies.
    Raises:
      RuntimeError: If tensor_fetches is None or empty.
    """
    def _cast_unsupported_dtypes(tensor):
      """Casts tensor to a supported type."""

      if tensor.dtype.__eq__(dtypes.int64):
        # outside-compilation doesn't support int64 input yet.
        return math_ops.cast(tensor, dtypes.int32)
      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
          dtypes.float16):
        # Since host can't handle bf16, convert tensor to f32.
        return math_ops.cast(tensor, dtypes.float32)
      return tensor

    trace_mode = self._parameters.trace_mode
    device_type = self._tt_config.device_type
    # pylint: disable=protected-access
    self._outmost_context = graph._get_control_flow_context()
    # pylint: enable=protected-access

    analytics.track_usage('tensor_tracer', [trace_mode, device_type])
    TensorTracer.check_device_type(device_type)
    TensorTracer.check_trace_mode(device_type, trace_mode)
    # Check in_tensor_fetches, and op_fetches and convert them to lists.
    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
    op_fetches = self._process_op_fetches(op_fetches)
    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]

    # Filter out the operations that won't be executed.
    # if fetches=None, then ops_in_exec_path = set(operations)
    exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
                                                         all_fetches)
    graph_summary_tag = _graph_summary_tag(graph)

    # Write report file, and determine the traced tensors.
    tensor_trace_order = self._determine_trace_and_create_report(
        graph, exec_op_set, graph_summary_tag)

    tensor_fetch_set = set(processed_t_fetches)
    tracing_ops = []

    sorted_exec_op_list = list(exec_op_set)
    sorted_exec_op_list.sort(key=lambda op: op.name)
    # Trace ops only if they are in the execution path.
    for op in sorted_exec_op_list:
      for i in range(len(op.outputs)):
        out_tensor = op.outputs[i]
        tensor_name = out_tensor.name
        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
          continue
        self._traced_op_names.add(op.name)
        # Create the list of consumers before calling _preprocess_traced_tensor.
        # Otherwise, adding control input below, will introduce a cycle in the
        # graph.
        consumers = out_tensor.consumers()
        # Not all consumers may be in the exec path. Filter out the consumers
        # to keep the graph simpler.
        consumers = [cop for cop in consumers if cop in exec_op_set]

        # If there is no consumer of the tensor, there is no need to trace it;
        # unless the tensor itself is one of the fetches.
        is_a_fetched_tensor = out_tensor in tensor_fetch_set
        if (not consumers) and (not is_a_fetched_tensor):
          continue

        op_control_flow_context = self._get_op_control_flow_context(op)
        if op_control_flow_context:
          # pylint: disable=protected-access
          graph._set_control_flow_context(op_control_flow_context)
          # pylint: enable=protected-access

        processed_tensors = self._preprocess_traced_tensor(out_tensor)

        if on_tpu:
          for signature in processed_tensors.keys():
            processed_tensors[signature] = _cast_unsupported_dtypes(
                processed_tensors[signature])

        if self._use_tensor_values_cache():
          # Use a small cache (either temp cache or tf local variable) to store
          # the characteristics of the tensor.
          if self._use_temp_cache():
            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
            self._save_tensor_value_to_tmp_cache(cache_idx,
                                                 processed_tensors,
                                                 graph)
            trace_op = None
          else:
            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
            trace_op = self._save_tensor_value_to_cache_op(cache_idx,
                                                           processed_tensors,
                                                           graph)
        elif self._use_tensor_buffer():
          if len(processed_tensors) != 1:
            raise RuntimeError('Multiple stats are only allowed in compact '
                               'mode.')
          processed_out_tensor = list(processed_tensors.values())[0]
          # Store the whole tensor in a buffer.
          trace_op = self._snapshot_tensor(processed_out_tensor)
        else:

          def tpu_wrap_trace_fn(tensor, out_tensor_name):
            """Wraps the trace_fn with outside compilation if on TPUs."""
            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
                                                          tensor_trace_order)
            if on_tpu:
              return tpu_replication.outside_compilation(
                  tensor_trace_fn, tensor)
            else:
              return tensor_trace_fn(tensor)

          if len(processed_tensors) != 1:
            raise RuntimeError('Multiple stats are only allowed in compact '
                               'mode.')
          # Collecting multiple statistics are only supported in the summary
          # mode that uses compact format(self._use_tensor_values_cache = true).
          # Non-compact mode currently allows single stat per tensor.
          processed_out_tensor = next(iter(processed_tensors.values()))
          trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)

        if op_control_flow_context:
          # pylint: disable=protected-access
          graph._set_control_flow_context(self._outmost_context)
          # pylint: enable=protected-access
        if trace_op:
          if is_a_fetched_tensor:
            tracing_ops.append(trace_op)
            continue
          # Add it to all consumers, as some consumers may not be executed if
          # they are in a control flow.
          for consumer_op in consumers:
            # pylint: disable=protected-access
            consumer_op._add_control_input(trace_op)
            # pylint: enable=protected-access

    # pylint: disable=protected-access
    graph._set_control_flow_context(self._outmost_context)
    # pylint: enable=protected-access
    if tracing_ops:
      # If we are tracing a fetched tensor, their dependency is stored in
      # tracing_ops.
      processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
                                                   control_inputs=tracing_ops)
    if self._use_tensor_values_cache() or self._use_tensor_buffer():
      if self._use_temp_cache():
        # Create the temporary tf cache variable by concantanating all
        # statistics.
        graph_cache_var = self._cache_variable_for_graph(graph)
        if graph not in self._temp_cache_var:
          raise RuntimeError('graph is not in self._temp_cache_var')
        graph_cache_var[_TT_SUMMARY_TAG] = array_ops_stack.stack(
            self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures')
      if self._create_host_call():
        self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph,
                                   graph_summary_tag)
        if not on_tpu:
          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
          cache_write_op = write_cache(**caches_to_write)
          processed_t_fetches = control_flow_ops.tuple(
              processed_t_fetches, control_inputs=[cache_write_op])
          del self._host_call_fn[_TT_HOSTCALL_KEY]
        elif self._parameters.flush_summaries_with_outside_compile:
          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
          if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
            step = caches_to_write['step']
            tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
            tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
            if not self._parameters.collect_summary_per_core:
              tt_core_summary = self.aggregate_global_cache(tt_core_summary)

            def write_if_core_0(step, replica_id, tt_summary):

              return cond.cond(
                  math_ops.equal(replica_id, 0),
                  lambda: write_cache(step=step, event_file_suffix=None,  # pylint: disable=g-long-lambda
                                      tensor_tracer_summary=tt_summary),
                  control_flow_ops.no_op)

            write_op = tpu_replication.outside_compilation(
                write_if_core_0,
                step=step,
                replica_id=self._replica_id,
                tt_summary=tt_core_summary)
            processed_t_fetches = control_flow_ops.tuple(
                processed_t_fetches, control_inputs=[write_op])
            del self._host_call_fn[_TT_HOSTCALL_KEY]
          else:
            raise ValueError('Outside compiled flush in only supported for '
                             'summary mode')
      else:
        processed_t_fetches = self._flush_tensor_values_cache(
            processed_t_fetches, op_fetches, on_tpu=on_tpu,
            tensor_trace_order=tensor_trace_order,
            graph=graph)

    # processed_t_fetches is a list at this point. Convert it to the same
    # format as given in tensor_fetches.
    return self._convert_fetches_to_input_format(tensor_fetches,
                                                 processed_t_fetches)

  def trace_tpu(self, graph,
                tensor_fetches,
                op_fetches=None,
                num_replicas=None,
                num_replicas_per_host=None,
                num_hosts=None):
    """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops executed on the TPU.
      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
        returned by model_fn given to session.run. Function must be provided
        with as least one tensor to fetch.
      op_fetches: A list of op fetches returned by model_fn given to
        session.run. op_fetches and tensor_fetches are used to determine the
        nodes that will be executed. Can be None.
      num_replicas: number of replicas used on the TPU.
      num_replicas_per_host: number of replicas per TPU host.
      num_hosts: total number of TPU hosts.

    Returns:
      tensor_fetches: an exact copy of tensor_fetches that has additional
                      dependencies.
    """
    if isinstance(graph, func_graph.FuncGraph) or isinstance(
        graph, function._FuncGraph):  # pylint: disable=protected-access
      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
                      'Ignoring tracing.')
      return tensor_fetches

    if graph in TensorTracer._traced_graphs:
      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
                      'multiple calls.')
      return tensor_fetches
    else:
      TensorTracer._traced_graphs.add(graph)
    # Reset the parameters in case parameters are changed.
    self._parameters = tensor_tracer_flags.TTParameters()
    self._tt_config.device_type = _DEVICE_TYPE_TPU
    self._tt_config.num_replicas = num_replicas
    self._tt_config.num_replicas_per_host = num_replicas_per_host
    self._tt_config.num_hosts = num_hosts
    if self._tt_config.num_replicas is not None:
      if self._tt_config.num_replicas_per_host is None:
        self._tt_config.num_replicas_per_host = 8
      if self._tt_config.num_hosts is None:
        self._tt_config.num_hosts = (
            num_replicas // self._tt_config.num_replicas_per_host +
            (num_replicas % self._tt_config.num_replicas_per_host > 0))

    if self._parameters.graph_dump_path:
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
                           'graph_before_tt.pbtxt')
    with graph.as_default():
      self._add_replica_id_to_graph()
      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
                                             on_tpu=True)
    if self._parameters.graph_dump_path:
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
                           'graph_after_tt.pbtxt')
    return tensor_fetches

  def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
    """Traces the tensors generated by CPU Ops in a TF graph.

    Args:
      graph: the graph of Ops executed on the CPU.
      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
        returned by model_fn given to session.run. Function must be provided
        with as least one tensor to fetch.
      op_fetches: A list of op fetches returned by model_fn given to
        session.run. op_fetches and tensor_fetches are used to determine the
        nodes that will be executed. Can be None.

    Returns:
      tensor_fetches: an exact copy of tensor_fetches that has additional
                      dependencies.
    """
    if isinstance(graph, func_graph.FuncGraph) or isinstance(
        graph, function._FuncGraph):  # pylint: disable=protected-access
      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
                      'Ignoring tracing.')
      return tensor_fetches

    if graph in TensorTracer._traced_graphs:
      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
                      'multiple calls.')
      return tensor_fetches
    else:
      TensorTracer._traced_graphs.add(graph)
    # Reset the parameters in case parameters are changed.
    self._parameters = tensor_tracer_flags.TTParameters()

    self._tt_config.device_type = _DEVICE_TYPE_CPU
    self._tt_config.num_replicas = 1
    self._tt_config.num_replicas_per_host = 1
    self._tt_config.num_hosts = 1
    self._replica_id = 0
    if self._parameters.graph_dump_path:
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
                           'graph_before_tt.pbtxt')
    with graph.as_default():
      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
                                             on_tpu=False)
    if self._parameters.graph_dump_path:
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
                           'graph_after_tt.pbtxt')
    return tensor_fetches