public static int[] broadcastToShape(int[] inputShapeWithOnes, long seed) {
        Nd4j.getRandom().setSeed(seed);
        int[] shape = new int[inputShapeWithOnes.length];
        for (int i = 0; i < shape.length; i++) {
            if (inputShapeWithOnes[i] == 1) {