if zero_time:
        states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)