deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java

Summary

Maintainability
C
1 day
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.imports.converters;

import dorkbox.annotation.AnnotationDefaults;
import dorkbox.annotation.AnnotationDetector;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.common.config.ND4JClassLoading;
import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.CreateView;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.OpDef;

import java.io.IOException;
import java.lang.annotation.ElementType;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;

@Slf4j
public class DifferentialFunctionClassHolder {
    private Map<String, DifferentialFunction> nodeConverters = ImportClassMapping.getOpNameMapping();
    private Map<String, DifferentialFunction> tensorFlowNames = ImportClassMapping.getTFOpMappingFunctions();
    private Map<String, DifferentialFunction> onnxNames = ImportClassMapping.getOnnxOpMappingFunctions();
    private Map<Long,Class<?>> customOpHashToClass = new HashMap<>();
    private Map<Long,Map<String,Class<?>>> customOpHashToClasses = new HashMap<>(); //Only contains ops with 1 hash to multiple classes
    private Map<String,Class<?>> udfs = new HashMap<>();
    private List<String> missingOps = new ArrayList<>();

    private Map<String,OpDescriptor> onnxOpDescriptors;
    private Map<String,OpDef> tensorflowOpDescriptors;
    private Map<String,Map<String,Field>> fieldsForFunction;

    private static final Set<String>  fieldNamesOpsIgnore = new LinkedHashSet<String>(){{
        add("extraArgs");
        add("arrayInitialized");
        add("log");
        add("inputArguments");
        add("outputArguments");
        add("outputShapes");
        add("outputVariables");
        add("tArguments");
        add("iArguments");
        add("bArguments");
        add("dArguments");
        add("hash");
        add("opName");
        add("sameDiff");
        add("ownName");
    }};
    //When determining fields/properties, where should we terminate the search?
    //We don't wan to include every single field from every single superclass
    private static final Set<Class> classesToIgnore = new HashSet<>(Arrays.<Class>asList(
            Object.class
//            BaseOp.class    //Exclude x/y/z, n, numProcessed, extraArgs, etc
    ));

    private static final Map<Class<?>,Set<String>> classFieldsToIgnore = new HashMap<>();
    static {
        classFieldsToIgnore.put(BaseOp.class, new HashSet<>(Arrays.asList("x", "y", "z", "n", "numProcessed", "xVertexId", "yVertexId", "zVertexId", "extraArgz")));
    }

    @Getter
    private int countTotalTfOps;
    @Getter
    private int countTotalMappedOps;

    private static DifferentialFunctionClassHolder INSTANCE = new DifferentialFunctionClassHolder();

    /**
     * Get the fields for a given {@link DifferentialFunction}
     * @param function the function to get the fields for
     * @return the fields for a given function
     */
    public Map<String,Field> getFieldsForFunction(DifferentialFunction function) {
        if(!fieldsForFunction.containsKey(function.getClass().getName())) {
            return Collections.emptyMap();
        }
        return fieldsForFunction.get(function.getClass().getName());
    }

    /**
     * Get the op definition of a given
     * tensorflow op.
     *
     * Note that if the name does not exist,
     * an {@link ND4JIllegalStateException} will be thrown
     * @param name the name of the op
     * @return the op definition for a given op
     */
    public OpDef getOpDefByTensorflowName(String name) {
        if(!tensorflowOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }

        return tensorflowOpDescriptors.get(name);
    }

    /**
     * Get the op definition of a given
     * onnx op
     * Note that if the name does not exist,
     * an {@link ND4JIllegalStateException}
     * will be thrown.
     * @param name the name of the op
     * @return the op definition for a given op
     */
    public OpDescriptor getOpDescriptorForOnnx(String name) {
        if(!onnxOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }

        return onnxOpDescriptors.get(name);
    }

    /**
     * Get the
     * @param tensorflowName
     * @return
     */
    public DifferentialFunction getOpWithTensorflowName(String tensorflowName) {
        return tensorFlowNames.get(tensorflowName);
    }

    public DifferentialFunction getOpWithOnnxName(String onnxName) {
        return onnxNames.get(onnxName);
    }

