deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java

Summary

Maintainability
D
2 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.datasets.datavec;

import lombok.Getter;
import lombok.Setter;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposable;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;

public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
    /**Alignment mode for dealing with input/labels of differing lengths (for example, one-to-many and many-to-one type situations).
     * For example, might have 10 time steps total but only one label at end for sequence classification.<br>
     * Currently supported modes:<br>
     * <b>EQUAL_LENGTH</b>: Default. Assume that label and input time series are of equal length, and all examples are of
     * the same length<br>
     * <b>ALIGN_START</b>: Align the label/input time series at the first time step, and zero pad either the labels or
     * the input at the end<br>
     * <b>ALIGN_END</b>: Align the label/input at the last time step, zero padding either the input or the labels as required<br>
     *
     * Note 1: When the time series for each example are of different lengths, the shorter time series will be padded to
     * the length of the longest time series.<br>
     * Note 2: When ALIGN_START or ALIGN_END are used, the DataSet masking functionality is used. Thus, the returned DataSets
     * will have the input and mask arrays set. These mask arrays identify whether an input/label is actually present,
     * or whether the value is merely masked.<br>
     */
    public enum AlignmentMode {
        EQUAL_LENGTH, ALIGN_START, ALIGN_END
    }

    private static final String READER_KEY = "reader";
    private static final String READER_KEY_LABEL = "reader_labels";

    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize = 10;
    private final boolean regression;
    private int labelIndex = -1;
    private final int numPossibleLabels;
    private int cursor = 0;
    private int inputColumns = -1;
    private int totalOutcomes = -1;
    private boolean useStored = false;
    private DataSet stored = null;
    @Getter
    private DataSetPreProcessor preProcessor;
    private AlignmentMode alignmentMode;

    private final boolean singleSequenceReaderMode;

    @Getter
    @Setter
    private boolean collectMetaData = false;

    private RecordReaderMultiDataSetIterator underlying;
    private boolean underlyingIsDisjoint;

    /**
     * Constructor where features and labels come from different RecordReaders (for example, different files),
     * and labels are for classification.
     *
     * @param featuresReader       SequenceRecordReader for the features
     * @param labels               Labels: assume single value per time step, where values are integers in the range 0 to numPossibleLables-1
     * @param miniBatchSize        Minibatch size for each call of next()
     * @param numPossibleLabels    Number of classes for the labels
     */
    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels,
                    int miniBatchSize, int numPossibleLabels) {
        this(featuresReader, labels, miniBatchSize, numPossibleLabels, false);
    }

    /**
     * Constructor where features and labels come from different RecordReaders (for example, different files)
     */
    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels,
                    int miniBatchSize, int numPossibleLabels, boolean regression) {
        this(featuresReader, labels, miniBatchSize, numPossibleLabels, regression, AlignmentMode.EQUAL_LENGTH);
    }

    /**
     * Constructor where features and labels come from different RecordReaders (for example, different files)
     */
    public SequenceRecordReaderDataSetIterator(SequenceRecordReader featuresReader, SequenceRecordReader labels,
                    int miniBatchSize, int numPossibleLabels, boolean regression, AlignmentMode alignmentMode) {
        this.recordReader = featuresReader;
        this.labelsReader = labels;
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.alignmentMode = alignmentMode;
        this.singleSequenceReaderMode = false;
    }

    /** Constructor where features and labels come from the SAME RecordReader (i.e., target/label is a column in the
     * same data as the features). Defaults to regression = false - i.e., for classification
     * @param reader SequenceRecordReader with data
     * @param miniBatchSize size of each minibatch
     * @param numPossibleLabels number of labels/classes for classification
     * @param labelIndex index in input of the label index. If in regression mode and numPossibleLabels > 1, labelIndex denotes the
     *                   first index for labels. Everything before that index will be treated as input(s) and
     *                   everything from that index (inclusive) to the end will be treated as output(s)
     */
    public SequenceRecordReaderDataSetIterator(SequenceRecordReader reader, int miniBatchSize, int numPossibleLabels,
                    int labelIndex) {
        this(reader, miniBatchSize, numPossibleLabels, labelIndex, false);
    }

    /** Constructor where features and labels come from the SAME RecordReader (i.e., target/label is a column in the
     * same data as the features)
     * @param reader SequenceRecordReader with data
     * @param miniBatchSize size of each minibatch
     * @param numPossibleLabels number of labels/classes for classification
     * @param labelIndex index in input of the label index. If in regression mode and numPossibleLabels > 1, labelIndex denotes the
     *                   first index for labels. Everything before that index will be treated as input(s) and
     *                   everything from that index (inclusive) to the end will be treated as output(s)
     * @param regression Whether output is for regression or classification
     */
    public SequenceRecordReaderDataSetIterator(SequenceRecordReader reader, int miniBatchSize, int numPossibleLabels,
                    int labelIndex, boolean regression) {
        this.recordReader = reader;
        this.labelsReader = null;
        this.miniBatchSize = miniBatchSize;
        this.regression = regression;
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.singleSequenceReaderMode = true;
    }

    private void initializeUnderlyingFromReader() {
        initializeUnderlying(recordReader.nextSequence());
        underlying.reset();
    }

    private void initializeUnderlying(SequenceRecord nextF) {
        if (nextF.getSequenceRecord().isEmpty()) {
            throw new ZeroLengthSequenceException();
        }
        int totalSizeF = nextF.getSequenceRecord().get(0).size();

        //allow people to specify label index as -1 and infer the last possible label
        if (singleSequenceReaderMode && numPossibleLabels >= 1 && labelIndex < 0) {
            labelIndex = totalSizeF - 1;
        } else if (!singleSequenceReaderMode && numPossibleLabels >= 1 && labelIndex < 0) {
            labelIndex = 0;
        }

        recordReader.reset();

        //Add readers
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize);
        builder.addSequenceReader(READER_KEY, recordReader);
        if (labelsReader != null) {
            builder.addSequenceReader(READER_KEY_LABEL, labelsReader);
        }


        //Add outputs
        if (singleSequenceReaderMode) {

            if (labelIndex < 0 && numPossibleLabels < 0) {
                //No labels - all values -> features array
                builder.addInput(READER_KEY);
            } else if (labelIndex == 0 || labelIndex == totalSizeF - 1) {  //Features: subset of columns
                //Labels are first or last -> one input in underlying
                int inputFrom;
                int inputTo;
                if (labelIndex < 0) {
                    //No label
                    inputFrom = 0;
                    inputTo = totalSizeF - 1;
                } else if (labelIndex == 0) {
                    inputFrom = 1;
                    inputTo = totalSizeF - 1;
                } else {
                    inputFrom = 0;
                    inputTo = labelIndex - 1;
                }

                builder.addInput(READER_KEY, inputFrom, inputTo);

                underlyingIsDisjoint = false;
            } else if (regression && numPossibleLabels > 1){
                //Multiple inputs and multiple outputs
                int inputFrom = 0;
                int inputTo = labelIndex - 1;
                int outputFrom = labelIndex;
                int outputTo = totalSizeF - 1;

                builder.addInput(READER_KEY, inputFrom, inputTo);
                builder.addOutput(READER_KEY, outputFrom, outputTo);

                underlyingIsDisjoint = false;
            } else {
                //Multiple inputs (disjoint features case)
                int firstFrom = 0;
                int firstTo = labelIndex - 1;
                int secondFrom = labelIndex + 1;
                int secondTo = totalSizeF - 1;

                builder.addInput(READER_KEY, firstFrom, firstTo);
                builder.addInput(READER_KEY, secondFrom, secondTo);

                underlyingIsDisjoint = true;
            }

            if(!(labelIndex < 0 && numPossibleLabels < 0)) {
                if (regression && numPossibleLabels <= 1) {
                    //Multiple output regression already handled
                    builder.addOutput(READER_KEY, labelIndex, labelIndex);
                } else if (!regression) {
                    builder.addOutputOneHot(READER_KEY, labelIndex, numPossibleLabels);
                }
            }
        } else {

            //Features: entire reader
            builder.addInput(READER_KEY);
            underlyingIsDisjoint = false;

            if (regression) {
                builder.addOutput(READER_KEY_LABEL);
            } else {
                builder.addOutputOneHot(READER_KEY_LABEL, 0, numPossibleLabels);
            }
        }

        if (alignmentMode != null) {
            switch (alignmentMode) {
                case EQUAL_LENGTH:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.EQUAL_LENGTH);
                    break;
                case ALIGN_START:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START);
                    break;
                case ALIGN_END:
                    builder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END);
                    break;
            }
        }

        underlying = builder.build();

        if (collectMetaData) {
            underlying.setCollectMetaData(true);
        }
    }

    private DataSet mdsToDataSet(MultiDataSet mds) {
        INDArray f;
        INDArray fm;
        if (underlyingIsDisjoint) {
            //Rare case: 2 input arrays -> concat
            INDArray f1 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            INDArray f2 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 1);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0); //Per-example masking only on the input -> same for both

            //Can assume 3d features here
            f = Nd4j.createUninitialized(new long[] {f1.size(0), f1.size(1) + f2.size(1), f1.size(2)});
            f.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, f1.size(1)), NDArrayIndex.all()},
                            f1);
            f.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(f1.size(1), f1.size(1) + f2.size(1)),
                            NDArrayIndex.all()}, f2);
        } else {
            //Standard case
            f = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0);
        }

        INDArray l = RecordReaderDataSetIterator.getOrNull(mds.getLabels(), 0);
        INDArray lm = RecordReaderDataSetIterator.getOrNull(mds.getLabelsMaskArrays(), 0);

        DataSet ds = new DataSet(f, l, fm, lm);

        if (collectMetaData) {
            List<Serializable> temp = mds.getExampleMetaData();
            List<Serializable> temp2 = new ArrayList<>(temp.size());
            for (Serializable s : temp) {
                RecordMetaDataComposableMap m = (RecordMetaDataComposableMap) s;
                if (singleSequenceReaderMode) {
                    temp2.add(m.getMeta().get(READER_KEY));
                } else {
                    RecordMetaDataComposable c = new RecordMetaDataComposable(m.getMeta().get(READER_KEY),
                                    m.getMeta().get(READER_KEY_LABEL));
                    temp2.add(c);
                }
            }
            ds.setExampleMetaData(temp2);
        }

        if (preProcessor != null) {
            preProcessor.preProcess(ds);
        }

        return ds;
    }

    @Override
    public boolean hasNext() {
        if (underlying == null) {
            initializeUnderlyingFromReader();
        }
        return underlying.hasNext();
    }

    @Override
    public DataSet next() {
        return next(miniBatchSize);
    }


    @Override
    public DataSet next(int num) {
        if (useStored) {
            useStored = false;
            DataSet temp = stored;
            stored = null;
            if (preProcessor != null)
                preProcessor.preProcess(temp);
            return temp;
        }
        if (!hasNext())
            throw new NoSuchElementException();

        if (underlying == null) {
            initializeUnderlyingFromReader();
        }

        MultiDataSet mds = underlying.next(num);
        DataSet ds = mdsToDataSet(mds);

        if (totalOutcomes == -1) {
            inputColumns = (int) ds.getFeatures().size(1);
            totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1);
        }

        return ds;
    }

    @Override
    public int inputColumns() {
        if (inputColumns != -1)
            return inputColumns;
        preLoad();
        return inputColumns;
    }

    @Override
    public int totalOutcomes() {
        if (totalOutcomes != -1)
            return totalOutcomes;
        preLoad();
        return totalOutcomes;
    }

    private void preLoad() {
        stored = next();
        useStored = true;

        inputColumns = (int) stored.getFeatures().size(1);
        totalOutcomes = (int) stored.getLabels().size(1);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return true;
    }

    @Override
    public void reset() {
        if (underlying != null)
            underlying.reset();

        cursor = 0;
        stored = null;
        useStored = false;
    }

    @Override
    public int batch() {
        return miniBatchSize;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public List<String> getLabels() {
        return null;
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    /**
     * Load a single sequence example to a DataSet, using the provided RecordMetaData.
     * Note that it is more efficient to load multiple instances at once, using {@link #loadFromMetaData(List)}
     *
     * @param recordMetaData RecordMetaData to load from. Should have been produced by the given record reader
     * @return DataSet with the specified example
     * @throws IOException If an error occurs during loading of the data
     */
    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    /**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        if (underlying == null) {
            SequenceRecord r = recordReader.loadSequenceFromMetaData(list.get(0));
            initializeUnderlying(r);
        }

        //Two cases: single vs. multiple reader...
        List<RecordMetaData> l = new ArrayList<>(list.size());
        if (singleSequenceReaderMode) {
            for (RecordMetaData m : list) {
                l.add(new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, m)));
            }
        } else {
            for (RecordMetaData m : list) {
                RecordMetaDataComposable rmdc = (RecordMetaDataComposable) m;
                Map<String, RecordMetaData> map = new HashMap<>(2);
                map.put(READER_KEY, rmdc.getMeta()[0]);
                map.put(READER_KEY_LABEL, rmdc.getMeta()[1]);
                l.add(new RecordMetaDataComposableMap(map));
            }
        }

        return mdsToDataSet(underlying.loadFromMetaData(l));
    }
}