deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/Graph.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.graph;

import org.deeplearning4j.graph.api.BaseGraph;
import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.exception.NoEdgesException;
import org.deeplearning4j.graph.vertexfactory.VertexFactory;

import java.lang.reflect.Array;
import java.util.*;

public class Graph<V, E> extends BaseGraph<V, E> {
    private boolean allowMultipleEdges;
    private List<Edge<E>>[] edges; //edge[i].get(j).to = k, then edge from i -> k
    private List<Vertex<V>> vertices;

    public Graph(int numVertices, VertexFactory<V> vertexFactory) {
        this(numVertices, false, vertexFactory);
    }

    @SuppressWarnings("unchecked")
    public Graph(int numVertices, boolean allowMultipleEdges, VertexFactory<V> vertexFactory) {
        if (numVertices <= 0)
            throw new IllegalArgumentException();
        this.allowMultipleEdges = allowMultipleEdges;

        vertices = new ArrayList<>(numVertices);
        for (int i = 0; i < numVertices; i++)
            vertices.add(vertexFactory.create(i));

        edges = (List<Edge<E>>[]) Array.newInstance(List.class, numVertices);
    }

    @SuppressWarnings("unchecked")
    public Graph(List<Vertex<V>> vertices, boolean allowMultipleEdges) {
        this.vertices = new ArrayList<>(vertices);
        this.allowMultipleEdges = allowMultipleEdges;
        edges = (List<Edge<E>>[]) Array.newInstance(List.class, vertices.size());
    }

    public Graph(List<Vertex<V>> vertices) {
        this(vertices, false);
    }

    @Override
    public int numVertices() {
        return vertices.size();
    }

    @Override
    public Vertex<V> getVertex(int idx) {
        if (idx < 0 || idx >= vertices.size())
            throw new IllegalArgumentException("Invalid index: " + idx);
        return vertices.get(idx);
    }

    @Override
    public List<Vertex<V>> getVertices(int[] indexes) {
        List<Vertex<V>> out = new ArrayList<>(indexes.length);
        for (int i : indexes)
            out.add(getVertex(i));
        return out;
    }

    @Override
    public List<Vertex<V>> getVertices(int from, int to) {
        if (to < from || from < 0 || to >= vertices.size())
            throw new IllegalArgumentException("Invalid range: from=" + from + ", to=" + to);
        List<Vertex<V>> out = new ArrayList<>(to - from + 1);
        for (int i = from; i <= to; i++)
            out.add(getVertex(i));
        return out;
    }

    @Override
    public void addEdge(Edge<E> edge) {
        if (edge.getFrom() < 0 || edge.getTo() >= vertices.size())
            throw new IllegalArgumentException("Invalid edge: " + edge + ", from/to indexes out of range");

        List<Edge<E>> fromList = edges[edge.getFrom()];
        if (fromList == null) {
            fromList = new ArrayList<>();
            edges[edge.getFrom()] = fromList;
        }
        addEdgeHelper(edge, fromList);

        if (edge.isDirected())
            return;

        //Add other way too (to allow easy lookup for undirected edges)
        List<Edge<E>> toList = edges[edge.getTo()];
        if (toList == null) {
            toList = new ArrayList<>();
            edges[edge.getTo()] = toList;
        }
        addEdgeHelper(edge, toList);
    }

    @Override
    @SuppressWarnings("unchecked")
    public List<Edge<E>> getEdgesOut(int vertex) {
        if (edges[vertex] == null)
            return Collections.emptyList();
        return new ArrayList<>(edges[vertex]);
    }

    @Override
    public int getVertexDegree(int vertex) {
        if (edges[vertex] == null)
            return 0;
        return edges[vertex].size();
    }

    @Override
    public Vertex<V> getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException {
        if (vertex < 0 || vertex >= vertices.size())
            throw new IllegalArgumentException("Invalid vertex index: " + vertex);
        if (edges[vertex] == null || edges[vertex].isEmpty())
            throw new NoEdgesException("Cannot generate random connected vertex: vertex " + vertex
                            + " has no outgoing/undirected edges");
        int connectedVertexNum = rng.nextInt(edges[vertex].size());
        Edge<E> edge = edges[vertex].get(connectedVertexNum);
        if (edge.getFrom() == vertex)
            return vertices.get(edge.getTo()); //directed or undirected, vertex -> x
        else
            return vertices.get(edge.getFrom()); //Undirected edge, x -> vertex
    }

    @Override
    public List<Vertex<V>> getConnectedVertices(int vertex) {
        if (vertex < 0 || vertex >= vertices.size())
            throw new IllegalArgumentException("Invalid vertex index: " + vertex);

        if (edges[vertex] == null)
            return Collections.emptyList();
        List<Vertex<V>> list = new ArrayList<>(edges[vertex].size());
        for (Edge<E> edge : edges[vertex]) {
            list.add(vertices.get(edge.getTo()));
        }
        return list;
    }

    @Override
    public int[] getConnectedVertexIndices(int vertex) {
        int[] out = new int[(edges[vertex] == null ? 0 : edges[vertex].size())];
        if (out.length == 0)
            return out;
        for (int i = 0; i < out.length; i++) {
            Edge<E> e = edges[vertex].get(i);
            out[i] = (e.getFrom() == vertex ? e.getTo() : e.getFrom());
        }
        return out;
    }

    private void addEdgeHelper(Edge<E> edge, List<Edge<E>> list) {
        if (!allowMultipleEdges) {
            //Check to avoid multiple edges
            boolean duplicate = false;

            if (edge.isDirected()) {
                for (Edge<E> e : list) {
                    if (e.getTo() == edge.getTo()) {
                        duplicate = true;
                        break;
                    }
                }
            } else {
                for (Edge<E> e : list) {
                    if ((e.getFrom() == edge.getFrom() && e.getTo() == edge.getTo())
                                    || (e.getTo() == edge.getFrom() && e.getFrom() == edge.getTo())) {
                        duplicate = true;
                        break;
                    }
                }
            }

            if (!duplicate) {
                list.add(edge);
            }
        } else {
            //allow multiple/duplicate edges
            list.add(edge);
        }
    }


    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Graph {");
        sb.append("\nVertices {");
        for (Vertex<V> v : vertices) {
            sb.append("\n\t").append(v);
        }
        sb.append("\n}");
        sb.append("\nEdges {");
        for (int i = 0; i < edges.length; i++) {
            sb.append("\n\t");
            if (edges[i] == null)
                continue;
            sb.append(i).append(":");
            for (Edge<E> e : edges[i]) {
                sb.append(" ").append(e);
            }
        }
        sb.append("\n}");
        sb.append("\n}");
        return sb.toString();
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof Graph))
            return false;
        Graph g = (Graph) o;
        if (allowMultipleEdges != g.allowMultipleEdges)
            return false;
        if (edges.length != g.edges.length)
            return false;
        if (vertices.size() != g.vertices.size())
            return false;
        for (int i = 0; i < edges.length; i++) {
            if (!edges[i].equals(g.edges[i]))
                return false;
        }
        return vertices.equals(g.vertices);
    }

    @Override
    public int hashCode() {
        int result = 23;
        result = 31 * result + (allowMultipleEdges ? 1 : 0);
        result = 31 * result + Arrays.hashCode(edges);
        result = 31 * result + vertices.hashCode();
        return result;
    }
}