deeplearning4j/deeplearning4j

View on GitHub
contrib/blas-lapack-generator/src/main/java/org/deeplearning4j/OpenblasBlasLapackGenerator.java

Summary

Maintainability
D
3 days
Test Coverage
package org.deeplearning4j;

import com.github.javaparser.ParserConfiguration;
import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.symbolsolver.JavaSymbolSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.JavaParserTypeSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver;
import com.github.javaparser.utils.SourceRoot;
import com.squareup.javapoet.*;
import org.apache.commons.io.FileUtils;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.openblas.global.openblas;
import org.bytedeco.openblas.global.openblas_nolapack;
import org.nd4j.linalg.api.blas.BLASLapackDelegator;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;

import javax.lang.model.element.Modifier;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.*;

public class OpenblasBlasLapackGenerator {

    private SourceRoot sourceRoot;
    private File rootDir;
    private File targetFile;

    private Map<String,String> casting = new HashMap<String,String>(){{
        put("LAPACKE_sgees","openblas.LAPACK_S_SELECT2");
        put("LAPACKE_dgees","openblas.LAPACK_D_SELECT2");
        put("LAPACKE_cgees","openblas.LAPACK_C_SELECT1");
        put("LAPACKE_zgees","openblas.LAPACK_Z_SELECT1");
        put("LAPACKE_sgeesx","openblas.LAPACK_S_SELECT2");
        put("LAPACKE_dgeesx","openblas.LAPACK_D_SELECT2");
        put("LAPACKE_cgeesx","openblas.LAPACK_C_SELECT1");
        put("LAPACKE_zgeesx","openblas.LAPACK_Z_SELECT1");
        put("LAPACKE_sgges","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dgges","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cgges","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zgges","openblas.LAPACK_Z_SELECT2");
        put("LAPACKE_sgges3","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dgges3","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cgges3","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zgges3","openblas.LAPACK_Z_SELECT2");
        put("LAPACKE_sggesx","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dggesx","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cggesx","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zggesx","openblas.LAPACK_Z_SELECT2");
        put("LAPACKE_sgees_work","openblas.LAPACK_S_SELECT2");
        put("LAPACKE_dgees_work","openblas.LAPACK_D_SELECT2");
        put("LAPACKE_cgees_work","openblas.LAPACK_C_SELECT1");
        put("LAPACKE_zgees_work","openblas.LAPACK_Z_SELECT1");
        put("LAPACKE_sgeesx_work","openblas.LAPACK_S_SELECT2");
        put("LAPACKE_dgeesx_work","openblas.LAPACK_D_SELECT2");
        put("LAPACKE_cgeesx_work","openblas.LAPACK_C_SELECT1");
        put("LAPACKE_zgeesx_work","openblas.LAPACK_Z_SELECT1");
        put("LAPACKE_sgges_work","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dgges_work","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cgges_work","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zgges_work","openblas.LAPACK_Z_SELECT2");
        put("LAPACKE_sgges3_work","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dgges3_work","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cgges3_work","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zgges3_work","openblas.LAPACK_Z_SELECT2");
        put("LAPACKE_sggesx_work","openblas.LAPACK_S_SELECT3");
        put("LAPACKE_dggesx_work","openblas.LAPACK_D_SELECT3");
        put("LAPACKE_cggesx_work","openblas.LAPACK_C_SELECT2");
        put("LAPACKE_zggesx_work","openblas.LAPACK_Z_SELECT2");

        put("LAPACK_sgges3","openblas.LAPACK_S_SELECT3");
        put("LAPACK_dgges3","openblas.LAPACK_D_SELECT3");
        put("LAPACK_cgges3","openblas.LAPACK_C_SELECT2");
        put("LAPACK_zgges3","openblas.LAPACK_Z_SELECT2");


        put("LAPACK_sgges","openblas.LAPACK_S_SELECT3");
        put("LAPACK_dgges","openblas.LAPACK_D_SELECT3");
        put("LAPACK_cgges","openblas.LAPACK_C_SELECT2");
        put("LAPACK_zgges","openblas.LAPACK_Z_SELECT2");


        put("LAPACK_sggesx","openblas.LAPACK_S_SELECT3");
        put("LAPACK_dggesx","openblas.LAPACK_D_SELECT3");
        put("LAPACK_cggesx","openblas.LAPACK_C_SELECT2");
        put("LAPACK_zggesx","openblas.LAPACK_Z_SELECT2");

        //LAPACK_zgeesx
        put("LAPACK_cgees","openblas.LAPACK_C_SELECT1");
        put("LAPACK_dgees","openblas.LAPACK_D_SELECT2");
        put("LAPACK_zgees","openblas.LAPACK_Z_SELECT1");
        put("LAPACK_sgees","openblas.LAPACK_S_SELECT2");


        put("LAPACK_cgeesx","openblas.LAPACK_C_SELECT1");
        put("LAPACK_dgeesx","openblas.LAPACK_D_SELECT2");
        put("LAPACK_zgeesx","openblas.LAPACK_Z_SELECT1");
        put("LAPACK_sgeesx","openblas.LAPACK_S_SELECT2");

    }};
    private static String copyright =
            "/*\n" +
                    " *  ******************************************************************************\n" +
                    " *  *\n" +
                    " *  *\n" +
                    " *  * This program and the accompanying materials are made available under the\n" +
                    " *  * terms of the Apache License, Version 2.0 which is available at\n" +
                    " *  * https://www.apache.org/licenses/LICENSE-2.0.\n" +
                    " *  *\n" +
                    " *  *  See the NOTICE file distributed with this work for additional\n" +
                    " *  *  information regarding copyright ownership.\n" +
                    " *  * Unless required by applicable law or agreed to in writing, software\n" +
                    " *  * distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n" +
                    " *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n" +
                    " *  * License for the specific language governing permissions and limitations\n" +
                    " *  * under the License.\n" +
                    " *  *\n" +
                    " *  * SPDX-License-Identifier: Apache-2.0\n" +
                    " *  *****************************************************************************\n" +
                    " */\n";
    private static String codeGenWarning =
            "\n//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================\n\n";


