nlpub/watset-java

View on GitHub
src/main/java/org/nlpub/watset/util/ClusteringAlgorithmProvider.java

Summary

Maintainability
B
4 hrs
Test Coverage
/*
 * Copyright 2019 Dmitry Ustalov
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package org.nlpub.watset.util;

import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.jgrapht.Graph;
import org.jgrapht.alg.clustering.GirvanNewmanClustering;
import org.jgrapht.alg.clustering.KSpanningTreeClustering;
import org.jgrapht.alg.interfaces.ClusteringAlgorithm;
import org.nlpub.watset.graph.*;

import java.nio.file.Path;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;

import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;

/**
 * A utility class that creates instances of the graph clustering algorithms.
 *
 * @param <V> the type of nodes in the graph
 * @param <E> the type of edges in the graph
 */
public class ClusteringAlgorithmProvider<V, E> implements ClusteringAlgorithmBuilder<V, E, ClusteringAlgorithm<V>> {
    /**
     * Clustering algorithms that {@link ClusteringAlgorithmProvider} knows how to provide.
     */
    public enum ProvidingAlgorithm {
        /**
         * Label for {@link EmptyClustering}.
         */
        EMPTY,

        /**
         * Label for {@link TogetherClustering}.
         */
        TOGETHER,

        /**
         * Label for {@link SingletonClustering}.
         */
        SINGLETON,

        /**
         * Label for {@link ComponentsClustering}.
         */
        COMPONENTS,

        /**
         * Label for {@link KSpanningTreeClustering}.
         */
        K_SPANNING_TREE,

        /**
         * Label for {@link GirvanNewmanClustering}.
         */
        GIRVAN_NEWMAN,

        /**
         * Label for {@link SpectralClustering}.
         */
        SPECTRAL,

        /**
         * Label for {@link ChineseWhispers}.
         */
        CHINESE_WHISPERS,

        /**
         * Label for {@link MarkovClustering}.
         */
        MARKOV_CLUSTERING,

        /**
         * Label for {@link MarkovClusteringExternal}.
         */
        MARKOV_CLUSTERING_EXTERNAL,

        /**
         * Label for {@link MaxMax}.
         */
        MAXMAX,
    }

    private final ProvidingAlgorithm algorithm;
    private final Map<String, String> params;
    private final NodeWeighting<V, E> weighting;
    private final JDKRandomGenerator random;

    /**
     * Create an instance of this utility class.
     *
     * @param algorithm the algorithm identifier
     * @param params    the parameter map for the algorithm
     * @param random    the random number generator
     */
    public ClusteringAlgorithmProvider(String algorithm, Map<String, String> params, JDKRandomGenerator random) {
        this.algorithm = ProvidingAlgorithm.valueOf(normalize(requireNonNull(algorithm, "algorithm is not specified")));
        this.params = requireNonNullElse(params, Collections.emptyMap());
        this.weighting = NodeWeightings.parse(params.get("mode"));
        this.random = requireNonNullElse(random, new JDKRandomGenerator());
    }

    @Override
    public ClusteringAlgorithm<V> apply(Graph<V, E> graph) {
        switch (algorithm) {
            case EMPTY:
                return EmptyClustering.<V, E>builder().apply(graph);
            case TOGETHER:
                return TogetherClustering.<V, E>builder().apply(graph);
            case SINGLETON:
                return SingletonClustering.<V, E>builder().apply(graph);
            case COMPONENTS:
                return ComponentsClustering.<V, E>builder().apply(graph);
            case K_SPANNING_TREE:
                final int kst = Integer.parseInt(requireNonNull(params.get("k"), "k must be specified"));
                return new KSpanningTreeClustering<>(graph, kst);
            case GIRVAN_NEWMAN:
                final int kgn = Integer.parseInt(requireNonNull(params.get("k"), "k must be specified"));
                return new GirvanNewmanClustering<>(graph, kgn);
            case SPECTRAL:
                final int kSpectral = Integer.parseInt(requireNonNull(params.get("k"), "k must be specified"));
                final var clusterer = new KMeansPlusPlusClusterer<NodeEmbedding<V>>(kSpectral, -1, new EuclideanDistance(), random);
                final int numTrials = params.containsKey("n") ? Integer.parseInt(params.get("n")) : 10;
                final var metaClusterer = new MultiKMeansPlusPlusClusterer<>(clusterer, numTrials);
                return SpectralClustering.<V, E>builder().setClusterer(metaClusterer).setK(kSpectral).apply(graph);
            case CHINESE_WHISPERS:
                return ChineseWhispers.<V, E>builder().setWeighting(weighting).setRandom(random).apply(graph);
            case MARKOV_CLUSTERING:
                final var mcl = MarkovClustering.<V, E>builder();

                if (params.containsKey("e")) mcl.setE(Integer.parseInt(params.get("e")));
                if (params.containsKey("r")) mcl.setR(Double.parseDouble(params.get("r")));

                return mcl.apply(graph);
            case MARKOV_CLUSTERING_EXTERNAL:
                final var mclOfficial = MarkovClusteringExternal.<V, E>builder().
                        setPath(Path.of(params.get("bin"))).
                        setThreads(Runtime.getRuntime().availableProcessors());

                if (params.containsKey("r")) mclOfficial.setR(Double.parseDouble(params.get("r")));

                return mclOfficial.apply(graph);
            case MAXMAX:
                return MaxMax.<V, E>builder().apply(graph);
            default:
                throw new IllegalArgumentException("Unknown algorithm: " + algorithm);
        }
    }

    /**
     * Normalize the name of the requested algorithm.
     *
     * @param algorithm the algorithm
     * @return the normalized name
     */
    protected String normalize(String algorithm) {
        return algorithm.toUpperCase(Locale.ROOT).
                replaceAll("-", "_").
                replaceAll("GN", ProvidingAlgorithm.GIRVAN_NEWMAN.name()).
                replaceAll("KST", ProvidingAlgorithm.K_SPANNING_TREE.name()).
                replaceAll("CW", ProvidingAlgorithm.CHINESE_WHISPERS.name()).
                replaceAll("MCL_BIN", ProvidingAlgorithm.MARKOV_CLUSTERING_EXTERNAL.name()).
                replaceAll("MCL", ProvidingAlgorithm.CHINESE_WHISPERS.name());
    }
}