def _validate_states(self, states):
    """Raises a value error if `states` does not have the expected shape.

    Args:
      states: A tensor.