    public OpenblasBlasLapackGenerator(File nd4jApiRootDir) {
        this.sourceRoot = initSourceRoot(nd4jApiRootDir);
        this.rootDir = nd4jApiRootDir;
    }

    public void parse() throws Exception {
        targetFile = new File(rootDir,"org/nd4j/linalg/cpu/nativecpu/OpenblasLapackDelegator.java");
        String packageName = "org.nd4j.linalg.cpu.nativecpu";
        TypeSpec.Builder openblasLapackDelegator = TypeSpec.classBuilder("OpenblasLapackDelegator");
        openblasLapackDelegator.addModifiers(Modifier.PUBLIC);
        openblasLapackDelegator.addSuperinterface(BLASLapackDelegator.class);

        Class<BLASLapackDelegator> clazz = BLASLapackDelegator.class;
        List<Method> objectMethods = Arrays.asList(Object.class.getMethods());
        Set<MethodSpec> addedCodeLines = new HashSet<>();
        Arrays.stream(clazz.getMethods())
                .filter(input -> !objectMethods.contains(input))
                .forEach(method -> {
                    MethodSpec.Builder builder = MethodSpec.methodBuilder(
                                    method.getName()
                            ).addModifiers(Modifier.PUBLIC)
                            .returns(method.getReturnType())
                            .addAnnotation(Override.class);
                    StringBuilder codeStatement = new StringBuilder();
                    //don't return anything when void
                    if(method.getReturnType().equals(Void.TYPE)) {
                        codeStatement.append("openblas." + method.getName() + "(");

                    } else if(method.getReturnType().equals(int.class)){
                        //codeStatement.append("return 0;");
                        codeStatement.append("return openblas." + method.getName() + "(");

                    } else if(method.getReturnType().equals(double.class)) {
                        //codeStatement.append("return 0.0;");
                        codeStatement.append("return openblas." + method.getName() + "(");

                    } else if(method.getReturnType().equals(float.class)) {
                        //codeStatement.append("return 0.0f;");
                        codeStatement.append("return openblas." + method.getName() + "(");

                    }
                    else if(method.getReturnType().equals(long.class)) {
                        //codeStatement.append("return 0L;");
                        codeStatement.append("return openblas." + method.getName() + "(");

                    }

                    //TODO: LAPACK_cgees
                    //TODO: LAPACK_dgees
                    //TODO: LAPACK_zgees
                    //TODO: LAPACK_cgeesx
                    //TODO: LAPACK_dgeesx
                    //TODO: LAPACK_sgeesx
                    //TODO: LAPACK_zgeesx
                    //TODO: LAPACK_cgges
                    //TODO: LAPACK_dgges
                    //TODO: LAPACK_sgges
                    //TODO: LAPACK_zgges
                    //TODO: LAPACK_cgges3
                    //TODO: LAPACK_dgges3
                    //TODO: LAPACK_sgges3
                    //TODO: LAPACK_zgges3
                    //TODO: LAPACK_cggesx
                    //TODO: LAPACK_dggesx
                    //TODO: LAPACK_sggesx
                    //TODO: LAPACK_zggesx


                    //TODO: issue could be LAPACK_Z_SELECT_2
                    //TODO: LAPACK_S_SELECT_3
                    Arrays.stream(method.getParameters()).forEach(param -> {
                        if(casting.containsKey(method.getName()) && param.getType().equals(Pointer.class)) {
                            System.out.println("In function casting for " + method.getName());
                            codeStatement.append("((" + casting.get(method.getName()) + ")" + param.getName() + ")");
                            codeStatement.append(",");
                        } else {
                            codeStatement.append(param.getName());
                            codeStatement.append(",");
                        }

                        builder.addParameter(ParameterSpec.builder(param.getType(),param.getName())
                                .build());

                    });

                    codeStatement.append(")");

                    builder.addCode(CodeBlock
                            .builder()
                            .addStatement(codeStatement.toString().replace(",)",")"))
                            .build());

                    MethodSpec build = builder.build();
                    openblasLapackDelegator.addMethod(build);
                    addedCodeLines.add(build);


                });

        JavaFile.builder(packageName,openblasLapackDelegator.build())
                .addFileComment(copyright)
                .addStaticImport(openblas.class,"*")
                .addStaticImport(openblas_nolapack.class,"*")
                .build()
                .writeTo(rootDir);
    }


