padded_next_states = torch.concatenate([next_states_seq,
                                                        torch.zeros((self._truncation_length - next_states_seq.shape[0],
                                                                     next_states_seq.shape[1]))])