deeplearning4j/deeplearning4j

View on GitHub
contrib/blas-lapack-generator/src/main/java/org/deeplearning4j/BlasLapackGenerator.java

Summary

Maintainability
B
4 hrs
Test Coverage
package org.deeplearning4j;

import com.github.javaparser.ParserConfiguration;
import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.symbolsolver.JavaSymbolSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.JavaParserTypeSolver;
import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver;
import com.github.javaparser.utils.SourceRoot;
import com.squareup.javapoet.*;
import org.apache.commons.io.FileUtils;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.openblas.global.openblas;

import javax.lang.model.element.Modifier;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;

public class BlasLapackGenerator {

    private SourceRoot sourceRoot;
    private File rootDir;
    private File targetFile;

    private static String copyright =
            "/*\n" +
                    " *  ******************************************************************************\n" +
                    " *  *\n" +
                    " *  *\n" +
                    " *  * This program and the accompanying materials are made available under the\n" +
                    " *  * terms of the Apache License, Version 2.0 which is available at\n" +
                    " *  * https://www.apache.org/licenses/LICENSE-2.0.\n" +
                    " *  *\n" +
                    " *  *  See the NOTICE file distributed with this work for additional\n" +
                    " *  *  information regarding copyright ownership.\n" +
                    " *  * Unless required by applicable law or agreed to in writing, software\n" +
                    " *  * distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT\n" +
                    " *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\n" +
                    " *  * License for the specific language governing permissions and limitations\n" +
                    " *  * under the License.\n" +
                    " *  *\n" +
                    " *  * SPDX-License-Identifier: Apache-2.0\n" +
                    " *  *****************************************************************************\n" +
                    " */\n";
    private static String codeGenWarning =
            "\n//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================\n\n";


    public BlasLapackGenerator(File nd4jApiRootDir) {
        this.sourceRoot = initSourceRoot(nd4jApiRootDir);
        this.rootDir = nd4jApiRootDir;
    }

    public SourceRoot getSourceRoot() {
        return sourceRoot;
    }

    public void setSourceRoot(SourceRoot sourceRoot) {
        this.sourceRoot = sourceRoot;
    }

    public File getTargetFile() {
        return targetFile;
    }

    public void setTargetFile(File targetFile) {
        this.targetFile = targetFile;
    }

    public void parse() throws Exception {
        targetFile = new File(rootDir,"org/nd4j/linalg/api/blas/BLASLapackDelegator.java");
        String packageName = "org.nd4j.linalg.api.blas";
        TypeSpec.Builder openblasLapackDelegator = TypeSpec.interfaceBuilder("BLASLapackDelegator");
        openblasLapackDelegator.addModifiers(Modifier.PUBLIC);
        Class<openblas> clazz = openblas.class;
        List<Method> objectMethods = Arrays.asList(Object.class.getMethods());
        Arrays.stream(clazz.getMethods())
                .filter(input -> !objectMethods.contains(input))
                .filter(input -> !input.getName().equals("map") && !input.getName().equals("init"))
                .forEach(method -> {
                    MethodSpec.Builder builder = MethodSpec.methodBuilder(
                                    method.getName()
                            ).returns(method.getReturnType())
                            .addModifiers(Modifier.DEFAULT,Modifier.PUBLIC);
                    Arrays.stream(method.getParameters()).forEach(param -> {
                        builder.addParameter(ParameterSpec.builder(
                                !lapackType(param.getType()) ?
                                        TypeName.get(param.getType()) :
                                TypeName.get(Pointer.class),
                                param.getName()
                        ).build());
                    });

                    openblasLapackDelegator.addMethod(builder.build());
                });

        JavaFile finalFile = JavaFile.builder(packageName, openblasLapackDelegator.build())
                .addFileComment(copyright)
                .build();
        finalFile
                .writeTo(rootDir);
    }

    private boolean lapackType(Class<?> clazz) {
        return clazz.equals(openblas.LAPACK_C_SELECT1.class) ||
                clazz.equals(openblas.LAPACK_C_SELECT2.class) ||
                clazz.equals(openblas.LAPACK_D_SELECT2.class) ||
                clazz.equals(openblas.LAPACK_S_SELECT2.class) ||
                clazz.equals(openblas.LAPACK_Z_SELECT1.class)
                || clazz.equals(openblas.LAPACK_Z_SELECT2.class) ||
                clazz.equals(openblas.LAPACK_D_SELECT3.class) ||
                clazz.equals(openblas.LAPACK_S_SELECT3.class);
    }


    private SourceRoot initSourceRoot(File nd4jApiRootDir) {
        CombinedTypeSolver typeSolver = new CombinedTypeSolver();
        typeSolver.add(new ReflectionTypeSolver(false));
        typeSolver.add(new JavaParserTypeSolver(nd4jApiRootDir));
        JavaSymbolSolver symbolSolver = new JavaSymbolSolver(typeSolver);
        StaticJavaParser.getConfiguration().setSymbolResolver(symbolSolver);
        SourceRoot sourceRoot = new SourceRoot(nd4jApiRootDir.toPath(),new ParserConfiguration().setSymbolResolver(symbolSolver));
        return sourceRoot;
    }


    public static void main(String...args) throws Exception {
        BlasLapackGenerator blasLapackGenerator = new BlasLapackGenerator(new File("../../nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/"));
        blasLapackGenerator.parse();
        String generated = FileUtils.readFileToString(blasLapackGenerator.getTargetFile(), Charset.defaultCharset());
        generated = generated.replaceAll("\\{\\s+\\}",";");
        generated = generated.replace("default","");
        FileUtils.write(blasLapackGenerator.getTargetFile(),generated);

    }

}