    private SourceRoot initSourceRoot(File nd4jApiRootDir) {
        CombinedTypeSolver typeSolver = new CombinedTypeSolver();
        typeSolver.add(new ReflectionTypeSolver(false));
        typeSolver.add(new JavaParserTypeSolver(nd4jApiRootDir));
        JavaSymbolSolver symbolSolver = new JavaSymbolSolver(typeSolver);
        StaticJavaParser.getConfiguration().setSymbolResolver(symbolSolver);
        SourceRoot sourceRoot = new SourceRoot(nd4jApiRootDir.toPath(),new ParserConfiguration().setSymbolResolver(symbolSolver));
        return sourceRoot;
    }

    public SourceRoot getSourceRoot() {
        return sourceRoot;
    }

    public File getRootDir() {
        return rootDir;
    }

    public File getTargetFile() {
        return targetFile;
    }

    public static void main(String...args) throws Exception {
        OpenblasBlasLapackGenerator openblasBlasLapackGenerator = new OpenblasBlasLapackGenerator(new File("../../nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java"));
        openblasBlasLapackGenerator.parse();
        String generated = FileUtils.readFileToString(openblasBlasLapackGenerator.getTargetFile(), Charset.defaultCharset());
        generated = generated.replace(";;",";");
        generated = generated.replaceAll("import static org.bytedeco.openblas.global.openblas\\.\\*","import org.bytedeco.openblas.global.openblas");
        generated = generated.replaceAll("import static org.bytedeco.openblas.global.openblas_nolapack\\.\\*","import org.bytedeco.openblas.global.openblas_nolapack");
        FileUtils.write(openblasBlasLapackGenerator.getTargetFile(),generated,Charset.defaultCharset());

    }

}