deeplearning4j/deeplearning4j

View on GitHub
codegen/op-codegen/src/main/java/org/nd4j/codegen/cli/PicoCliCodeGen.java

Summary

Maintainability
B
5 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  * See the NOTICE file distributed with this work for additional
 *  * information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.codegen.cli;

import com.beust.jcommander.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.codegen.Namespace;
import org.nd4j.codegen.api.LossReduce;
import org.nd4j.linalg.api.buffer.DataType;
import picocli.CommandLine;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * Planned CLI for generating classes
 */
@Slf4j
public class PicoCliCodeGen {
    private static final String relativePath = "nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/";
    private static final String allProjects = "all";


    @Parameter(names = "-dir", description = "Root directory of deeplearning4j mono repo")
    private String repoRootDir;

    @Parameter(names = "-docsdir", description = "Root directory for generated docs")
    private String docsdir;

    @Parameter(names = "-namespaces", description = "List of namespaces to generate, or 'ALL' to generate all namespaces", required = true)
    private List<String> namespaces;



    private void generateNamespaces() {

        List<Namespace> usedNamespaces = new ArrayList<>();

        for (String s : namespaces) {
            if ("all".equalsIgnoreCase(s)) {
                Collections.addAll(usedNamespaces, Namespace.values());
                break;
            }


            CommandLine.Model.CommandSpec commandSpec = CommandLine.Model.CommandSpec.create();

            int cnt = 0;
            for (int i = 0; i < usedNamespaces.size(); ++i) {
                Namespace ns = usedNamespaces.get(i);
                CommandLine.Model.CommandSpec subCommand = CommandLine.Model.CommandSpec.create();
                commandSpec.addSubcommand(ns.name(), subCommand);
                ns.getNamespace().getOps().forEach(op -> {
                    CommandLine.Model.CommandSpec commandSpec1 = CommandLine.Model.CommandSpec.create();
                    subCommand.addSubcommand(op.name(), commandSpec1);
                    op.inputs().forEach(input -> {
                        //TODO: Add SDVariable converter for picocli and figure out where to put that converter
                        commandSpec1.addOption(CommandLine.Model.OptionSpec.builder("--" + input.getName())
                                .type(SDVariable.class)
                                .required(true)
                                .description(input.getDescription())
                                .build());
                    });

                    op.getArgs().forEach(arg -> {
                        CommandLine.Model.OptionSpec.Builder builder = CommandLine.Model.OptionSpec.builder("--" + arg.getName())
                                .description(arg.getDescription());

                        switch (arg.getType()) {
                            case INT:
                                builder.type(Integer.class);
                                break;
                            case BOOL:
                                builder.type(Boolean.class);
                                break;
                            case ENUM:
                                break;
                            case LONG:
                                builder.type(Long.class);
                                break;
                            case STRING:
                                builder.type(String.class);
                                break;
                            case NDARRAY:
                                break;
                            case NUMERIC:
                                break;
                            case CONDITION:
                                break;
                            case DATA_TYPE:
                                builder.type(DataType.class);
                                break;
                            case LOSS_REDUCE:
                                builder.type(LossReduce.class);
                                break;
                            case FLOATING_POINT:
                                break;
                        }

                        builder.required(arg.getDefaultValue() == null);

                        if (arg.getDefaultValue() != null) {
                            builder.defaultValue(arg.getDefaultValue().toString());
                        }

                        commandSpec1.addOption(builder.build());
                    });


                });
                log.info("Starting generation of namespace: {}", ns);

                ++cnt;
            }


            log.info("Complete - generated {} namespaces", cnt);
        }
    }


    public static void main(String[] args) throws Exception {
        new CLI().runMain(args);
    }

    public void runMain(String[] args) throws Exception {
        JCommander.newBuilder()
                .addObject(this)
                .build()
                .parse(args);

        // Either root directory for source code generation or docs directory must be present. If root directory is
        // absenbt - then it's "generate docs only" mode.
        if (StringUtils.isEmpty(repoRootDir) && StringUtils.isEmpty(docsdir)) {
            throw new IllegalStateException("Provide one or both of arguments : -dir, -docsdir");
        }

        File outputDir = null;
        if (StringUtils.isNotEmpty(repoRootDir)) {
            //First: Check root directory.
            File dir = new File(repoRootDir);
            if (!dir.exists() || !dir.isDirectory()) {
                throw new IllegalStateException("Provided root directory does not exist (or not a directory): " + dir.getAbsolutePath());
            }

            outputDir = new File(dir, relativePath);
            if (!outputDir.exists() || !dir.isDirectory()) {
                throw new IllegalStateException("Expected output directory does not exist: " + outputDir.getAbsolutePath());
            }
        }

        if(namespaces == null || namespaces.isEmpty() ) {
            throw new IllegalStateException("No namespaces were provided");
        }

        generateNamespaces();

    }
}