for (int i = 0; i < dimensions.length; i++) {
            tadLength *= tensor.size(dimensions[i]);
            shape[i] = tensor.size(dimensions[i]);
        }