padded_states = torch.concatenate([states_seq,
                                                   torch.zeros((self._truncation_length - states_seq.shape[0],
                                                                states_seq.shape[1]))])