if self.initial_state == "Input":
            if self.cell_type == "LSTM":
                d[self.key_input_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM last hidden states (h0/c0) passed from another LSTM that will be used as initial [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
            else:
                d[self.key_input_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN last hidden states passed from another RNN that will be used as initial [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")