deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java

Summary

Maintainability
F
4 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.models.embeddings.learning.impl.elements;

import lombok.*;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.nlp.CbowInference;
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.shade.guava.cache.Cache;
import org.nd4j.shade.guava.cache.CacheBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;

public class CBOW<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;

    private static final Logger logger = LoggerFactory.getLogger(CBOW.class);


    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected int workers = Runtime.getRuntime().availableProcessors();
    private Cache<IterationArraysKey, Queue<IterationArrays>> iterationArrays = CacheBuilder.newBuilder()
            .maximumSize(Integer.parseInt(System.getProperty(DL4JSystemProperties.NLP_CACHE_SIZE,"10000")))
            .build();

    protected int maxQueueSize = Integer.parseInt(System.getProperty(DL4JSystemProperties.NLP_QUEUE_SIZE,"1000"));



    public int getWorkers() {
        return workers;
    }

    public void setWorkers(int workers) {
        this.workers = workers;
    }

    @Getter
    @Setter
    protected DeviceLocalNDArray syn0, syn1, syn1Neg, expTable, table;

    protected ThreadLocal<List<BatchItem<T>>> batches = new ThreadLocal<>();

    public List<BatchItem<T>> getBatch() {
        if(batches.get() == null)
            batches.set(new ArrayList<>());
        return batches.get();
    }


    public void addBatchItem(BatchItem<T> batchItem) {
        getBatch().add(batchItem);
    }


    public void clearBatch() {
        getBatch().clear();
    }

    @Override
    public String getCodeName() {
        return "CBOW";
    }

    @Override
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable,
                          @NonNull VectorsConfiguration configuration) {
        this.vocabCache = vocabCache;
        this.lookupTable = lookupTable;
        this.configuration = configuration;

        this.window = configuration.getWindow();
        this.useAdaGrad = configuration.isUseAdaGrad();
        this.negative = configuration.getNegative();
        this.sampling = configuration.getSampling();
        this.workers = configuration.getWorkers();
        if (configuration.getNegative() > 0) {
            if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
                logger.info("Initializing syn1Neg...");
                ((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
                ((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
                lookupTable.resetWeights(false);
            }
        }


        this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
        this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
        this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
        this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable(),
                new long[]{((InMemoryLookupTable<T>) lookupTable).getExpTable().length}, syn0.get() == null ? DataType.DOUBLE :  syn0.get().dataType()));
        this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
        this.variableWindows = configuration.getVariableWindows();
    }

    /**
     * CBOW doesn't involve any pretraining
     *
     * @param iterator
     */
    @Override
    public void pretrain(SequenceIterator<T> iterator) {
        // no-op
    }

    @Override
    public void finish() {
        if (batches != null && batches.get() != null && !batches.get().isEmpty()) {
            doExec(batches.get(),null);
            batches.get().clear();
        }
    }

    @Override
    public void finish(INDArray inferenceVector) {
        if (batches != null && batches.get() != null && !batches.get().isEmpty()) {
            doExec(batches.get(),inferenceVector);
            batches.get().clear();
        }
    }


    @Override
    public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate) {
        Sequence<T> tempSequence = sequence;
        if (sampling > 0)
            tempSequence = applySubsampling(sequence, nextRandom);

        int currentWindow = window;

        if (variableWindows != null && variableWindows.length != 0) {
            currentWindow = variableWindows[RandomUtils.nextInt(0, variableWindows.length)];
        }

        for (int i = 0; i < tempSequence.getElements().size(); i++) {
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
            cbow(i, tempSequence.getElements(), (int) nextRandom.get() % currentWindow, nextRandom, learningRate,
                    currentWindow, null);
        }

        if (getBatch() != null && getBatch().size() >= configuration.getBatchSize()) {
            doExec(getBatch(),null);
            getBatch().clear();
        }


        return 0;
    }

    @Override
    public boolean isEarlyTerminationHit() {
        return false;
    }

    @Data
    @AllArgsConstructor
    @Builder
    @NoArgsConstructor
    public static class IterationArraysKey {
        private int itemSize;
        private int maxCols;
    }


    public double doExec(List<BatchItem<T>> items, INDArray inferenceVector) {
        try(MemoryWorkspace workspace = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
            boolean useHS = configuration.isUseHierarchicSoftmax();
            boolean useNegative = configuration.getNegative() > 0;
            boolean useInference = inferenceVector != null;
            if (items.size() > 1) {
                int maxCols = 1;
                for (int i = 0; i < items.size(); i++) {
                    int curr = items.get(i).getWord().getCodeLength();
                    if (curr > maxCols)
                        maxCols = curr;
                }


                boolean hasNumLabels = false;

                int maxWinWordsCols = -1;
                for (int i = 0; i < items.size(); ++i) {
                    int curr = items.get(i).getWord().getCodeLength();
                    if (curr > maxWinWordsCols)
                        maxWinWordsCols = curr;
                }





                INDArray inputWindowWords;
                INDArray inputWordsStatuses;
                INDArray randoms;
                INDArray alphas;
                INDArray currentWindowIndexes;
                INDArray codes;
                INDArray indices;
                INDArray numLabelsArray;


                IterationArraysKey key = IterationArraysKey.builder()
                        .itemSize(items.size())
                        .maxCols(maxCols).build();
                Queue<IterationArrays> iterationArraysQueue = iterationArrays.getIfPresent(key);
                IterationArrays iterationArrays1;
                if(iterationArraysQueue == null) {
                    iterationArraysQueue = new ConcurrentLinkedQueue<>();
                    iterationArrays.put(key,iterationArraysQueue);
                    iterationArrays1 = new IterationArrays(items.size(),maxCols,maxWinWordsCols);
                } else {
                    if(iterationArraysQueue.isEmpty()) {
                        iterationArrays1 = new IterationArrays(items.size(),maxCols,maxWinWordsCols);

                    }else {
                        try {
                            iterationArrays1 = iterationArraysQueue.remove();
                            iterationArrays1.initCodes();
                        } catch (NoSuchElementException e) {
                            iterationArrays1 = new IterationArrays(items.size(),maxCols);
                        }

                    }
                }

                int[][] inputWindowWordsArr = iterationArrays1.inputWindowWordsArr;
                int[][] inputWindowWordStatuses = iterationArrays1.inputWindowWordStatuses;
                int[]  currentWindowIndexesArr = iterationArrays1.currentWindowIndexes;
                double[] alphasArr = iterationArrays1.alphas;
                int[][] indicesArr = iterationArrays1.indicesArr;
                int[][]  codesArr = iterationArrays1.codesArr;
                long[] randomValues = iterationArrays1.randomValues;
                int[] numLabelsArr = iterationArrays1.numLabels;
                currentWindowIndexes = Nd4j.createFromArray(currentWindowIndexesArr);

                for (int cnt = 0; cnt < items.size(); cnt++) {
                    T currentWord = items.get(cnt).getWord();
                    currentWindowIndexes.putScalar(0, currentWord.getIndex());
                    currentWindowIndexesArr[0] = currentWord.getIndex();
                    int[] windowWords = items.get(cnt).getWindowWords().clone();
                    boolean[] windowStatuses = items.get(cnt).getWordStatuses().clone();

                    for (int i = 0; i < maxWinWordsCols; i++) {
                        if (i < windowWords.length) {
                            inputWindowWordsArr[cnt][i] = windowWords[i];
                            inputWindowWordStatuses[cnt][i] = windowStatuses[i] ? 1 : 0;

                        } else {
                            inputWindowWordsArr[cnt][i] = -1;
                            inputWindowWordStatuses[cnt][i] = -1;
                        }
                    }

                    long randomValue = items.get(cnt).getRandomValue();
                    double alpha = items.get(cnt).getAlpha();
                    alphasArr[cnt] = alpha;
                    randomValues[cnt] = randomValue;
                    numLabelsArr[cnt] = items.get(cnt).getNumLabel();
                    if (items.get(cnt).getNumLabel() > 0)
                        hasNumLabels = true;

                    if (useHS) {
                        for (int p = 0; p < currentWord.getCodeLength(); p++) {
                            if (currentWord.getPoints().get(p) < 0)
                                continue;


                            codesArr[cnt][p] = currentWord.getCodes().get(p);
                            indicesArr[cnt][p] =  currentWord.getPoints().get(p);
                        }

                    }

                    if (negative > 0) {
                        if (syn1Neg == null) {
                            ((InMemoryLookupTable<T>) lookupTable).initNegative();
                            syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
                        }
                    }

                }


                inputWindowWords = Nd4j.createFromArray(inputWindowWordsArr);
                inputWordsStatuses = Nd4j.createFromArray(inputWindowWordStatuses);
                numLabelsArray = Nd4j.createFromArray(numLabelsArr);
                indices = Nd4j.createFromArray(indicesArr);
                codes = Nd4j.createFromArray(codesArr);
                alphas = Nd4j.createFromArray(alphasArr);
                randoms = Nd4j.createFromArray(randomValues);
                CbowRound cbow = CbowRound.builder()
                        .target(currentWindowIndexes)
                        .context(inputWindowWords)
                        .lockedWords(inputWordsStatuses)
                        .ngStarter(currentWindowIndexes)
                        .syn0(syn0.get())
                        .syn1(useHS ? syn1.get() : Nd4j.empty(syn0.get().dataType()))
                        .syn1Neg((negative > 0) ? syn1Neg.get() : Nd4j.empty(syn0.get().dataType()))
                        .expTable(expTable.get())
                        .negTable(negative > 0 ? table.get() : Nd4j.empty(syn0.get().dataType()))
                        .indices(useHS ? indices : Nd4j.empty(DataType.INT32))
                        .codes(useHS ? codes : Nd4j.empty(DataType.INT8))
                        .nsRounds((int) negative)
                        .alpha(alphas)
                        .nextRandom(randoms)
                        .inferenceVector(inferenceVector != null ? inferenceVector : Nd4j.empty(syn0.get().dataType()))
                        .numLabels(hasNumLabels ? numLabelsArray : Nd4j.empty(DataType.INT32))
                        .trainWords(configuration.isTrainElementsVectors())
                        .numWorkers(workers)
                        .iterations(useInference ? configuration.getIterations() * configuration.getEpochs() : 1)
                        .build();


                Nd4j.getExecutioner().exec(cbow);


                Nd4j.close(currentWindowIndexes,inputWindowWords,alphas,randoms,codes,numLabelsArray,indices);
                if(iterationArraysQueue.size() < maxQueueSize)
                    iterationArraysQueue.add(iterationArrays1);

                batches.get().clear();
                return 0.0;
            } else {
                int cnt = 0;
                T currentWord = items.get(cnt).getWord();
                int currentWindowIndex = currentWord.getIndex();

                int[] windowWords = items.get(cnt).getWindowWords().clone();
                boolean[] windowStatuses = items.get(cnt).getWordStatuses().clone();
                byte[] codes = new byte[currentWord.getCodeLength()];
                int[] points = new int[currentWord.getCodeLength()];
                long randomValue = items.get(cnt).getRandomValue();
                double alpha = items.get(cnt).getAlpha();
                int numLabels = items.get(cnt).getNumLabel();
                int[] inputStatuses = new int[windowWords.length];
                for (int i = 0; i < windowWords.length; ++i) {
                    if (i < windowStatuses.length)
                        inputStatuses[i] = windowStatuses[i] ? 1 : 0;
                    else
                        inputStatuses[i] = -1;
                }

                if (useHS) {
                    for (int p = 0; p < currentWord.getCodeLength(); p++) {
                        if (currentWord.getPoints().get(p) < 0)
                            continue;

                        codes[p] = currentWord.getCodes().get(p);
                        points[p] = currentWord.getPoints().get(p);
                    }

                }

                if (negative > 0) {
                    if (syn1Neg == null) {
                        ((InMemoryLookupTable<T>) lookupTable).initNegative();
                        syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
                    }
                } else {
                    syn1Neg.set(Nd4j.empty(syn0.get().dataType()));
                }

                CbowInference cbowInference = CbowInference.builder()
                        .target(currentWord.getIndex())
                        .ngStarter(currentWord.getIndex())
                        .negTable(table.get() == null ? Nd4j.empty(syn0.get().dataType()) : table.get())
                        .expTable(expTable.get() == null ? Nd4j.empty(syn0.get().dataType()) : expTable.get())
                        .syn0(syn0.get())
                        .syn1(!useHS ? Nd4j.empty(syn0.get().dataType()) : syn1.get())
                        .syn1Neg(syn1Neg.get())
                        .alpha(alpha)
                        .context(windowWords == null ? new int[0] : windowWords)
                        .indices(useHS || useNegative ? points : new int[0])
                        .codes(useHS ? codes : new byte[0])
                        .lockedWords(inputStatuses)
                        .randomValue((int) randomValue)
                        .numWorkers(workers)
                        .numLabels(numLabels)
                        .nsRounds(useNegative ? (int) negative : 0)
                        .preciseMode(configuration.isPreciseMode())
                        .inferenceVector(inferenceVector != null ? inferenceVector : Nd4j.empty(syn0.get().dataType()))
                        .iterations(useInference ? configuration.getIterations() * configuration.getEpochs() : 1)
                        .build();


                Nd4j.getExecutioner().exec(cbowInference);
            }

        }
        return 0.0;
    }

    public void cbow(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow,
                     List<BatchItem<T>> batch) {

        int end = window * 2 + 1 - b;

        T currentWord = sentence.get(i);

        List<Integer> intsList = new ArrayList<>();
        List<Boolean> statusesList = new ArrayList<>();
        for (int a = b; a < end; a++) {
            if (a != currentWindow) {
                int c = i - currentWindow + a;
                if (c >= 0 && c < sentence.size()) {
                    T lastWord = sentence.get(c);

                    intsList.add(lastWord.getIndex());
                    statusesList.add(lastWord.isLocked());
                }
            }
        }

        int[] windowWords = new int[intsList.size()];
        boolean[] statuses = new boolean[intsList.size()];
        for (int x = 0; x < windowWords.length; x++) {
            windowWords[x] = intsList.get(x);
            statuses[x] = statusesList.get(x);
        }

        BatchItem<T> batchItem = new BatchItem<>(currentWord,windowWords,statuses,nextRandom.get(),alpha);
        batch.add(batchItem);
        iterateBatchesIfReady(batch);


    }

    private double iterateBatchesIfReady(List<BatchItem<T>> batch) {
        double score = 0.0;
        if(batches.get() == null) {
            batches.set(batch);
        }
        else
            batches.get().addAll(batch);

        if(batches.get().size() >= configuration.getBatchSize()) {
            score = doExec(batches.get(),null);
            batches.get().clear();

        }
        return score;
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom) {
        Sequence<T> result = new Sequence<>();

        // subsampling implementation, if subsampling threshold met, just continue to next element
        if (sampling > 0) {
            result.setSequenceId(sequence.getSequenceId());
            if (sequence.getSequenceLabels() != null)
                result.setSequenceLabels(sequence.getSequenceLabels());
            if (sequence.getSequenceLabel() != null)
                result.setSequenceLabel(sequence.getSequenceLabel());

            for (T element : sequence.getElements()) {
                double numWords = vocabCache.totalWordOccurrences();
                double ran = (Math.sqrt(element.getElementFrequency() / (sampling * numWords)) + 1)
                        * (sampling * numWords) / element.getElementFrequency();

                nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));

                if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
                    continue;
                }
                result.addElement(element);
            }
            return result;
        } else
            return sequence;
    }
}