deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java

Summary

Maintainability
D
1 day
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.models.fasttext;

import com.github.jfasttext.JFastText;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;

import java.io.*;
import java.util.*;

@Slf4j
@AllArgsConstructor
@lombok.Builder
public class FastText implements WordVectors, Serializable {

    private final static String METHOD_NOT_AVAILABLE = "This method is available for text (.vec) models only - binary (.bin) model currently loaded";
    // Mandatory
    @Getter private String inputFile;
    @Getter private String outputFile;

    // Optional for dictionary
    @Builder.Default  private int bucket = -1;
    @Builder.Default  private int minCount = -1;
    @Builder.Default  private int minCountLabel = -1;
    @Builder.Default  private int wordNgrams = -1;
    @Builder.Default  private int minNgramLength = -1;
    @Builder.Default  private int maxNgramLength = -1;
    @Builder.Default  private int samplingThreshold = -1;
    private String labelPrefix;

    // Optional for training
    @Getter private boolean supervised;
    @Getter private boolean quantize;
    @Getter private boolean predict;
    @Getter private boolean predict_prob;
    @Getter private boolean skipgram;
    @Getter private boolean cbow;
    @Getter private boolean nn;
    @Getter private boolean analogies;
    @Getter private String pretrainedVectorsFile;
    @Getter
    @Builder.Default
    private double learningRate = -1.0;
    @Getter private double learningRateUpdate = -1.0;
    @Getter
    @Builder.Default
    private int dim = -1;
    @Getter
    @Builder.Default
    private int contextWindowSize = -1;
    @Getter
    @Builder.Default
    private int epochs = -1;
    @Getter private String modelName;
    @Getter private String lossName;
    @Getter
    @Builder.Default
    private int negativeSamples = -1;
    @Getter
    @Builder.Default
    private int numThreads = -1;
    @Getter private boolean saveOutput = false;

    // Optional for quantization
    @Getter
    @Builder.Default
    private int cutOff = -1;
    @Getter private boolean retrain;
    @Getter private boolean qnorm;
    @Getter private boolean qout;
    @Getter
    @Builder.Default
    private int dsub = -1;

    @Getter private SentenceIterator iterator;

    @Builder.Default private transient JFastText fastTextImpl = new JFastText();
    private transient Word2Vec word2Vec;
    @Getter private boolean modelLoaded;
    @Getter private boolean modelVectorsLoaded;
    private VocabCache vocabCache;

    public FastText(File modelPath) {
        this();
        loadBinaryModel(modelPath.getAbsolutePath());
    }

    public FastText() {
        fastTextImpl = new JFastText();
    }

    private static class ArgsFactory {

        private List<String> args = new ArrayList<>();

        private void add(String label, String value) {
            args.add(label);
            args.add(value);
        }

        private void addOptional(String label, int value) {
            if (value >= 0) {
                args.add(label);
                args.add(Integer.toString(value));
            }
        }

        private void addOptional(String label, double value) {
            if (value >= 0.0) {
                args.add(label);
                args.add(Double.toString(value));
            }
        }

        private void addOptional(String label, String value) {
            if (StringUtils.isNotEmpty(value)) {
                args.add(label);
                args.add(value);
            }
        }

        private void addOptional(String label, boolean value) {
            if (value) {
                args.add(label);
            }
        }


        public String[] args() {
            String[] asArray = new String[args.size()];
            return args.toArray(asArray);
        }
    }

    private String[] makeArgs() {
        ArgsFactory argsFactory = new ArgsFactory();

        argsFactory.addOptional("cbow", cbow);
        argsFactory.addOptional("skipgram", skipgram);
        argsFactory.addOptional("supervised", supervised);
        argsFactory.addOptional("quantize", quantize);
        argsFactory.addOptional("predict", predict);
        argsFactory.addOptional("predict_prob", predict_prob);

        argsFactory.add("-input", inputFile);
        argsFactory.add("-output", outputFile );

        argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile);