    private DifferentialFunctionClassHolder() {
        fieldsForFunction = new LinkedHashMap<>();

        for(DifferentialFunction df : ImportClassMapping.getOpNameMapping().values()){
            if(df == null || df.opName() == null) {
                continue;
            }
            try {
                //accumulate the field names for a given function
                //this is mainly used in import
                Map<String, Field> fieldNames = new LinkedHashMap<>();
                Class<? extends DifferentialFunction> current = df.getClass();
                val fields = new ArrayList<Field>();
                boolean isFirst = true;

                while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) {

                    if (df.isConfigProperties() && isFirst) {

                        String fieldName = df.configFieldName();

                        if(fieldName == null)
                            fieldName = "config";

                        Field configField = null;
                        try{
                            configField = current.getDeclaredField(fieldName);
                        } catch (NoSuchFieldException e){
                            Class<?> currentConfig = current.getSuperclass();

                            // find a config field in superclasses
                            while(currentConfig.getSuperclass() != null){
                                try {
                                    configField = currentConfig.getDeclaredField(fieldName);
                                    break;
                                } catch (NoSuchFieldException e2){
                                    currentConfig = currentConfig.getSuperclass();
                                }
                            }
                        }

                        if(configField == null)
                            continue;

                        val configFieldClass = configField.getType();

                        for (val field : configFieldClass.getDeclaredFields()) {
                            if (!Modifier.isStatic(field.getModifiers()) && !fieldNamesOpsIgnore.contains(field.getName()) &&
                                    (!classFieldsToIgnore.containsKey(current) || !classFieldsToIgnore.get(current).contains(field.getName()))) {
                                fields.add(field);
                                field.setAccessible(true);
                                if (fieldNames.containsKey(field.getName())) {
                                    throw new IllegalStateException("Field with name " + field.getName() + " exists for multiple classes: "
                                            + fieldNames.get(field.getName()).getDeclaringClass().getName() + " and " + field.getDeclaringClass().getName());
                                }
                                fieldNames.put(field.getName(), field);
                            }
                        }
                    } else {
                        for (Field field : current.getDeclaredFields()) {
                            if (!Modifier.isStatic(field.getModifiers()) && !fieldNamesOpsIgnore.contains(field.getName()) &&
                                    (!classFieldsToIgnore.containsKey(current) || !classFieldsToIgnore.get(current).contains(field.getName()))) {
                                fields.add(field);
                                field.setAccessible(true);
                                if (fieldNames.containsKey(field.getName())) {
                                    throw new IllegalStateException("Field with name " + field.getName() + " exists for multiple classes: "
                                            + fieldNames.get(field.getName()).getDeclaringClass().getName() + " and " + field.getDeclaringClass().getName());
                                }
                                fieldNames.put(field.getName(), field);
                            }
                        }
                    }

                    // do something with current's fields
                    current = (Class<? extends DifferentialFunction>) current.getSuperclass();
                    isFirst = false;

                }

                fieldsForFunction.put(df.getClass().getName(), fieldNames);
            } catch (NoOpNameFoundException e) {
                log.trace("Skipping function  " + df.getClass());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        //get the op descriptors for onnx and tensorflow
        //this is used when validating operations
        try {
            tensorflowOpDescriptors = TensorflowDescriptorParser.opDescs();
            onnxOpDescriptors = OnnxDescriptorParser.onnxOpDescriptors();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }


        val map = new HashMap<>(Nd4j.getExecutioner().getCustomOperations());
        val set = map.keySet();
        set.removeAll(nodeConverters.keySet());
        missingOps.addAll(set);
        Collections.sort(missingOps);
        //log.debug("Missing " + set.size() + " ops!");

        countTotalTfOps = tensorflowOpDescriptors.size();

        //Work out total number of TF ops mapped
        Set<String> tfMappedOps = new HashSet<>();
        for(DifferentialFunction df : nodeConverters.values()){
            try{
                String[] tfNames = df.tensorflowNames();
                Collections.addAll(tfMappedOps, tfNames);
            } catch (NoOpNameFoundException e){
                //Ignore
            }
        }
        countTotalMappedOps = tfMappedOps.size();

        //Get custom ops - map from hash to class
        Map<String,CustomOpDescriptor> descriptorMap = Nd4j.getExecutioner().getCustomOperations();
        Set<Long> multiClassHashes = new HashSet<>();
        for (Map.Entry<String, CustomOpDescriptor> e : descriptorMap.entrySet()) {
            String name = e.getKey();
            DifferentialFunction df = getInstance(name);

            if (df == null) {
                //Can be no class for 2 reasons:
                //(a) op name aliases
                //(b) libnd4j ops with no corresponding ND4J op class
                continue;
            }

            if (!CustomOp.class.isAssignableFrom(df.getClass())) {
                //Not a custom op class
                continue;
            }

            long h = e.getValue().getHash();
            if (customOpHashToClass.containsKey(h)) {
                //One op hash mapped to multiple classes
                multiClassHashes.add(h);
            }
            customOpHashToClass.put(e.getValue().getHash(), df.getClass());
        }

        for (Map.Entry<String, CustomOpDescriptor> e : descriptorMap.entrySet()) {
            long h = e.getValue().getHash();
            if (multiClassHashes.contains(h)) {
                if (!customOpHashToClasses.containsKey(h)) {
                    customOpHashToClasses.put(h, new HashMap<>());
                }
                Map<String, Class<?>> m = customOpHashToClasses.get(h);
                String name = e.getKey();
                DifferentialFunction df = getInstance(name);
                if(df == null)
                    continue;
                m.put(e.getKey(), df.getClass());
            }
        }



        try {
            // Get a list of all classes annotated with @UserDefinedOp,
            if(System.getProperties().containsKey(ND4JSystemProperties.UDF_NAME_SPACES)) {
                String[] packageNames = System.getProperty(ND4JSystemProperties.UDF_NAME_SPACES).split(",");
                List<Class<?>> classModules = AnnotationDetector.scanClassPath(ND4JClassLoading.getNd4jClassloader(),packageNames)
                        .forAnnotations(UserDefinedOp.class)  // one or more annotations
                        .on(ElementType.TYPE) // optional, default ElementType.TYPE. One ore more element types
                        .collect(AnnotationDefaults.getType);
                classModules.forEach(udf ->  {
                    try {
                        UserDefinedCustomOp o = (UserDefinedCustomOp) udf.newInstance();
                        udfs.put(o.opName(),udf);
                    } catch (InstantiationException e) {
                        throw new RuntimeException(e);
                    } catch (IllegalAccessException e) {
                        throw new RuntimeException(e);
                    }
                });
            }

        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to start the client", e);
        }

    }


    /***
     * Returns the missing onnx ops
     * @return
     */
    public Set<String> missingOnnxOps() {
        Set<String> copy = new HashSet<>(onnxOpDescriptors.keySet());
        copy.removeAll(onnxNames.keySet());
        return copy;
    }


    /***
     * Returns the missing tensorflow ops
     * @return
     */
    public Set<String> missingTensorflowOps() {
        Set<String> copy = new HashSet<>(tensorflowOpDescriptors.keySet());
        copy.removeAll(tensorFlowNames.keySet());
        return copy;
    }

    /**
     * Returns the missing ops
     * for c++ vs java.
     * @return
     */
    public List<String> missingOps() {
        return missingOps;
    }

    /**
     *
     * @param name
     * @return
     */
    public boolean hasName(String name) {
        return nodeConverters.containsKey(name);
    }


    public Set<String> opNames() {
        return nodeConverters.keySet();
    }

    /**
     *
     * @param name
     * @return
     */
    public DifferentialFunction getInstance(String name) {
        return nodeConverters.get(name);
    }

    public Class<?> customOpClassForHashAndName(long customOpHash, String name) {
        switch (name) {
            case CreateView.OP_NAME:
                return CreateView.class;
            case Enter.OP_NAME:
                return Enter.class;
            case Exit.OP_NAME:
                return Exit.class;
            case NextIteration.OP_NAME:
                return NextIteration.class;
            case Merge.OP_NAME:
                return Merge.class;
            case Switch.OP_NAME:
                return Switch.class;
            case LoopCond.OP_NAME:
                return LoopCond.class;
            case ExternalErrorsFunction.OP_NAME:
                return ExternalErrorsFunction.class;
            default:
                if(udfs.containsKey(name)) {
                    return udfs.get(name);
                }
                if(customOpHashToClasses.containsKey(customOpHash)) {
                    return customOpHashToClasses.get(customOpHash).get(name);
                } else if(customOpHashToClass.containsKey(customOpHash)) {
                    return customOpHashToClass.get(customOpHash);
                } else if(ImportClassMapping.getOpNameMapping().containsKey(name)) {
                    return ImportClassMapping.getOpNameMapping().get(name).getClass();
                } else {
                    throw new IllegalStateException("No op known for hash: " + customOpHash + " and name " + name);
                }
        }

    }

    public static DifferentialFunctionClassHolder getInstance() {
        return INSTANCE;
    }

    public Map<String,DifferentialFunction> getTensorFlowNames(){
        return Collections.unmodifiableMap(tensorFlowNames);
    }
}