
View on GitHub


3 days
Test Coverage
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  *
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************

package org.deeplearning4j.models.word2vec;

import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.StreamLineIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.type.CollectionType;

import java.util.*;

public class Word2Vec extends SequenceVectors<VocabWord> {
    private static final long serialVersionUID = 78249242142L;

    protected transient SentenceIterator sentenceIter;
    protected transient TokenizerFactory tokenizerFactory;

     * This method defines TokenizerFactory instance to be using during model building
     * @param tokenizerFactory TokenizerFactory instance
    public void setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;

        if (sentenceIter != null) {
            SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIter)
            this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();

     * This method defines SentenceIterator instance, that will be used as training corpus source
     * @param iterator SentenceIterator instance
    public void setSentenceIterator(@NonNull SentenceIterator iterator) {

        if (tokenizerFactory != null) {
            SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator)
                    .allowMultithreading(configuration == null || configuration.isAllowParallelTokenization())
            this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();
        } else
            log.error("Please call setTokenizerFactory() prior to setSentenceIter() call.");

     * This method defines SequenceIterator instance, that will be used as training corpus source.
     * Main difference with other iterators here: it allows you to pass already tokenized Sequence<VocabWord> for training
     * @param iterator
    public void setSequenceIterator(@NonNull SequenceIterator<VocabWord> iterator) {
        this.iterator = iterator;

    private static ObjectMapper mapper = null;
    private static final Object lock = new Object();

    private static ObjectMapper mapper() {
        if (mapper == null) {
            synchronized (lock) {
                if (mapper == null) {
                    mapper = new ObjectMapper();
                    mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
                    mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
                    return mapper;
        return mapper;

    private static final String CLASS_FIELD = "@class";
    private static final String VOCAB_LIST_FIELD = "VocabCache";

    public String toJson() throws JsonProcessingException {

        JsonObject retVal = new JsonObject();
        ObjectMapper mapper = mapper();

        retVal.addProperty(CLASS_FIELD, mapper.writeValueAsString(this.getClass().getName()));

        if (this.vocab instanceof AbstractCache) {
            retVal.addProperty(VOCAB_LIST_FIELD, ((AbstractCache<VocabWord>) this.vocab).toJson());

        return retVal.toString();

    public static Word2Vec fromJson(String jsonString)  throws IOException {

        Word2Vec ret = new Word2Vec();

        JsonParser parser = new JsonParser();
        JsonObject json = parser.parse(jsonString).getAsJsonObject();

        VocabCache cache = AbstractCache.fromJson(json.get(VOCAB_LIST_FIELD).getAsString());

        return ret;

    public String toString() {
        return "Word2Vec{" +
                "sentenceIter=" + sentenceIter +
                ", tokenizerFactory=" + tokenizerFactory +
                ", iterator=" + iterator +
                ", elementsLearningAlgorithm=" + elementsLearningAlgorithm +
                ", sequenceLearningAlgorithm=" + sequenceLearningAlgorithm +
                ", configuration=" + configuration +
                ", existingModel=" + existingModel +
                ", intersectModel=" + intersectModel +
                ", unknownElement=" + unknownElement +
                ", scoreElements=" + scoreElements +
                ", scoreSequences=" + scoreSequences +
                ", configured=" + configured +
                ", lockFactor=" + lockFactor +
                ", enableScavenger=" + enableScavenger +
                ", vocabLimit=" + vocabLimit +
                ", eventListeners=" + eventListeners +
                ", minWordFrequency=" + minWordFrequency +
                ", lookupTable=" + lookupTable +
                ", vocab=" + vocab +
                ", layerSize=" + layerSize +
                ", modelUtils=" + modelUtils +
                ", numIterations=" + numIterations +
                ", numEpochs=" + numEpochs +
                ", negative=" + negative +
                ", sampling=" + sampling +
                ", learningRate=" + learningRate +
                ", minLearningRate=" + minLearningRate +
                ", window=" + window +
                ", batchSize=" + batchSize +
                ", learningRateDecayWords=" + learningRateDecayWords +
                ", resetModel=" + resetModel +
                ", useAdeGrad=" + useAdeGrad +
                ", workers=" + workers +
                ", trainSequenceVectors=" + trainSequenceVectors +
                ", trainElementsVectors=" + trainElementsVectors +
                ", seed=" + seed +
                ", useUnknown=" + useUnknown +
                ", variableWindows=" + Arrays.toString(variableWindows) +
                ", stopWords=" + stopWords +

    public static class Builder extends SequenceVectors.Builder<VocabWord> {
        protected SentenceIterator sentenceIterator;
        protected LabelAwareIterator labelAwareIterator;
        protected TokenizerFactory tokenizerFactory;
        protected boolean allowParallelTokenization = true;

        public Builder() {


         * This method has no effect for Word2Vec
         * @param vec existing WordVectors model
         * @return
        protected Builder useExistingWordVectors(@NonNull WordVectors vec) {
            return this;

        public Builder(@NonNull VectorsConfiguration configuration) {
            this.allowParallelTokenization = configuration.isAllowParallelTokenization();

        public Builder iterate(@NonNull DocumentIterator iterator) {
            this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build();
            return this;

         * This method used to feed SentenceIterator, that contains training corpus, into ParagraphVectors
         * @param iterator
         * @return
        public Builder iterate(@NonNull SentenceIterator iterator) {
            this.sentenceIterator = iterator;
            return this;

         * This method defines TokenizerFactory to be used for strings tokenization during training
         * PLEASE NOTE: If external VocabCache is used, the same TokenizerFactory should be used to keep derived tokens equal.
         * @param tokenizerFactory
         * @return
        public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;

         * This method used to feed SequenceIterator, that contains training corpus, into ParagraphVectors
         * @param iterator
         * @return
        public Builder iterate(@NonNull SequenceIterator<VocabWord> iterator) {
            return this;

         * This method used to feed LabelAwareIterator, that is usually used
         * @param iterator
         * @return
        public Builder iterate(@NonNull LabelAwareIterator iterator) {
            this.labelAwareIterator = iterator;
            return this;

         * This method defines mini-batch size
         * @param batchSize
         * @return
        public Builder batchSize(int batchSize) {
            return this;

         * This method defines number of iterations done for each mini-batch during training
         * @param iterations
         * @return
        public Builder iterations(int iterations) {
            return this;

         * This method defines number of epochs (iterations over whole training corpus) for training
         * @param numEpochs
         * @return
        public Builder epochs(int numEpochs) {
            return this;

         * This method defines number of dimensions for output vectors
         * @param layerSize
         * @return
        public Builder layerSize(int layerSize) {
            return this;

         * This method defines initial learning rate for model training
         * @param learningRate
         * @return
        public Builder learningRate(double learningRate) {
            return this;

         * This method defines minimal word frequency in training corpus. All words below this threshold will be removed prior model training
         * @param minWordFrequency
         * @return
        public Builder minWordFrequency(int minWordFrequency) {
            return this;

         * This method defines minimal learning rate value for training
         * @param minLearningRate
         * @return
        public Builder minLearningRate(double minLearningRate) {
            return this;

         * This method defines whether model should be totally wiped out prior building, or not
         * @param reallyReset
         * @return
        public Builder resetModel(boolean reallyReset) {
            return this;

         * This method sets vocabulary limit during construction.
         * Default value: 0. Means no limit
         * @param limit
         * @return
        public Builder limitVocabularySize(int limit) {
            return this;

         * This method allows to define external VocabCache to be used
         * @param vocabCache
         * @return
        public Builder vocabCache(@NonNull VocabCache<VocabWord> vocabCache) {
            return this;

         * This method allows to define external WeightLookupTable to be used
         * @param lookupTable
         * @return
        public Builder lookupTable(@NonNull WeightLookupTable<VocabWord> lookupTable) {
            return this;

         * This method defines whether subsampling should be used or not
         * @param sampling set > 0 to subsampling argument, or 0 to disable
         * @return
        public Builder sampling(double sampling) {
            return this;

         * This method defines whether adaptive gradients should be used or not
         * @param reallyUse
         * @return
        public Builder useAdaGrad(boolean reallyUse) {
            return this;

         * This method defines whether negative sampling should be used or not
         * PLEASE NOTE: If you're going to use negative sampling, you might want to disable HierarchicSoftmax, which is enabled by default
         * Default value: 0
         * @param negative set > 0 as negative sampling argument, or 0 to disable
         * @return
        public Builder negativeSample(double negative) {
            return this;

         * This method defines stop words that should be ignored during training
         * @param stopList
         * @return
        public Builder stopWords(@NonNull List<String> stopList) {
            return this;

         * This method is hardcoded to TRUE, since that's whole point of Word2Vec
         * @param trainElements
         * @return
        public Builder trainElementsRepresentation(boolean trainElements) {
            throw new IllegalStateException("You can't change this option for Word2Vec");

         * This method is hardcoded to FALSE, since that's whole point of Word2Vec
         * @param trainSequences
         * @return
        public Builder trainSequencesRepresentation(boolean trainSequences) {
            throw new IllegalStateException("You can't change this option for Word2Vec");

         * This method defines stop words that should be ignored during training
         * @param stopList
         * @return
        public Builder stopWords(@NonNull Collection<VocabWord> stopList) {
            return this;

         * This method defines context window size
         * @param windowSize
         * @return
        public Builder windowSize(int windowSize) {
            return this;

         * This method defines random seed for random numbers generator
         * @param randomSeed
         * @return
        public Builder seed(long randomSeed) {
            return this;

         * Sets number of threads running calculations.
         * Note this is different from workers which affect
         * the number of threads used to compute updates.
         * This should be balanced with the number of workers.
         * High number of threads will actually hinder performance.
         * @param vectorCalcThreads the number of threads to compute updates
         * @return
        public Builder vectorCalcThreads(int vectorCalcThreads) {
            return this;

         * This method defines maximum number of concurrent threads available for training
         * @param numWorkers
         * @return
        public Builder workers(int numWorkers) {
            return this;

         * Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc
         * @param modelUtils model utils to be used
         * @return
        public Builder modelUtils(@NonNull ModelUtils<VocabWord> modelUtils) {
            return this;

         * This method allows to use variable window size. In this case, every batch gets processed using one of predefined window sizes
         * @param windows
         * @return
        public Builder useVariableWindow(int... windows) {
            return this;

         * This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used
         * @param element
         * @return
        public Builder unknownElement(VocabWord element) {
            return this;

         * This method allows you to specify, if UNK word should be used internally
         * @param reallyUse
         * @return
        public Builder useUnknown(boolean reallyUse) {
            if (this.unknownElement == null) {
                this.unknownElement(new VocabWord(1.0, Word2Vec.DEFAULT_UNK));
            return this;

         * This method sets VectorsListeners for this SequenceVectors model
         * @param vectorsListeners
         * @return
        public Builder setVectorsListeners(@NonNull Collection<VectorsListener<VocabWord>> vectorsListeners) {
            return this;

        public Builder elementsLearningAlgorithm(String algorithm) {
            if(algorithm == null)
                return this;
            return this;

        public Builder elementsLearningAlgorithm(ElementsLearningAlgorithm<VocabWord> algorithm) {
          if(algorithm == null)
              return this;
            return this;

         * This method enables/disables parallel tokenization.
         * Default value: TRUE
         * @param allow
         * @return
        public Builder allowParallelTokenization(boolean allow) {
            this.allowParallelTokenization = allow;
            return this;

         * This method ebables/disables periodical vocab truncation during construction
         * Default value: disabled
         * @param reallyEnable
         * @return
        public Builder enableScavenger(boolean reallyEnable) {
            return this;

         * This method enables/disables Hierarchic softmax
         * Default value: enabled
         * @param reallyUse
         * @return
        public Builder useHierarchicSoftmax(boolean reallyUse) {
            return this;

        public Builder usePreciseWeightInit(boolean reallyUse) {
            return this;

        public Builder usePreciseMode(boolean reallyUse) {
            return this;

        public Builder intersectModel(@NonNull SequenceVectors vectors, boolean isLocked) {
            super.intersectModel(vectors, isLocked);
            return this;

        public Word2Vec build() {

            Word2Vec ret = new Word2Vec();

            if (sentenceIterator != null) {
                if (tokenizerFactory == null)
                    tokenizerFactory = new DefaultTokenizerFactory();

                SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator)
                this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();

            if (this.labelAwareIterator != null) {
                if (tokenizerFactory == null)
                    tokenizerFactory = new DefaultTokenizerFactory();

                SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(labelAwareIterator)
                this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build();

            ret.numEpochs = this.numEpochs;
            ret.numIterations = this.iterations;
            ret.vocab = this.vocabCache;
            ret.minWordFrequency = this.minWordFrequency;
            ret.minLearningRate = this.minLearningRate;
            ret.sampling = this.sampling;
            ret.negative = this.negative;
            ret.layerSize = this.layerSize;
            ret.batchSize = this.batchSize;
            ret.learningRateDecayWords = this.learningRateDecayWords;
            ret.window = this.window;
            ret.resetModel = this.resetModel;
            ret.useAdeGrad = this.useAdaGrad;
            ret.stopWords = this.stopWords;
            ret.workers = this.workers;
            ret.useUnknown = this.useUnknown;
            ret.unknownElement = this.unknownElement;
            ret.variableWindows = this.variableWindows;
            ret.seed = this.seed;
            ret.enableScavenger = this.enableScavenger;
            ret.vocabLimit = this.vocabLimit;

            if (ret.unknownElement == null)
                ret.unknownElement = new VocabWord(1.0,SequenceVectors.DEFAULT_UNK);

            ret.iterator = this.iterator;
            ret.lookupTable = this.lookupTable;
            ret.tokenizerFactory = this.tokenizerFactory;
            ret.modelUtils = this.modelUtils;

            ret.elementsLearningAlgorithm = this.elementsLearningAlgorithm;
            ret.sequenceLearningAlgorithm = this.sequenceLearningAlgorithm;

            ret.intersectModel = this.intersectVectors;
            ret.lockFactor = this.lockFactor;

            if(!this.configurationSpecified) {


            if (tokenizerFactory != null) {
                if (tokenizerFactory.getTokenPreProcessor() != null)

            ret.configuration = this.configuration;

            // we hardcode
            ret.trainSequenceVectors = false;
            ret.trainElementsVectors = true;

            ret.eventListeners = this.vectorsListeners;

            return ret;