for state, new_state in zip(flat_state, flat_new_state):
          if isinstance(new_state, tensor_lib.Tensor):
            new_state.set_shape(state.shape)