        argsFactory.addOptional("-bucket", bucket);
        argsFactory.addOptional("-minCount", minCount);
        argsFactory.addOptional("-minCountLabel", minCountLabel);
        argsFactory.addOptional("-wordNgrams", wordNgrams);
        argsFactory.addOptional("-minn", minNgramLength);
        argsFactory.addOptional("-maxn", maxNgramLength);
        argsFactory.addOptional("-t", samplingThreshold);
        argsFactory.addOptional("-label", labelPrefix);
        argsFactory.addOptional("analogies",analogies);
        argsFactory.addOptional("-lr", learningRate);
        argsFactory.addOptional("-lrUpdateRate", learningRateUpdate);
        argsFactory.addOptional("-dim", dim);
        argsFactory.addOptional("-ws", contextWindowSize);
        argsFactory.addOptional("-epoch", epochs);
        argsFactory.addOptional("-loss", lossName);
        argsFactory.addOptional("-neg", negativeSamples);
        argsFactory.addOptional("-thread", numThreads);
        argsFactory.addOptional("-saveOutput", saveOutput);
        argsFactory.addOptional("-cutoff", cutOff);
        argsFactory.addOptional("-retrain", retrain);
        argsFactory.addOptional("-qnorm", qnorm);
        argsFactory.addOptional("-qout", qout);
        argsFactory.addOptional("-dsub", dsub);

