deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/CnnSentenceDataSetIterator.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.iterator;

import lombok.AllArgsConstructor;
import lombok.NonNull;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;

import java.util.*;

@AllArgsConstructor
public class CnnSentenceDataSetIterator implements DataSetIterator {
    public enum UnknownWordHandling {
        RemoveWord, UseUnknownVector
    }

    /**
     * Format of features:<br>
     * CNN1D: For use with 1d convolution layers: Shape [minibatch, vectorSize, sentenceLength]<br>
     * CNN2D: For use with 2d convolution layers: Shape [minibatch, 1, vectorSize, sentenceLength] or [minibatch, 1, sentenceLength, vectorSize],
     * depending on the setting for 'sentencesAlongHeight' configuration.
     */
    public enum Format {
        RNN, CNN1D, CNN2D
    }

    private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";

    private Format format;
    private LabeledSentenceProvider sentenceProvider;
    private WordVectors wordVectors;
    private TokenizerFactory tokenizerFactory;
    private UnknownWordHandling unknownWordHandling;
    private boolean useNormalizedWordVectors;
    private int minibatchSize;
    private int maxSentenceLength;
    private boolean sentencesAlongHeight;
    private DataSetPreProcessor dataSetPreProcessor;

    private int wordVectorSize;
    private int numClasses;
    private Map<String, Integer> labelClassMap;
    private INDArray unknown;

    private int cursor = 0;

    private Pair<List<String>, String> preLoadedTokens;

    protected CnnSentenceDataSetIterator(Builder builder) {
        this.format = builder.format;
        this.sentenceProvider = builder.sentenceProvider;
        this.wordVectors = builder.wordVectors;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.unknownWordHandling = builder.unknownWordHandling;
        this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
        this.minibatchSize = builder.minibatchSize;
        this.maxSentenceLength = builder.maxSentenceLength;
        this.sentencesAlongHeight = builder.sentencesAlongHeight;
        this.dataSetPreProcessor = builder.dataSetPreProcessor;


        this.numClasses = this.sentenceProvider.numLabelClasses();
        this.labelClassMap = new HashMap<>();
        int count = 0;
        //First: sort the labels to ensure the same label assignment order (say train vs. test)
        List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
        Collections.sort(sortedLabels);

        this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;

        for (String s : sortedLabels) {
            this.labelClassMap.put(s, count++);
        }
        if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
            if (useNormalizedWordVectors) {
                unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
            } else {
                unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK());
            }

