if zero_time:
      states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)