val k = Input(NUMERIC, "keys") { description = "input 3D array \"keys\" of shape [batchSize, featureKeys, timesteps]\n" +
                "or 4D array of shape [batchSize, numHeads, featureKeys, timesteps]" }