def _zero_state_tensors(state_size, batch_size, dtype):
  """Create tensors of zeros based on state_size, batch_size, and dtype."""

  def get_state_shape(s):
    """Combine s with batch_size to get a proper tensor shape."""