deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/Word2Vec.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.models.word2vec;

import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.type.CollectionType;

import java.io.IOException;
import java.util.*;

public class Word2Vec extends SequenceVectors<VocabWord> {
    private static final long serialVersionUID = 78249242142L;

    protected transient SentenceIterator sentenceIter;
    @Getter
    protected transient TokenizerFactory tokenizerFactory;

    /**
     * This method defines TokenizerFactory instance to be using during model building
     *
     * @param tokenizerFactory TokenizerFactory instance
     */
    public void setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;

        if (sentenceIter != null) {
            SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIter)
                    .vocabCache(vocab)
                    .tokenizerFactory(this.tokenizerFactory).build();
            this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
        }
    }

    /**
     * This method defines SentenceIterator instance, that will be used as training corpus source
     *
     * @param iterator SentenceIterator instance
     */
    public void setSentenceIterator(@NonNull SentenceIterator iterator) {

        if (tokenizerFactory != null) {
            SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator)
                    .tokenizerFactory(tokenizerFactory)
                    .vocabCache(vocab)
                    .allowMultithreading(configuration == null || configuration.isAllowParallelTokenization())
                    .build();
            this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
        } else
            log.error("Please call setTokenizerFactory() prior to setSentenceIter() call.");
    }

    /**
     * This method defines SequenceIterator instance, that will be used as training corpus source.
     * Main difference with other iterators here: it allows you to pass already tokenized Sequence<VocabWord> for training
     *
     * @param iterator
     */
    public void setSequenceIterator(@NonNull SequenceIterator<VocabWord> iterator) {
        this.iterator = iterator;
    }

    private static ObjectMapper mapper = null;
    private static final Object lock = new Object();

    private static ObjectMapper mapper() {
        if (mapper == null) {
            synchronized (lock) {
                if (mapper == null) {
                    mapper = new ObjectMapper();
                    mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
                    mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
                    return mapper;
                }
            }
        }
        return mapper;
    }

    private static final String CLASS_FIELD = "@class";
    private static final String VOCAB_LIST_FIELD = "VocabCache";

    public String toJson() throws JsonProcessingException {

        JsonObject retVal = new JsonObject();
        ObjectMapper mapper = mapper();

        retVal.addProperty(CLASS_FIELD, mapper.writeValueAsString(this.getClass().getName()));

        if (this.vocab instanceof AbstractCache) {
            retVal.addProperty(VOCAB_LIST_FIELD, ((AbstractCache<VocabWord>) this.vocab).toJson());
        }

        return retVal.toString();
    }

    public static Word2Vec fromJson(String jsonString)  throws IOException {

        Word2Vec ret = new Word2Vec();

        JsonParser parser = new JsonParser();
        JsonObject json = parser.parse(jsonString).getAsJsonObject();

        VocabCache cache = AbstractCache.fromJson(json.get(VOCAB_LIST_FIELD).getAsString());

        ret.setVocab(cache);
        return ret;
    }

    @Override
    public String toString() {
        return "Word2Vec{" +
                "sentenceIter=" + sentenceIter +
                ", tokenizerFactory=" + tokenizerFactory +
                ", iterator=" + iterator +
                ", elementsLearningAlgorithm=" + elementsLearningAlgorithm +
                ", sequenceLearningAlgorithm=" + sequenceLearningAlgorithm +
                ", configuration=" + configuration +
                ", existingModel=" + existingModel +
                ", intersectModel=" + intersectModel +
                ", unknownElement=" + unknownElement +
                ", scoreElements=" + scoreElements +
                ", scoreSequences=" + scoreSequences +
                ", configured=" + configured +
                ", lockFactor=" + lockFactor +
                ", enableScavenger=" + enableScavenger +
                ", vocabLimit=" + vocabLimit +
                ", eventListeners=" + eventListeners +
                ", minWordFrequency=" + minWordFrequency +
                ", lookupTable=" + lookupTable +
                ", vocab=" + vocab +
                ", layerSize=" + layerSize +
                ", modelUtils=" + modelUtils +
                ", numIterations=" + numIterations +
                ", numEpochs=" + numEpochs +
                ", negative=" + negative +
                ", sampling=" + sampling +
                ", learningRate=" + learningRate +
                ", minLearningRate=" + minLearningRate +
                ", window=" + window +
                ", batchSize=" + batchSize +
                ", learningRateDecayWords=" + learningRateDecayWords +
                ", resetModel=" + resetModel +
                ", useAdeGrad=" + useAdeGrad +
                ", workers=" + workers +
                ", trainSequenceVectors=" + trainSequenceVectors +
                ", trainElementsVectors=" + trainElementsVectors +
                ", seed=" + seed +
                ", useUnknown=" + useUnknown +
                ", variableWindows=" + Arrays.toString(variableWindows) +
                ", stopWords=" + stopWords +
                '}';
    }

    public static class Builder extends SequenceVectors.Builder<VocabWord> {
        protected SentenceIterator sentenceIterator;
        protected LabelAwareIterator labelAwareIterator;
        protected TokenizerFactory tokenizerFactory;
        protected boolean allowParallelTokenization = true;


        public Builder() {

        }

        /**
         * This method has no effect for Word2Vec
         *
         * @param vec existing WordVectors model
         * @return
         */
        @Override
        protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
            return this;
        }

        public Builder(@NonNull VectorsConfiguration configuration) {
            super(configuration);
            this.allowParallelTokenization = configuration.isAllowParallelTokenization();
        }

        public Builder iterate(@NonNull DocumentIterator iterator) {
            this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
            return this;
        }

        /**
         * This method used to feed SentenceIterator, that contains training corpus, into ParagraphVectors
         *
         * @param iterator
         * @return
         */
        public Builder iterate(@NonNull SentenceIterator iterator) {
            this.sentenceIterator = iterator;
            return this;
        }

        /**
         * This method defines TokenizerFactory to be used for strings tokenization during training
         * PLEASE NOTE: If external VocabCache is used, the same TokenizerFactory should be used to keep derived tokens equal.
         *
         * @param tokenizerFactory
         * @return
         */
        public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        /**
         * This method used to feed SequenceIterator, that contains training corpus, into ParagraphVectors
         *
         * @param iterator
         * @return
         */
        @Override
        public Builder iterate(@NonNull SequenceIterator<VocabWord> iterator) {
            super.iterate(iterator);
            return this;
        }

        /**
         * This method used to feed LabelAwareIterator, that is usually used
         *
         * @param iterator
         * @return
         */
        public Builder iterate(@NonNull LabelAwareIterator iterator) {
            this.labelAwareIterator = iterator;
            return this;
        }

        /**
         * This method defines mini-batch size
         * @param batchSize
         * @return
         */
        @Override
        public Builder batchSize(int batchSize) {
            super.batchSize(batchSize);
            return this;
        }

        /**
         * This method defines number of iterations done for each mini-batch during training
         * @param iterations
         * @return
         */
        @Override
        public Builder iterations(int iterations) {
            super.iterations(iterations);
            return this;
        }

        /**
         * This method defines number of epochs (iterations over whole training corpus) for training
         * @param numEpochs
         * @return
         */
        @Override
        public Builder epochs(int numEpochs) {
            super.epochs(numEpochs);
            return this;
        }

        /**
         * This method defines number of dimensions for output vectors
         * @param layerSize
         * @return
         */
        @Override
        public Builder layerSize(int layerSize) {
            super.layerSize(layerSize);
            return this;
        }

        /**
         * This method defines initial learning rate for model training
         *
         * @param learningRate
         * @return
         */
        @Override
        public Builder learningRate(double learningRate) {
            super.learningRate(learningRate);
            return this;
        }

        /**
         * This method defines minimal word frequency in training corpus. All words below this threshold will be removed prior model training
         *
         * @param minWordFrequency
         * @return
         */
        @Override
        public Builder minWordFrequency(int minWordFrequency) {
            super.minWordFrequency(minWordFrequency);
            return this;
        }

        /**
         * This method defines minimal learning rate value for training
         *
         * @param minLearningRate
         * @return
         */
        @Override
        public Builder minLearningRate(double minLearningRate) {
            super.minLearningRate(minLearningRate);
            return this;
        }

        /**
         * This method defines whether model should be totally wiped out prior building, or not
         *
         * @param reallyReset
         * @return
         */
        @Override
        public Builder resetModel(boolean reallyReset) {
            super.resetModel(reallyReset);
            return this;
        }

        /**
         * This method sets vocabulary limit during construction.
         *
         * Default value: 0. Means no limit
         *
         * @param limit
         * @return
         */
        @Override
        public Builder limitVocabularySize(int limit) {
            super.limitVocabularySize(limit);
            return this;
        }

        /**
         * This method allows to define external VocabCache to be used
         *
         * @param vocabCache
         * @return
         */
        @Override
        public Builder vocabCache(@NonNull VocabCache<VocabWord> vocabCache) {
            super.vocabCache(vocabCache);
            return this;
        }

        /**
         * This method allows to define external WeightLookupTable to be used
         *
         * @param lookupTable
         * @return
         */
        @Override
        public Builder lookupTable(@NonNull WeightLookupTable<VocabWord> lookupTable) {
            super.lookupTable(lookupTable);
            return this;
        }

        /**
         * This method defines whether subsampling should be used or not
         *
         * @param sampling set > 0 to subsampling argument, or 0 to disable
         * @return
         */
        @Override
        public Builder sampling(double sampling) {
            super.sampling(sampling);
            return this;
        }

        /**
         * This method defines whether adaptive gradients should be used or not
         *
         * @param reallyUse
         * @return
         */
        @Override
        public Builder useAdaGrad(boolean reallyUse) {
            super.useAdaGrad(reallyUse);
            return this;
        }

        /**
         * This method defines whether negative sampling should be used or not
         *
         * PLEASE NOTE: If you're going to use negative sampling, you might want to disable HierarchicSoftmax, which is enabled by default
         *
         * Default value: 0
         *
         * @param negative set > 0 as negative sampling argument, or 0 to disable
         * @return
         */
        @Override
        public Builder negativeSample(double negative) {
            super.negativeSample(negative);
            return this;
        }

        /**
         * This method defines stop words that should be ignored during training
         *
         * @param stopList
         * @return
         */
        @Override
        public Builder stopWords(@NonNull List<String> stopList) {
            super.stopWords(stopList);
            return this;
        }

        /**
         * This method is hardcoded to TRUE, since that's whole point of Word2Vec
         *
         * @param trainElements
         * @return
         */
        @Override
        public Builder trainElementsRepresentation(boolean trainElements) {
            throw new IllegalStateException("You can't change this option for Word2Vec");
        }

        /**
         * This method is hardcoded to FALSE, since that's whole point of Word2Vec
         *
         * @param trainSequences
         * @return
         */
        @Override
        public Builder trainSequencesRepresentation(boolean trainSequences) {
            throw new IllegalStateException("You can't change this option for Word2Vec");
        }

        /**
         * This method defines stop words that should be ignored during training
         *
         * @param stopList
         * @return
         */
        @Override
        public Builder stopWords(@NonNull Collection<VocabWord> stopList) {
            super.stopWords(stopList);
            return this;
        }

        /**
         * This method defines context window size
         *
         * @param windowSize
         * @return
         */
        @Override
        public Builder windowSize(int windowSize) {
            super.windowSize(windowSize);
            return this;
        }

        /**
         * This method defines random seed for random numbers generator
         * @param randomSeed
         * @return
         */
        @Override
        public Builder seed(long randomSeed) {
            super.seed(randomSeed);
            return this;
        }


        /**
         * Sets number of threads running calculations.
         * Note this is different from workers which affect
         * the number of threads used to compute updates.
         * This should be balanced with the number of workers.
         * High number of threads will actually hinder performance.
         *
         * @param vectorCalcThreads the number of threads to compute updates
         * @return
         */
        public Builder vectorCalcThreads(int vectorCalcThreads) {
            super.vectorCalcThreads(vectorCalcThreads);
            return this;
        }


        /**
         * This method defines maximum number of concurrent threads available for training
         *
         * @param numWorkers
         * @return
         */
        @Override
        public Builder workers(int numWorkers) {
            super.workers(numWorkers);
            return this;
        }

        /**
         * Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc
         *
         * @param modelUtils model utils to be used
         * @return
         */
        @Override
        public Builder modelUtils(@NonNull ModelUtils<VocabWord> modelUtils) {
            super.modelUtils(modelUtils);
            return this;
        }

        /**
         * This method allows to use variable window size. In this case, every batch gets processed using one of predefined window sizes
         *
         * @param windows
         * @return
         */
        @Override
        public Builder useVariableWindow(int... windows) {
            super.useVariableWindow(windows);
            return this;
        }

        /**
         * This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
         *
         * @param element
         * @return
         */
        @Override
        public Builder unknownElement(VocabWord element) {
            super.unknownElement(element);
            return this;
        }

        /**
         * This method allows you to specify, if UNK word should be used internally
         *
         * @param reallyUse
         * @return
         */
        @Override
        public Builder useUnknown(boolean reallyUse) {
            super.useUnknown(reallyUse);
            if (this.unknownElement == null) {
                this.unknownElement(new VocabWord(1.0, Word2Vec.DEFAULT_UNK));
            }
            return this;
        }

        /**
         * This method sets VectorsListeners for this SequenceVectors model
         *
         * @param vectorsListeners
         * @return
         */
        @Override
        public Builder setVectorsListeners(@NonNull Collection<VectorsListener<VocabWord>> vectorsListeners) {
            super.setVectorsListeners(vectorsListeners);
            return this;
        }

        @Override
        public Builder elementsLearningAlgorithm(String algorithm) {
            if(algorithm == null)
                return this;
            super.elementsLearningAlgorithm(algorithm);
            return this;
        }

        @Override
        public Builder elementsLearningAlgorithm(ElementsLearningAlgorithm<VocabWord> algorithm) {
          if(algorithm == null)
              return this;
            super.elementsLearningAlgorithm(algorithm);
            return this;
        }

        /**
         * This method enables/disables parallel tokenization.
         *
         * Default value: TRUE
         * @param allow
         * @return
         */
        public Builder allowParallelTokenization(boolean allow) {
            this.allowParallelTokenization = allow;
            return this;
        }

        /**
         * This method ebables/disables periodical vocab truncation during construction
         *
         * Default value: disabled
         *
         * @param reallyEnable
         * @return
         */
        @Override
        public Builder enableScavenger(boolean reallyEnable) {
            super.enableScavenger(reallyEnable);
            return this;
        }

        /**
         * This method enables/disables Hierarchic softmax
         *
         * Default value: enabled
         *
         * @param reallyUse
         * @return
         */
        @Override
        public Builder useHierarchicSoftmax(boolean reallyUse) {
            super.useHierarchicSoftmax(reallyUse);
            return this;
        }

        @Override
        public Builder usePreciseWeightInit(boolean reallyUse) {
            super.usePreciseWeightInit(reallyUse);
            return this;
        }

        @Override
        public Builder usePreciseMode(boolean reallyUse) {
            super.usePreciseMode(reallyUse);
            return this;
        }

        @Override
        public Builder intersectModel(@NonNull SequenceVectors vectors, boolean isLocked) {
            super.intersectModel(vectors, isLocked);
            return this;
        }

        public Word2Vec build() {
            presetTables();

            Word2Vec ret = new Word2Vec();

            if (sentenceIterator != null) {
                if (tokenizerFactory == null)
                    tokenizerFactory = new DefaultTokenizerFactory();

                SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator)
                        .vocabCache(vocabCache)
                        .tokenizerFactory(tokenizerFactory).allowMultithreading(allowParallelTokenization)
                        .build();
                this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
            }

            if (this.labelAwareIterator != null) {
                if (tokenizerFactory == null)
                    tokenizerFactory = new DefaultTokenizerFactory();

                SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(labelAwareIterator)
                        .tokenizerFactory(tokenizerFactory).allowMultithreading(allowParallelTokenization)
                        .vocabCache(vocabCache)
                        .build();
                this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
            }

            ret.numEpochs = this.numEpochs;
            ret.numIterations = this.iterations;
            ret.vocab = this.vocabCache;
            ret.minWordFrequency = this.minWordFrequency;
            ret.learningRate.set(this.learningRate);
            ret.minLearningRate = this.minLearningRate;
            ret.sampling = this.sampling;
            ret.negative = this.negative;
            ret.layerSize = this.layerSize;
            ret.batchSize = this.batchSize;
            ret.learningRateDecayWords = this.learningRateDecayWords;
            ret.window = this.window;
            ret.resetModel = this.resetModel;
            ret.useAdeGrad = this.useAdaGrad;
            ret.stopWords = this.stopWords;
            ret.workers = this.workers;
            ret.useUnknown = this.useUnknown;
            ret.unknownElement = this.unknownElement;
            ret.variableWindows = this.variableWindows;
            ret.seed = this.seed;
            ret.enableScavenger = this.enableScavenger;
            ret.vocabLimit = this.vocabLimit;

            if (ret.unknownElement == null)
                ret.unknownElement = new VocabWord(1.0,SequenceVectors.DEFAULT_UNK);


            ret.iterator = this.iterator;
            ret.lookupTable = this.lookupTable;
            ret.tokenizerFactory = this.tokenizerFactory;
            ret.modelUtils = this.modelUtils;

            ret.elementsLearningAlgorithm = this.elementsLearningAlgorithm;
            ret.sequenceLearningAlgorithm = this.sequenceLearningAlgorithm;

            ret.intersectModel = this.intersectVectors;
            ret.lockFactor = this.lockFactor;

            if(!this.configurationSpecified) {
                this.configuration.setLearningRate(this.learningRate);
                this.configuration.setLayersSize(layerSize);
                this.configuration.setHugeModelExpected(hugeModelExpected);
                this.configuration.setWindow(window);
                this.configuration.setMinWordFrequency(minWordFrequency);
                this.configuration.setIterations(iterations);
                this.configuration.setSeed(seed);
                this.configuration.setBatchSize(batchSize);
                this.configuration.setLearningRateDecayWords(learningRateDecayWords);
                this.configuration.setMinLearningRate(minLearningRate);
                this.configuration.setSampling(this.sampling);
                this.configuration.setUseAdaGrad(useAdaGrad);
                this.configuration.setNegative(negative);
                this.configuration.setEpochs(this.numEpochs);
                this.configuration.setStopList(this.stopWords);
                this.configuration.setVariableWindows(variableWindows);
                this.configuration.setUseHierarchicSoftmax(this.useHierarchicSoftmax);
                this.configuration.setPreciseWeightInit(this.preciseWeightInit);
                this.configuration.setModelUtils(this.modelUtils.getClass().getCanonicalName());
                this.configuration.setAllowParallelTokenization(this.allowParallelTokenization);
                this.configuration.setPreciseMode(this.preciseMode);

            }

            if (tokenizerFactory != null) {
                this.configuration.setTokenizerFactory(tokenizerFactory.getClass().getCanonicalName());
                if (tokenizerFactory.getTokenPreProcessor() != null)
                    this.configuration.setTokenPreProcessor(
                            tokenizerFactory.getTokenPreProcessor().getClass().getCanonicalName());
            }

            ret.configuration = this.configuration;

            // we hardcode
            ret.trainSequenceVectors = false;
            ret.trainElementsVectors = true;

            ret.eventListeners = this.vectorsListeners;


            return ret;
        }
    }
}