if tensor.dtype == tf.bfloat16:
    return tf.cast(tensor, dtype=tf.float32)
  else:
    return tensor