if (shape[0] == 1) {
            //Edge case: miniBatchSize = 1
            output2d = output.tensorAlongDimension(0, 1, 2).permutei(1, 0);
        } else if (shape[2] == 1) {
            //Edge case: timeSeriesLength = 1