deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.models.sequencevectors.graph.walkers.impl;

import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.sequencevectors.graph.enums.SamplingMode;
import org.deeplearning4j.models.sequencevectors.graph.primitives.IGraph;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.common.util.ArrayUtil;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class NearestVertexWalker<V extends SequenceElement> implements GraphWalker<V> {
    @Getter
    protected IGraph<V, ?> sourceGraph;
    protected int walkLength = 0;
    protected long seed = 0;
    protected SamplingMode samplingMode = SamplingMode.RANDOM;
    protected int[] order;
    protected Random rng;
    protected int depth;

    private AtomicInteger position = new AtomicInteger(0);

    protected NearestVertexWalker() {

    }

    @Override
    public boolean hasNext() {
        return position.get() < order.length;
    }

    @Override
    public Sequence<V> next() {
        return walk(sourceGraph.getVertex(order[position.getAndIncrement()]), 1);
    }

    @Override
    public void reset(boolean shuffle) {
        position.set(0);
        if (shuffle) {
            log.trace("Calling shuffle() on entries...");
            // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
            for (int i = order.length - 1; i > 0; i--) {
                int j = rng.nextInt(i + 1);
                int temp = order[j];
                order[j] = order[i];
                order[i] = temp;
            }
        }
    }

    protected Sequence<V> walk(Vertex<V> node, int cDepth) {
        Sequence<V> sequence = new Sequence<>();

        int idx = node.vertexID();
        List<Vertex<V>> vertices = sourceGraph.getConnectedVertices(idx);

        sequence.setSequenceLabel(node.getValue());

        if (walkLength == 0) {
            // if walk is unlimited - we use all connected vertices as is
            for (Vertex<V> vertex : vertices)
                sequence.addElement(vertex.getValue());
        } else {
            // if walks are limited, we care about sampling mode
            switch (samplingMode) {
                case MAX_POPULARITY: {
                    Collections.sort(vertices, new VertexComparator<>(sourceGraph));
                    for (int i = 0; i < walkLength; i++) {
                        sequence.addElement(vertices.get(i).getValue());

                        // going for one more depth level
                        if (depth > 1 && cDepth < depth) {
                            Sequence<V> nextDepth = walk(vertices.get(i), ++cDepth);
                            for (V element : nextDepth.getElements()) {
                                if (sequence.getElementByLabel(element.getLabel()) == null)
                                    sequence.addElement(element);
                            }
                        }
                    }

                }
                case MEDIAN_POPULARITY: {
                    Collections.sort(vertices, new VertexComparator<>(sourceGraph));
                    for (int i = (vertices.size() / 2) - (walkLength / 2), e = 0; e < walkLength
                                    && i < vertices.size(); i++, e++) {
                        sequence.addElement(vertices.get(i).getValue());

                        // going for one more depth level
                        if (depth > 1 && cDepth < depth) {
                            Sequence<V> nextDepth = walk(vertices.get(i), ++cDepth);
                            for (V element : nextDepth.getElements()) {
                                if (sequence.getElementByLabel(element.getLabel()) == null)
                                    sequence.addElement(element);
                            }
                        }
                    }

                }
                case MIN_POPULARITY: {
                    Collections.sort(vertices, new VertexComparator<>(sourceGraph));
                    for (int i = vertices.size(), e = 0; e < walkLength && i >= 0; i--, e++) {
                        sequence.addElement(vertices.get(i).getValue());

                        // going for one more depth level
                        if (depth > 1 && cDepth < depth) {
                            Sequence<V> nextDepth = walk(vertices.get(i), ++cDepth);
                            for (V element : nextDepth.getElements()) {
                                if (sequence.getElementByLabel(element.getLabel()) == null)
                                    sequence.addElement(element);
                            }
                        }
                    }
                }
                case RANDOM: {
                    // we randomly sample some number of connected vertices
                    if (vertices.size() <= walkLength)
                        for (Vertex<V> vertex : vertices)
                            sequence.addElement(vertex.getValue());
                    else {
                        Set<V> elements = new HashSet<>();
                        while (elements.size() < walkLength) {
                            Vertex<V> vertex = ArrayUtil.getRandomElement(vertices);
                            elements.add(vertex.getValue());

                            // going for one more depth level
                            if (depth > 1 && cDepth < depth) {
                                Sequence<V> nextDepth = walk(vertex, ++cDepth);
                                for (V element : nextDepth.getElements()) {
                                    if (sequence.getElementByLabel(element.getLabel()) == null)
                                        sequence.addElement(element);
                                }
                            }
                        }

                        sequence.addElements(elements);
                    }
                }
                    break;
                default:
                    throw new ND4JIllegalStateException("Unknown sampling mode was passed in: [" + samplingMode + "]");
            }
        }

        return sequence;
    }

    @Override
    public boolean isLabelEnabled() {
        return true;
    }

    public static class Builder<V extends SequenceElement> {
        protected int walkLength = 0;
        protected IGraph<V, ?> sourceGraph;
        protected SamplingMode samplingMode = SamplingMode.RANDOM;
        protected long seed;
        protected int depth = 1;

        public Builder(@NonNull IGraph<V, ?> graph) {
            this.sourceGraph = graph;
        }

        public Builder setSeed(long seed) {
            this.seed = seed;
            return this;
        }

        /**
         * This method defines maximal number of nodes to be visited during walk.
         *
         * PLEASE NOTE: If set to 0 - no limits will be used.
         *
         * Default value: 0
         * @param length
         * @return
         */
        public Builder setWalkLength(int length) {
            walkLength = length;
            return this;
        }

        /**
         * This method specifies, how deep walker goes from starting point
         *
         * Default value: 1
         * @param depth
         * @return
         */
        public Builder setDepth(int depth) {
            this.depth = depth;
            return this;
        }

        /**
         * This method defines sorting which will be used to generate walks.
         *
         * PLEASE NOTE: This option has effect only if walkLength is limited (>0).
         *
         * @param mode
         * @return
         */
        public Builder setSamplingMode(@NonNull SamplingMode mode) {
            this.samplingMode = mode;
            return this;
        }

        /**
         * This method returns you new GraphWalker instance
         *
         * @return
         */
        public NearestVertexWalker<V> build() {
            NearestVertexWalker<V> walker = new NearestVertexWalker<>();
            walker.sourceGraph = this.sourceGraph;
            walker.walkLength = this.walkLength;
            walker.samplingMode = this.samplingMode;
            walker.depth = this.depth;

            walker.order = new int[sourceGraph.numVertices()];
            for (int i = 0; i < walker.order.length; i++) {
                walker.order[i] = i;
            }

            walker.rng = new Random(seed);

            walker.reset(true);

            return walker;
        }
    }

    protected class VertexComparator<V extends SequenceElement, E extends Number> implements Comparator<Vertex<V>> {
        private IGraph<V, E> graph;

        public VertexComparator(@NonNull IGraph<V, E> graph) {
            this.graph = graph;
        }

        @Override
        public int compare(Vertex<V> o1, Vertex<V> o2) {
            return Integer.compare(graph.getConnectedVertices(o2.vertexID()).size(),
                            graph.getConnectedVertices(o1.vertexID()).size());
        }
    }
}