        return argsFactory.args();
    }

    public void fit() {
        String[] cmd = makeArgs();
        fastTextImpl.runCmd(cmd);
    }

    public void loadIterator() {
        if (iterator != null) {
            try {
                File tempFile = File.createTempFile("FTX", ".txt");
                BufferedWriter writer = new BufferedWriter(new FileWriter(tempFile));
                while (iterator.hasNext()) {
                    String sentence = iterator.nextSentence();
                    writer.write(sentence);
                }

                fastTextImpl = new JFastText();
            } catch (IOException e) {
                log.error(e.getMessage());
            }
        }
    }

    public void loadPretrainedVectors(File vectorsFile) {
        word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
        modelVectorsLoaded = true;
        log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
                vectorsFile.getAbsolutePath());
    }

    public void loadBinaryModel(String modelPath) {
        fastTextImpl.loadModel(modelPath);

        modelLoaded = true;
    }

    public void unloadBinaryModel() {
        fastTextImpl.unloadModel();
        modelLoaded = false;
    }

    public void test(File testFile) {
        fastTextImpl.test(testFile.getAbsolutePath());
    }

    private void assertModelLoaded() {
        if (!modelLoaded && !modelVectorsLoaded)
            throw new IllegalStateException("Model must be loaded before predict!");
    }

    public String predict(String text) {

        assertModelLoaded();

        String label = fastTextImpl.predict(text);
        return label;
    }

    public Pair<String, Float> predictProbability(String text) {

        assertModelLoaded();

        JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);

        Pair<String,Float> retVal = new Pair<>();
        retVal.setFirst(predictedProbLabel.label);
        retVal.setSecond(predictedProbLabel.logProb);
        return retVal;
    }

    @Override
    public VocabCache vocab() {
        if (modelVectorsLoaded) {
            vocabCache = word2Vec.vocab();
        }
        else {
            if (!modelLoaded)
                throw new IllegalStateException("Load model before calling vocab()");

            if (vocabCache == null) {
                vocabCache = new AbstractCache();
            }
            List<String> words = fastTextImpl.getWords();
            for (int i = 0; i < words.size(); ++i) {
                vocabCache.addWordToIndex(i, words.get(i));
                VocabWord word = new VocabWord();
                word.setWord(words.get(i));
                vocabCache.addToken(word);
            }
        }
        return vocabCache;
    }

    @Override
    public long vocabSize() {
        long result = 0;
        if (modelVectorsLoaded) {
            result = word2Vec.vocabSize();
        }
        else {
            if (!modelLoaded)
                throw new IllegalStateException("Load model before calling vocab()");
            result = fastTextImpl.getNWords();
        }
        return result;
    }

    @Override
    public String getUNK() {
        throw new NotImplementedException("FastText.getUNK");
    }

    @Override
    public void setUNK(String input) {
        throw new NotImplementedException("FastText.setUNK");
    }

    @Override
    public double[] getWordVector(String word) {
        if (modelVectorsLoaded) {
            return word2Vec.getWordVector(word);
        }
        else {
            List<Float> vectors = fastTextImpl.getVector(word);
            double[] retVal = new double[vectors.size()];
            for (int i = 0; i < vectors.size(); ++i) {
                retVal[i] = vectors.get(i);
            }
            return retVal;
        }
    }

    @Override
    public INDArray getWordVectorMatrixNormalized(String word) {
        if (modelVectorsLoaded) {
            return word2Vec.getWordVectorMatrixNormalized(word);
        }
        else {
            INDArray r = getWordVectorMatrix(word);
            return r.divi(Nd4j.getBlasWrapper().nrm2(r));
        }
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        if (modelVectorsLoaded) {
            return word2Vec.getWordVectorMatrix(word);
        }
        else {
            double[] values = getWordVector(word);
            return Nd4j.createFromArray(values);
        }
    }

    @Override
    public INDArray getWordVectors(Collection<String> labels) {
        if (modelVectorsLoaded) {
            return word2Vec.getWordVectors(labels);
        }
        return null;
    }

    @Override
    public INDArray getWordVectorsMean(Collection<String> labels) {
        if (modelVectorsLoaded) {
            return word2Vec.getWordVectorsMean(labels);
        }
        return null;
    }

    private List<String> words = new ArrayList<>();

    @Override
    public boolean hasWord(String word) {
        if (modelVectorsLoaded) {
            return word2Vec.outOfVocabularySupported();
        }
        if (words.isEmpty())
            words = fastTextImpl.getWords();
        return words.contains(word);
    }

    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearest(words, top);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public Collection<String> wordsNearestSum(INDArray words, int top) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearestSum(words, top);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearestSum(word, n);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }


    @Override
    public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearestSum(positive, negative, top);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        if (modelVectorsLoaded) {
            return word2Vec.accuracy(questions);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public int indexOf(String word) {
        if (modelVectorsLoaded) {
            return word2Vec.indexOf(word);
        }
        return vocab().indexOf(word);
    }


    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        if (modelVectorsLoaded) {
            return word2Vec.similarWordsInVocabTo(word, accuracy);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearest(positive, negative, top);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }


    @Override
    public Collection<String> wordsNearest(String word, int n) {
        if (modelVectorsLoaded) {
            return word2Vec.wordsNearest(word,n);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }


    @Override
    public double similarity(String word, String word2) {
        if (modelVectorsLoaded) {
            return word2Vec.similarity(word, word2);
        }
        throw new IllegalStateException(METHOD_NOT_AVAILABLE);
    }

    @Override
    public WeightLookupTable lookupTable() {
        if (modelVectorsLoaded) {
            return word2Vec.lookupTable();
        }
        return null;
    }

    @Override
    public void setModelUtils(ModelUtils utils) {
    }

    @Override
    public void loadWeightsInto(INDArray array) {}

    @Override
    public int vectorSize() {return -1;}

    @Override
    public boolean jsonSerializable() {return false;}

    public double getLearningRate() {
        return fastTextImpl.getLr();
    }

    public int getDimension() {
        return fastTextImpl.getDim();
    }

    public int getContextWindowSize() {
        return fastTextImpl.getContextWindowSize();
    }

    public int getEpoch() {
        return fastTextImpl.getEpoch();
    }

    public int getNegativesNumber() {
        return fastTextImpl.getNSampledNegatives();
    }

    public int getWordNgrams() {
        return fastTextImpl.getWordNgrams();
    }

    public String getLossName() {
        return fastTextImpl.getLossName();
    }

    public String getModelName() {
        return fastTextImpl.getModelName();
    }

    public int getNumberOfBuckets() {
        return fastTextImpl.getBucket();
    }

    public String getLabelPrefix() {
        return fastTextImpl.getLabelPrefix();
    }

    @Override
    public boolean outOfVocabularySupported() {
        return true;
    }

}