if self.output_last_state:
            if self.cell_type == "LSTM":
                d[self.key_output_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM final hidden states (h0/c0) [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
            else:
                d[self.key_output_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")