            if(unknown == null){
                unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like();
            }
        }
    }

    /**
     * Generally used post training time to load a single sentence for predictions
     */
    public INDArray loadSingleSentence(String sentence) {
        List<String> tokens = tokenizeSentence(sentence);
        if(tokens.isEmpty())
            throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" +
                    sentence + "\"");
        if(format == Format.CNN1D || format == Format.RNN) {
            int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
            INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));
            INDArrayIndex[] indices = new INDArrayIndex[3];
            indices[0] = NDArrayIndex.point(0);
            for (int i = 0; i < featuresShape[2]; i++) {
                INDArray vector = getVector(tokens.get(i));
                indices[1] = NDArrayIndex.all();
                indices[2] = NDArrayIndex.point(i);
                features.put(indices, vector);
            }
            return features;
        } else {
            int[] featuresShape = new int[] {1, 1, 0, 0};
            if (sentencesAlongHeight) {
                featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
                featuresShape[3] = wordVectorSize;
            } else {
                featuresShape[2] = wordVectorSize;
                featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
            }

            INDArray features = Nd4j.create(featuresShape);
            int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
            INDArrayIndex[] indices = new INDArrayIndex[4];
            indices[0] = NDArrayIndex.point(0);
            indices[1] = NDArrayIndex.point(0);
            for (int i = 0; i < length; i++) {
                INDArray vector = getVector(tokens.get(i));

                if (sentencesAlongHeight) {
                    indices[2] = NDArrayIndex.point(i);
                    indices[3] = NDArrayIndex.all();
                } else {
                    indices[2] = NDArrayIndex.all();
                    indices[3] = NDArrayIndex.point(i);
                }

                features.put(indices, vector);
            }

            return features;
        }
    }

    private INDArray getVector(String word) {
        INDArray vector;
        if (unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL) { //Yes, this *should* be using == for the sentinel String here
            vector = unknown;
        } else {
            if (useNormalizedWordVectors) {
                vector = wordVectors.getWordVectorMatrixNormalized(word);
            } else {
                vector = wordVectors.getWordVectorMatrix(word);
            }
        }
        return vector;
    }

    private List<String> tokenizeSentence(String sentence) {
        Tokenizer t = tokenizerFactory.create(sentence);

        List<String> tokens = new ArrayList<>();
        while (t.hasMoreTokens()) {
            String token = t.nextToken();
            if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) {
                switch (unknownWordHandling) {
                    case RemoveWord:
                        continue;
                    case UseUnknownVector:
                        token = UNKNOWN_WORD_SENTINEL;
                }
            }
            tokens.add(token);
        }
        return tokens;
    }

    public Map<String, Integer> getLabelClassMap() {
        return new HashMap<>(labelClassMap);
    }

    @Override
    public List<String> getLabels() {
        //We don't want to just return the list from the LabelledSentenceProvider, as we sorted them earlier to do the
        // String -> Integer mapping
        String[] str = new String[labelClassMap.size()];
        for (Map.Entry<String, Integer> e : labelClassMap.entrySet()) {
            str[e.getValue()] = e.getKey();
        }
        return Arrays.asList(str);
    }

    @Override
    public boolean hasNext() {
        if (sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }

        while (preLoadedTokens == null && sentenceProvider.hasNext()) {
            //Pre-load tokens. Because we filter out empty strings, or sentences with no valid words
            //we need to pre-load some tokens. Otherwise, sentenceProvider could have 1 (invalid) sentence
            //next, hasNext() would return true, but next(int) wouldn't be able to return anything
            preLoadTokens();
        }

        return preLoadedTokens != null;
    }

    private void preLoadTokens() {
        if (preLoadedTokens != null) {
            return;
        }
        Pair<String, String> p = sentenceProvider.nextSentence();
        List<String> tokens = tokenizeSentence(p.getFirst());
        if (!tokens.isEmpty()) {
            preLoadedTokens = new Pair<>(tokens, p.getSecond());
        }
    }

    @Override
    public DataSet next() {
        return next(minibatchSize);
    }

    @Override
    public DataSet next(int num) {
        if (sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        if (!hasNext()) {
            throw new NoSuchElementException("No next element");
        }


        List<Pair<List<String>, String>> tokenizedSentences = new ArrayList<>(num);
        int maxLength = -1;
        int minLength = Integer.MAX_VALUE; //Track to we know if we can skip mask creation for "all same length" case
        if (preLoadedTokens != null) {
            tokenizedSentences.add(preLoadedTokens);
            maxLength = Math.max(maxLength, preLoadedTokens.getFirst().size());
            minLength = Math.min(minLength, preLoadedTokens.getFirst().size());
            preLoadedTokens = null;
        }
        for (int i = tokenizedSentences.size(); i < num && sentenceProvider.hasNext(); i++) {
            Pair<String, String> p = sentenceProvider.nextSentence();
            List<String> tokens = tokenizeSentence(p.getFirst());

            if (!tokens.isEmpty()) {
                //Handle edge case: no tokens from sentence
                maxLength = Math.max(maxLength, tokens.size());
                minLength = Math.min(minLength, tokens.size());
                tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
            } else {
                //Skip the current iterator
                i--;
            }
        }

        if (maxSentenceLength > 0 && maxLength > maxSentenceLength) {
            maxLength = maxSentenceLength;
        }

        int currMinibatchSize = tokenizedSentences.size();
        INDArray labels = Nd4j.create(currMinibatchSize, numClasses);
        for (int i = 0; i < tokenizedSentences.size(); i++) {
            String labelStr = tokenizedSentences.get(i).getSecond();
            if (!labelClassMap.containsKey(labelStr)) {
                throw new IllegalStateException("Got label \"" + labelStr
                                + "\" that is not present in list of LabeledSentenceProvider labels");
            }

            int labelIdx = labelClassMap.get(labelStr);

            labels.putScalar(i, labelIdx, 1.0);
        }

        INDArray features;
        INDArray featuresMask = null;
        if(format == Format.CNN1D || format == Format.RNN){
            int[] featuresShape = new int[]{currMinibatchSize, wordVectorSize, maxLength};
            features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));

            INDArrayIndex[] idxs = new INDArrayIndex[3];
            idxs[1] = NDArrayIndex.all();
            for (int i = 0; i < currMinibatchSize; i++) {
                idxs[0] = NDArrayIndex.point(i);
                List<String> currSentence = tokenizedSentences.get(i).getFirst();
                for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) {
                    idxs[2] = NDArrayIndex.point(j);
                    INDArray vector = getVector(currSentence.get(j));
                    features.put(idxs, vector);
                }
            }

            if (minLength != maxLength) {
                featuresMask = Nd4j.create(currMinibatchSize, maxLength);
                for (int i = 0; i < currMinibatchSize; i++) {
                    int sentenceLength = tokenizedSentences.get(i).getFirst().size();
                    if (sentenceLength >= maxLength) {
                        featuresMask.getRow(i).assign(1.0);
                    } else {
                        featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
                    }
                }
            }

        } else {
            int[] featuresShape = new int[4];
            featuresShape[0] = currMinibatchSize;
            featuresShape[1] = 1;
            if (sentencesAlongHeight) {
                featuresShape[2] = maxLength;
                featuresShape[3] = wordVectorSize;
            } else {
                featuresShape[2] = wordVectorSize;
                featuresShape[3] = maxLength;
            }

            features = Nd4j.create(featuresShape);
            INDArrayIndex[] indices = new INDArrayIndex[4];
            indices[1] = NDArrayIndex.point(0);
            for (int i = 0; i < currMinibatchSize; i++) {
                indices[0] = NDArrayIndex.point(i);
                List<String> currSentence = tokenizedSentences.get(i).getFirst();
                for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) {
                    INDArray vector = getVector(currSentence.get(j));

                    if (sentencesAlongHeight) {
                        indices[2] = NDArrayIndex.point(j);
                        indices[3] = NDArrayIndex.all();
                    } else {
                        indices[2] = NDArrayIndex.all();
                        indices[3] = NDArrayIndex.point(j);
                    }

                    features.put(indices, vector);
                }
            }

            if (minLength != maxLength) {
                if(sentencesAlongHeight){
                    featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1);
                    for (int i = 0; i < currMinibatchSize; i++) {
                        int sentenceLength = tokenizedSentences.get(i).getFirst().size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice(i).assign(1.0);
                        } else {
                            featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0);
                        }
                    }
                } else {
                    featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength);
                    for (int i = 0; i < currMinibatchSize; i++) {
                        int sentenceLength = tokenizedSentences.get(i).getFirst().size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice(i).assign(1.0);
                        } else {
                            featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
                        }
                    }
                }
            }
        }

        DataSet ds = new DataSet(features, labels, featuresMask, null);

        if (dataSetPreProcessor != null) {
            dataSetPreProcessor.preProcess(ds);
        }

        cursor += ds.numExamples();
        return ds;
    }

    @Override
    public int inputColumns() {
        return wordVectorSize;
    }

    @Override
    public int totalOutcomes() {
        return numClasses;
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return true;
    }

    @Override
    public void reset() {
        cursor = 0;
        sentenceProvider.reset();
    }

    @Override
    public int batch() {
        return minibatchSize;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.dataSetPreProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return dataSetPreProcessor;
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public static class Builder {

        private Format format;
        private LabeledSentenceProvider sentenceProvider = null;
        private WordVectors wordVectors;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
        private boolean useNormalizedWordVectors = true;
        private int maxSentenceLength = -1;
        private int minibatchSize = 32;
        private boolean sentencesAlongHeight = true;
        private DataSetPreProcessor dataSetPreProcessor;

        /**
         * @deprecated Due to old default, that will be changed in the future. Use {@link #Builder(Format)} to specify
         * the {@link Format} of the activations
         */
        @Deprecated
        public Builder(){
            //Default for backward compatibility
            this(Format.CNN2D);
        }

        /**
         * @param format The format to use for the features - i.e., for 1D or 2D CNNs
         */
        public Builder(@NonNull Format format){
            this.format = format;
        }

        /**
         * Specify how the (labelled) sentences / documents should be provided
         */
        public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
            this.sentenceProvider = labeledSentenceProvider;
            return this;
        }

        /**
         * Specify how the (labelled) sentences / documents should be provided
         */
        public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List<String> labels) {
            LabelAwareConverter converter = new LabelAwareConverter(iterator, labels);
            return sentenceProvider(converter);
        }

        /**
         * Specify how the (labelled) sentences / documents should be provided
         */
        public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List<String> labels) {
            DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator);
            return sentenceProvider(converter, labels);
        }

        /**
         * Specify how the (labelled) sentences / documents should be provided
         */
        public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List<String> labels) {
            SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator);
            return sentenceProvider(converter, labels);
        }


        /**
         * Provide the WordVectors instance that should be used for training
         */
        public Builder wordVectors(WordVectors wordVectors) {
            this.wordVectors = wordVectors;
            return this;
        }

        /**
         * The {@link TokenizerFactory} that should be used. Defaults to {@link DefaultTokenizerFactory}
         */
        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        /**
         * Specify how unknown words (those that don't have a word vector in the provided WordVectors instance) should be
         * handled. Default: remove/ignore unknown words.
         */
        public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
            this.unknownWordHandling = unknownWordHandling;
            return this;
        }

        /**
         * Minibatch size to use for the DataSetIterator
         */
        public Builder minibatchSize(int minibatchSize) {
            this.minibatchSize = minibatchSize;
            return this;
        }

        /**
         * Whether normalized word vectors should be used. Default: true
         */
        public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
            this.useNormalizedWordVectors = useNormalizedWordVectors;
            return this;
        }

        /**
         * Maximum sentence/document length. If sentences exceed this, they will be truncated to this length by
         * taking the first 'maxSentenceLength' known words.
         */
        public Builder maxSentenceLength(int maxSentenceLength) {
            this.maxSentenceLength = maxSentenceLength;
            return this;
        }

        /**
         * If true (default): output features data with shape [minibatchSize, 1, maxSentenceLength, wordVectorSize]<br>
         * If false: output features with shape [minibatchSize, 1, wordVectorSize, maxSentenceLength]
         */
        public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
            this.sentencesAlongHeight = sentencesAlongHeight;
            return this;
        }

        /**
         * Optional DataSetPreProcessor
         */
        public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
            this.dataSetPreProcessor = dataSetPreProcessor;
            return this;
        }

        public CnnSentenceDataSetIterator build() {
            if (wordVectors == null) {
                throw new IllegalStateException(
                                "Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
            }

            return new CnnSentenceDataSetIterator(this);
        }

    }
}