deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java

Summary

Maintainability
F
1 wk
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.datasets.datavec;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.datavec.api.records.Record;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;

@Getter
public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, Serializable {

    /**
     * When dealing with time series data of different lengths, how should we align the input/labels time series?
     * For equal length: use EQUAL_LENGTH
     * For sequence classification: use ALIGN_END
     */
    public enum AlignmentMode {
        EQUAL_LENGTH, ALIGN_START, ALIGN_END
    }

    private int batchSize;
    private AlignmentMode alignmentMode;
    private Map<String, RecordReader> recordReaders = new HashMap<>();
    private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap<>();

    private List<SubsetDetails> inputs = new ArrayList<>();
    private List<SubsetDetails> outputs = new ArrayList<>();

    @Getter
    @Setter
    private boolean collectMetaData = false;

    private boolean timeSeriesRandomOffset = false;
    private Random timeSeriesRandomOffsetRng;

    private MultiDataSetPreProcessor preProcessor;

    private boolean resetSupported = true;

    private RecordReaderMultiDataSetIterator(Builder builder) {
        this.batchSize = builder.batchSize;
        this.alignmentMode = builder.alignmentMode;
        this.recordReaders = builder.recordReaders;
        this.sequenceRecordReaders = builder.sequenceRecordReaders;
        this.inputs.addAll(builder.inputs);
        this.outputs.addAll(builder.outputs);
        this.timeSeriesRandomOffset = builder.timeSeriesRandomOffset;
        if (this.timeSeriesRandomOffset) {
            timeSeriesRandomOffsetRng = new Random(builder.timeSeriesRandomOffsetSeed);
        }


        if(recordReaders != null){
            for(RecordReader rr : recordReaders.values()){
                resetSupported &= rr.resetSupported();
            }
        }
        if(sequenceRecordReaders != null){
            for(SequenceRecordReader srr : sequenceRecordReaders.values()){
                resetSupported &= srr.resetSupported();
            }
        }
    }

    @Override
    public MultiDataSet next() {
        return next(batchSize);
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException("Remove not supported");
    }

    @Override
    public MultiDataSet next(int num) {
        if (!hasNext())
            throw new NoSuchElementException("No next elements");

        //First: load the next values from the RR / SeqRRs
        Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
        Map<String, List<INDArray>> nextRRValsBatched = null;
        Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
        List<RecordMetaDataComposableMap> nextMetas =
                        (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);


        for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
            RecordReader rr = entry.getValue();
            if (!collectMetaData && rr.batchesSupported()) {
                //Batch case, for efficiency: ImageRecordReader etc
                List<List<Writable>> batchWritables = rr.next(num);

                List<INDArray> batch;
                if(batchWritables instanceof NDArrayRecordBatch) {
                    //ImageRecordReader etc case
                    batch = ((NDArrayRecordBatch)batchWritables).getArrays();
                } else {
                    batchWritables = filterRequiredColumns(entry.getKey(), batchWritables);
                    batch = new ArrayList<>();
                    List<Writable> temp = new ArrayList<>();
                    int sz = batchWritables.get(0).size();
                    for( int i = 0; i < sz; i++) {
                        temp.clear();
                        for( int j = 0; j < batchWritables.size(); j++) {
                            temp.add(batchWritables.get(j).get(i));
                        }

                        batch.add(RecordConverter.toMinibatchArray(temp));
                    }
                }

                if (nextRRValsBatched == null) {
                    nextRRValsBatched = new HashMap<>();
                }
                nextRRValsBatched.put(entry.getKey(), batch);
            } else {
                //Standard case
                List<List<Writable>> writables = new ArrayList<>(Math.min(num, 100000));    //Min op: in case user puts batch size >> amount of data
                for (int i = 0; i < num && rr.hasNext(); i++) {
                    List<Writable> record;
                    if (collectMetaData) {
                        Record r = rr.nextRecord();
                        record = r.getRecord();
                        if (nextMetas.size() <= i) {
                            nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                        }
                        RecordMetaDataComposableMap map = nextMetas.get(i);
                        map.getMeta().put(entry.getKey(), r.getMetaData());
                    } else {
                        record = rr.next();
                    }
                    writables.add(record);
                }

                nextRRVals.put(entry.getKey(), writables);
            }
        }

        for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
            SequenceRecordReader rr = entry.getValue();
            List<List<List<Writable>>> writables = new ArrayList<>(num);
            for (int i = 0; i < num && rr.hasNext(); i++) {
                List<List<Writable>> sequence;
                if (collectMetaData) {
                    SequenceRecord r = rr.nextSequence();
                    sequence = r.getSequenceRecord();
                    if (nextMetas.size() <= i) {
                        nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                    }
                    RecordMetaDataComposableMap map = nextMetas.get(i);
                    map.getMeta().put(entry.getKey(), r.getMetaData());
                } else {
                    sequence = rr.sequenceRecord();
                }
                writables.add(sequence);
            }

            nextSeqRRVals.put(entry.getKey(), writables);
        }

        return nextMultiDataSet(nextRRVals, nextRRValsBatched, nextSeqRRVals, nextMetas);
    }

    //Filter out the required columns before conversion. This is to avoid trying to convert String etc columns
    private List<List<Writable>> filterRequiredColumns(String readerName, List<List<Writable>> list){

        //Options: (a) entire reader
        //(b) one or more subsets

        boolean entireReader = false;
        List<SubsetDetails> subsetList = null;
        int max = -1;
        int min = Integer.MAX_VALUE;
        for(List<SubsetDetails> sdList : Arrays.asList(inputs, outputs)) {
            for (SubsetDetails sd : sdList) {
                if (readerName.equals(sd.readerName)) {
                    if (sd.entireReader) {
                        entireReader = true;
                        break;
                    } else {
                        if (subsetList == null) {
                            subsetList = new ArrayList<>();
                        }
                        subsetList.add(sd);
                        max = Math.max(max, sd.subsetEndInclusive);
                        min = Math.min(min, sd.subsetStart);
                    }
                }
            }
        }

        if(entireReader){
            //No filtering required
            return list;
        } else if(subsetList == null){
            throw new IllegalStateException("Found no usages of reader: " + readerName);
        } else {
            //we need some - but not all - columns
            boolean[] req = new boolean[max+1];
            for(SubsetDetails sd : subsetList){
                for( int i=sd.subsetStart; i<= sd.subsetEndInclusive; i++ ){
                    req[i] = true;
                }
            }

            List<List<Writable>> out = new ArrayList<>();
            IntWritable zero = new IntWritable(0);
            for(List<Writable> l : list){
                List<Writable> lNew = new ArrayList<>(l.size());
                for(int i=0; i<l.size(); i++ ){
                    if(i >= req.length || !req[i]){
                        lNew.add(zero);
                    } else {
                        lNew.add(l.get(i));
                    }
                }
                out.add(lNew);
            }
            return out;
        }
    }

    public MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals,
                    Map<String, List<INDArray>> nextRRValsBatched,
                    Map<String, List<List<List<Writable>>>> nextSeqRRVals,
                    List<RecordMetaDataComposableMap> nextMetas) {
        int minExamples = Integer.MAX_VALUE;
        for (List<List<Writable>> exampleData : nextRRVals.values()) {
            minExamples = Math.min(minExamples, exampleData.size());
        }
        if (nextRRValsBatched != null) {
            for (List<INDArray> exampleData : nextRRValsBatched.values()) {
                //Assume all NDArrayWritables here
                for (INDArray w : exampleData) {
                    val n = w.size(0);

                    if (Math.min(minExamples, n) < Integer.MAX_VALUE)
                        minExamples = (int) Math.min(minExamples, n);
                }
            }
        }
        for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
            minExamples = Math.min(minExamples, exampleData.size());
        }


        if (minExamples == Integer.MAX_VALUE)
            throw new RuntimeException("Error occurred during data set generation: no readers?"); //Should never happen

        //In order to align data at the end (for each example individually), we need to know the length of the
        // longest time series for each example
        int[] longestSequence = null;
        if (timeSeriesRandomOffset || alignmentMode == AlignmentMode.ALIGN_END) {
            longestSequence = new int[minExamples];
            for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
                List<List<List<Writable>>> list = entry.getValue();
                for (int i = 0; i < list.size() && i < minExamples; i++) {
                    longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
                }
            }
        }

        //Second: create the input/feature arrays
        //To do this, we need to know longest time series length, so we can do padding
        int longestTS = -1;
        if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
            for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
                List<List<List<Writable>>> list = entry.getValue();
                for (List<List<Writable>> c : list) {
                    longestTS = Math.max(longestTS, c.size());
                }
            }
        }
        long rngSeed = (timeSeriesRandomOffset ? timeSeriesRandomOffsetRng.nextLong() : -1);
        Pair<INDArray[], INDArray[]> features = convertFeaturesOrLabels(new INDArray[inputs.size()],
                        new INDArray[inputs.size()], inputs, minExamples, nextRRVals, nextRRValsBatched, nextSeqRRVals,
                        longestTS, longestSequence, rngSeed);


        //Third: create the outputs/labels
        Pair<INDArray[], INDArray[]> labels = convertFeaturesOrLabels(new INDArray[outputs.size()],
                        new INDArray[outputs.size()], outputs, minExamples, nextRRVals, nextRRValsBatched,
                        nextSeqRRVals, longestTS, longestSequence, rngSeed);



        MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(features.getFirst(), labels.getFirst(),
                        features.getSecond(), labels.getSecond());
        if (collectMetaData) {
            mds.setExampleMetaData(nextMetas);
        }
        if (preProcessor != null)
            preProcessor.preProcess(mds);
        return mds;
    }

    private Pair<INDArray[], INDArray[]> convertFeaturesOrLabels(INDArray[] featuresOrLabels, INDArray[] masks,
                    List<SubsetDetails> subsetDetails, int minExamples, Map<String, List<List<Writable>>> nextRRVals,
                    Map<String, List<INDArray>> nextRRValsBatched,
                    Map<String, List<List<List<Writable>>>> nextSeqRRVals, int longestTS, int[] longestSequence,
                    long rngSeed) {
        boolean hasMasks = false;
        int i = 0;

        for (SubsetDetails d : subsetDetails) {
            if (nextRRValsBatched != null && nextRRValsBatched.containsKey(d.readerName)) {
                //Standard reader, but batch ops
                featuresOrLabels[i] = convertWritablesBatched(nextRRValsBatched.get(d.readerName), d);
            } else if (nextRRVals.containsKey(d.readerName)) {
                //Standard reader
                List<List<Writable>> list = nextRRVals.get(d.readerName);
                featuresOrLabels[i] = convertWritables(list, minExamples, d);
            } else {
                //Sequence reader
                List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
                Pair<INDArray, INDArray> p =
                                convertWritablesSequence(list, minExamples, longestTS, d, longestSequence, rngSeed);
                featuresOrLabels[i] = p.getFirst();
                masks[i] = p.getSecond();
                if (masks[i] != null)
                    hasMasks = true;
            }
            i++;
        }

        return new Pair<>(featuresOrLabels, hasMasks ? masks : null);
    }

    private INDArray convertWritablesBatched(List<INDArray> list, SubsetDetails details) {
        INDArray arr;
        if (details.entireReader) {
            if (list.size() == 1) {
                arr = list.get(0);
            } else {
                //Need to concat column vectors
                INDArray[] asArray = list.toArray(new INDArray[list.size()]);
                arr = Nd4j.concat(1, asArray);
            }
        } else if (details.subsetStart == details.subsetEndInclusive || details.oneHot) {
            arr = list.get(details.subsetStart);
        } else {
            //Concat along dimension 1
            int count = details.subsetEndInclusive - details.subsetStart + 1;
            INDArray[] temp = new INDArray[count];
            int x = 0;
            for( int i=details.subsetStart; i<= details.subsetEndInclusive; i++){
                temp[x++] = list.get(i);
            }
            arr = Nd4j.concat(1, temp);
        }

        if (!details.oneHot || arr.size(1) == details.oneHotNumClasses) {
            //Not one-hot: no conversion required
            //Also, ImageRecordReader already does the one-hot conversion internally
            return arr;
        }

        //Do one-hot conversion
        if (arr.size(1) != 1) {
            throw new UnsupportedOperationException("Cannot do conversion to one hot using batched reader: "
                            + details.oneHotNumClasses + " output classes, but array.size(1) is " + arr.size(1)
                            + " (must be equal to 1 or numClasses = " + details.oneHotNumClasses + ")");
        }

        val n = arr.size(0);
        INDArray out = Nd4j.create(n, details.oneHotNumClasses);
        for (int i = 0; i < n; i++) {
            int v = arr.getInt(i, 0);
            out.putScalar(i, v, 1.0);
        }

        return out;
    }

    private int countLength(List<Writable> list) {
        return countLength(list, 0, list.size() - 1);
    }

    private int countLength(List<Writable> list, int from, int to) {
        int length = 0;
        for (int i = from; i <= to; i++) {
            Writable w = list.get(i);
            if (w instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable) w).get();
                if (!a.isRowVectorOrScalar()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is "
                                    + "not a row vector. Can only concat row vectors with other writables. Shape: "
                                    + Arrays.toString(a.shape()));
                }
                length += a.length();
            } else {
                //Assume all others are single value
                length++;
            }
        }

        return length;
    }

    private INDArray convertWritables(List<List<Writable>> list, int minValues, SubsetDetails details) {
        try{
            return convertWritablesHelper(list, minValues, details);
        } catch (NumberFormatException e) {
            throw new RuntimeException("Error parsing data (writables) from record readers - value is non-numeric", e);
        } catch(IllegalStateException e){
            throw e;
        } catch (Throwable t){
            throw new RuntimeException("Error parsing data (writables) from record readers", t);
        }
    }

    private INDArray convertWritablesHelper(List<List<Writable>> list, int minValues, SubsetDetails details) {
        INDArray arr;
        if (details.entireReader) {
            if (list.get(0).size() == 1 && list.get(0).get(0) instanceof NDArrayWritable) {
                //Special case: single NDArrayWritable...
                INDArray temp = ((NDArrayWritable) list.get(0).get(0)).get();
                val shape = ArrayUtils.clone(temp.shape());
                shape[0] = minValues;
                arr = Nd4j.create(shape);
            } else {
                arr = Nd4j.create(minValues, countLength(list.get(0)));
            }
        } else if (details.oneHot) {
            arr = Nd4j.zeros(minValues, details.oneHotNumClasses);
        } else {
            if (details.subsetStart == details.subsetEndInclusive
                            && list.get(0).get(details.subsetStart) instanceof NDArrayWritable) {
                //Special case: single NDArrayWritable (example: ImageRecordReader)
                INDArray temp = ((NDArrayWritable) list.get(0).get(details.subsetStart)).get();
                val shape = ArrayUtils.clone(temp.shape());
                shape[0] = minValues;
                arr = Nd4j.create(shape);
            } else {
                //Need to check for multiple NDArrayWritables, or mixed NDArrayWritable + DoubleWritable etc
                int length = countLength(list.get(0), details.subsetStart, details.subsetEndInclusive);
                arr = Nd4j.create(minValues, length);
            }
        }

        for (int i = 0; i < minValues; i++) {
            List<Writable> c = list.get(i);
            if (details.entireReader) {
                //Convert entire reader contents, without modification
                INDArray converted = RecordConverter.toArray(Nd4j.defaultFloatingPointType(), c);
                putExample(arr, converted, i);
            } else if (details.oneHot) {
                //Convert a single column to a one-hot representation
                Writable w = c.get(details.subsetStart);
                //Index of class
                int classIdx = w.toInt();
                if (classIdx >= details.oneHotNumClasses) {
                    throw new IllegalStateException("Cannot convert sequence writables to one-hot: class index " + classIdx
                                    + " >= numClass (" + details.oneHotNumClasses + "). (Note that classes are zero-" +
                            "indexed, thus only values 0 to nClasses-1 are valid)");
                }
                arr.putScalar(i, w.toInt(), 1.0);
            } else {
                //Convert a subset of the columns

                //Special case: subsetStart == subsetEndInclusive && NDArrayWritable. Example: ImageRecordReader
                if (details.subsetStart == details.subsetEndInclusive
                                && (c.get(details.subsetStart) instanceof NDArrayWritable)) {
                    putExample(arr, ((NDArrayWritable) c.get(details.subsetStart)).get(), i);
                } else {

                    Iterator<Writable> iter = c.iterator();
                    for (int j = 0; j < details.subsetStart; j++)
                        iter.next();
                    int k = 0;
                    for (int j = details.subsetStart; j <= details.subsetEndInclusive; j++) {
                        Writable w = iter.next();

                        if (w instanceof NDArrayWritable) {
                            INDArray toPut = ((NDArrayWritable) w).get();
                            arr.put(new INDArrayIndex[] {NDArrayIndex.point(i),
                                            NDArrayIndex.interval(k, k + toPut.length())}, toPut);
                            k += toPut.length();
                        } else {
                            arr.putScalar(i, k, w.toDouble());
                            k++;
                        }
                    }
                }
            }
        }

        return arr;
    }

    private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
        Preconditions.checkState(singleExample.size(0) == 1 && singleExample.rank() == arr.rank(), "Cannot put array: array should have leading dimension of 1 " +
                "and equal rank to output array. Attempting to put array of shape %s into output array of shape %s", singleExample.shape(), arr.shape());

        long[] arrShape = arr.shape();
        long[] singleShape = singleExample.shape();
        for( int i=1; i<arr.rank(); i++ ){
            Preconditions.checkState(arrShape[i] == singleShape[i], "Single example array and output arrays differ at position %s:" +
                    "single example shape %s, output array shape %s", i, singleShape, arrShape);
        }
        switch (arr.rank()) {
            case 2:
                arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all()}, singleExample);
                break;
            case 3:
                arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all()},
                                singleExample);
                break;
            case 4:
                arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
                                NDArrayIndex.all()}, singleExample);
                break;
            case 5:
                arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
                        NDArrayIndex.all(), NDArrayIndex.all()}, singleExample);
                break;
            default:
                throw new RuntimeException("Unexpected array rank: " + arr.rank() + " with shape " + Arrays.toString(arr.shape()) + " input arrays should be rank 2 to 5 inclusive");
        }
    }

    /**
     * Convert the writables to a sequence (3d) data set, and also return the mask array (if necessary)
     */
    private Pair<INDArray, INDArray> convertWritablesSequence(List<List<List<Writable>>> list, int minValues,
                    int maxTSLength, SubsetDetails details, int[] longestSequence, long rngSeed) {
        if (maxTSLength == -1)
            maxTSLength = list.get(0).size();
        INDArray arr;

        if (list.get(0).isEmpty()) {
            throw new ZeroLengthSequenceException("Zero length sequence encountered");
        }

        List<Writable> firstStep = list.get(0).get(0);

        int size = 0;
        if (details.entireReader) {
            //Need to account for NDArrayWritables etc in list:
            for (Writable w : firstStep) {
                if (w instanceof NDArrayWritable) {
                    size += ((NDArrayWritable) w).get().size(1);
                } else {
                    size++;
                }
            }
        } else if (details.oneHot) {
            size = details.oneHotNumClasses;
        } else {
            //Need to account for NDArrayWritables etc in list:
            for (int i = details.subsetStart; i <= details.subsetEndInclusive; i++) {
                Writable w = firstStep.get(i);
                if (w instanceof NDArrayWritable) {
                    size += ((NDArrayWritable) w).get().size(1);
                } else {
                    size++;
                }
            }
        }
        arr = Nd4j.create(new int[] {minValues, size, maxTSLength}, 'f');

        boolean needMaskArray = false;
        for (List<List<Writable>> c : list) {
            if (c.size() < maxTSLength)
                needMaskArray = true;
        }

        if (needMaskArray && alignmentMode == AlignmentMode.EQUAL_LENGTH) {
            throw new UnsupportedOperationException(
                            "Alignment mode is set to EQUAL_LENGTH but variable length data was "
                                            + "encountered. Use AlignmentMode.ALIGN_START or AlignmentMode.ALIGN_END with variable length data");
        }

        INDArray maskArray;
        if (needMaskArray) {
            maskArray = Nd4j.ones(minValues, maxTSLength);
        } else {
            maskArray = null;
        }

        //Don't use the global RNG as we need repeatability for each subset (i.e., features and labels must be aligned)
        Random rng = null;
        if (timeSeriesRandomOffset) {
            rng = new Random(rngSeed);
        }

        for (int i = 0; i < minValues; i++) {
            List<List<Writable>> sequence = list.get(i);

            //Offset for alignment:
            int startOffset;
            if (alignmentMode == AlignmentMode.ALIGN_START || alignmentMode == AlignmentMode.EQUAL_LENGTH) {
                startOffset = 0;
            } else {
                //Align end
                //Only practical differences here are: (a) offset, and (b) masking
                startOffset = longestSequence[i] - sequence.size();
            }

            if (timeSeriesRandomOffset) {
                int maxPossible = maxTSLength - sequence.size() + 1;
                startOffset = rng.nextInt(maxPossible);
            }

            int t = 0;
            int k;
            for (List<Writable> timeStep : sequence) {
                k = startOffset + t++;

                if (details.entireReader) {
                    //Convert entire reader contents, without modification
                    Iterator<Writable> iter = timeStep.iterator();
                    int j = 0;
                    while (iter.hasNext()) {
                        Writable w = iter.next();

                        if (w instanceof NDArrayWritable) {
                            INDArray row = ((NDArrayWritable) w).get();

                            arr.put(new INDArrayIndex[] {NDArrayIndex.point(i),
                                            NDArrayIndex.interval(j, j + row.length()), NDArrayIndex.point(k)}, row);
                            j += row.length();
                        } else {
                            arr.putScalar(i, j, k, w.toDouble());
                            j++;
                        }
                    }
                } else if (details.oneHot) {
                    //Convert a single column to a one-hot representation
                    Writable w = null;
                    if (timeStep instanceof List)
                        w = timeStep.get(details.subsetStart);
                    else {
                        Iterator<Writable> iter = timeStep.iterator();
                        for (int x = 0; x <= details.subsetStart; x++)
                            w = iter.next();
                    }
                    int classIdx = w.toInt();
                    if (classIdx >= details.oneHotNumClasses) {
                        throw new IllegalStateException("Cannot convert sequence writables to one-hot: class index " + classIdx
                                        + " >= numClass (" + details.oneHotNumClasses + "). (Note that classes are zero-" +
                                "indexed, thus only values 0 to nClasses-1 are valid)");
                    }
                    arr.putScalar(i, classIdx, k, 1.0);
                } else {
                    //Convert a subset of the columns...
                    int l = 0;
                    for (int j = details.subsetStart; j <= details.subsetEndInclusive; j++) {
                        Writable w = timeStep.get(j);

                        if (w instanceof NDArrayWritable) {
                            INDArray row = ((NDArrayWritable) w).get();
                            arr.put(new INDArrayIndex[] {NDArrayIndex.point(i),
                                            NDArrayIndex.interval(l, l + row.length()), NDArrayIndex.point(k)}, row);

                            l += row.length();
                        } else {
                            arr.putScalar(i, l++, k, w.toDouble());
                        }
                    }
                }
            }

            //For any remaining time steps: set mask array to 0 (just padding)
            if (needMaskArray) {
                //Masking array entries at start (for align end)
                if (timeSeriesRandomOffset || alignmentMode == AlignmentMode.ALIGN_END) {
                    for (int t2 = 0; t2 < startOffset; t2++) {
                        maskArray.putScalar(i, t2, 0.0);
                    }
                }

                //Masking array entries at end (for align start)
                int lastStep = startOffset + sequence.size();
                if (timeSeriesRandomOffset || alignmentMode == AlignmentMode.ALIGN_START || lastStep < maxTSLength) {
                    for (int t2 = lastStep; t2 < maxTSLength; t2++) {
                        maskArray.putScalar(i, t2, 0.0);
                    }
                }
            }
        }

        return new Pair<>(arr, maskArray);
    }

    @Override
    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public MultiDataSetPreProcessor getPreProcessor() {
        return preProcessor;
    }

    @Override
    public boolean resetSupported() {
        return resetSupported;
    }

    @Override
    public boolean asyncSupported() {
        return true;
    }

    @Override
    public void reset() {
        if(!resetSupported){
            throw new IllegalStateException("Cannot reset iterator - reset not supported (resetSupported() == false):" +
                    " one or more underlying (sequence) record readers do not support resetting");
        }

        for (RecordReader rr : recordReaders.values())
            rr.reset();
        for (SequenceRecordReader rr : sequenceRecordReaders.values())
            rr.reset();
    }

    @Override
    public boolean hasNext() {
        for (RecordReader rr : recordReaders.values())
            if (!rr.hasNext())
                return false;
        for (SequenceRecordReader rr : sequenceRecordReaders.values())
            if (!rr.hasNext())
                return false;
        return true;
    }


    public static class Builder {

        private int batchSize;
        private AlignmentMode alignmentMode = AlignmentMode.ALIGN_START;
        private Map<String, RecordReader> recordReaders = new HashMap<>();
        private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap<>();

        private List<SubsetDetails> inputs = new ArrayList<>();
        private List<SubsetDetails> outputs = new ArrayList<>();

        private boolean timeSeriesRandomOffset = false;
        private long timeSeriesRandomOffsetSeed = System.currentTimeMillis();

        /**
         * @param batchSize The batch size for the RecordReaderMultiDataSetIterator
         */
        public Builder(int batchSize) {
            this.batchSize = batchSize;
        }

        /**
         * Add a RecordReader for use in .addInput(...) or .addOutput(...)
         *
         * @param readerName   Name of the reader (for later reference)
         * @param recordReader RecordReader
         */
        public Builder addReader(String readerName, RecordReader recordReader) {
            recordReaders.put(readerName, recordReader);
            return this;
        }

        /**
         * Add a SequenceRecordReader for use in .addInput(...) or .addOutput(...)
         *
         * @param seqReaderName   Name of the sequence reader (for later reference)
         * @param seqRecordReader SequenceRecordReader
         */
        public Builder addSequenceReader(String seqReaderName, SequenceRecordReader seqRecordReader) {
            sequenceRecordReaders.put(seqReaderName, seqRecordReader);
            return this;
        }

        /**
         * Set the sequence alignment mode for all sequences
         */
        public Builder sequenceAlignmentMode(AlignmentMode alignmentMode) {
            this.alignmentMode = alignmentMode;
            return this;
        }

        /**
         * Set as an input, the entire contents (all columns) of the RecordReader or SequenceRecordReader
         */
        public Builder addInput(String readerName) {
            inputs.add(new SubsetDetails(readerName, true, false, -1, -1, -1));
            return this;
        }

        /**
         * Set as an input, a subset of the specified RecordReader or SequenceRecordReader
         *
         * @param readerName  Name of the reader
         * @param columnFirst First column index, inclusive
         * @param columnLast  Last column index, inclusive
         */
        public Builder addInput(String readerName, int columnFirst, int columnLast) {
            inputs.add(new SubsetDetails(readerName, false, false, -1, columnFirst, columnLast));
            return this;
        }

        /**
         * Add as an input a single column from the specified RecordReader / SequenceRecordReader
         * The assumption is that the specified column contains integer values in range 0..numClasses-1;
         * this integer will be converted to a one-hot representation
         *
         * @param readerName Name of the RecordReader or SequenceRecordReader
         * @param column     Column that contains the index
         * @param numClasses Total number of classes
         */
        public Builder addInputOneHot(String readerName, int column, int numClasses) {
            inputs.add(new SubsetDetails(readerName, false, true, numClasses, column, column));
            return this;
        }

        /**
         * Set as an output, the entire contents (all columns) of the RecordReader or SequenceRecordReader
         */
        public Builder addOutput(String readerName) {
            outputs.add(new SubsetDetails(readerName, true, false, -1, -1, -1));
            return this;
        }

        /**
         * Add an output, with a subset of the columns from the named RecordReader or SequenceRecordReader
         *
         * @param readerName  Name of the reader
         * @param columnFirst First column index
         * @param columnLast  Last column index (inclusive)
         */
        public Builder addOutput(String readerName, int columnFirst, int columnLast) {
            outputs.add(new SubsetDetails(readerName, false, false, -1, columnFirst, columnLast));
            return this;
        }

        /**
         * An an output, where the output is taken from a single column from the specified RecordReader / SequenceRecordReader
         * The assumption is that the specified column contains integer values in range 0..numClasses-1;
         * this integer will be converted to a one-hot representation (usually for classification)
         *
         * @param readerName Name of the RecordReader / SequenceRecordReader
         * @param column     index of the column
         * @param numClasses Number of classes
         */
        public Builder addOutputOneHot(String readerName, int column, int numClasses) {
            outputs.add(new SubsetDetails(readerName, false, true, numClasses, column, column));
            return this;
        }

        /**
         * For use with timeseries trained with tbptt
         * In a given minbatch, shorter time series are padded and appropriately masked to be the same length as the longest time series.
         * Cases with a skewed distrbution of lengths can result in the last few updates from the time series coming from mostly masked time steps.
         * timeSeriesRandomOffset randomly offsettsthe time series + masking appropriately to address this
         * @param timeSeriesRandomOffset, "true" to randomly offset time series within a minibatch
         * @param rngSeed seed for reproducibility
         */
        public Builder timeSeriesRandomOffset(boolean timeSeriesRandomOffset, long rngSeed) {
            this.timeSeriesRandomOffset = timeSeriesRandomOffset;
            this.timeSeriesRandomOffsetSeed = rngSeed;
            return this;
        }

        /**
         * Create the RecordReaderMultiDataSetIterator
         */
        public RecordReaderMultiDataSetIterator build() {
            //Validate input:
            if (recordReaders.isEmpty() && sequenceRecordReaders.isEmpty()) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no readers");
            }

            if (batchSize <= 0)
                throw new IllegalStateException(
                                "Cannot construct RecordReaderMultiDataSetIterator with batch size <= 0");

            if (inputs.isEmpty() && outputs.isEmpty()) {
                throw new IllegalStateException(
                                "Cannot construct RecordReaderMultiDataSetIterator with no inputs/outputs");
            }

            for (SubsetDetails ssd : inputs) {
                if (!recordReaders.containsKey(ssd.readerName) && !sequenceRecordReaders.containsKey(ssd.readerName)) {
                    throw new IllegalStateException(
                                    "Invalid input name: \"" + ssd.readerName + "\" - no reader found with this name");
                }
            }

            for (SubsetDetails ssd : outputs) {
                if (!recordReaders.containsKey(ssd.readerName) && !sequenceRecordReaders.containsKey(ssd.readerName)) {
                    throw new IllegalStateException(
                                    "Invalid output name: \"" + ssd.readerName + "\" - no reader found with this name");
                }
            }

            return new RecordReaderMultiDataSetIterator(this);
        }
    }

    /**
     * Load a single example to a DataSet, using the provided RecordMetaData.
     * Note that it is more efficient to load multiple instances at once, using {@link #loadFromMetaData(List)}
     *
     * @param recordMetaData RecordMetaData to load from. Should have been produced by the given record reader
     * @return DataSet with the specified example
     * @throws IOException If an error occurs during loading of the data
     */
    public MultiDataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    /**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
    public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        //First: load the next values from the RR / SeqRRs
        Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
        Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
        List<RecordMetaDataComposableMap> nextMetas =
                        (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);


        for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
            RecordReader rr = entry.getValue();

            List<RecordMetaData> thisRRMeta = new ArrayList<>();
            for (RecordMetaData m : list) {
                RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
                thisRRMeta.add(m2.getMeta().get(entry.getKey()));
            }

            List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
            List<List<Writable>> writables = new ArrayList<>(list.size());
            for (Record r : fromMeta) {
                writables.add(r.getRecord());
            }

            nextRRVals.put(entry.getKey(), writables);
        }

        for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
            SequenceRecordReader rr = entry.getValue();

            List<RecordMetaData> thisRRMeta = new ArrayList<>();
            for (RecordMetaData m : list) {
                RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
                thisRRMeta.add(m2.getMeta().get(entry.getKey()));
            }

            List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
            List<List<List<Writable>>> writables = new ArrayList<>(list.size());
            for (SequenceRecord r : fromMeta) {
                writables.add(r.getSequenceRecord());
            }

            nextSeqRRVals.put(entry.getKey(), writables);
        }

        return nextMultiDataSet(nextRRVals, null, nextSeqRRVals, nextMetas);

    }

    @AllArgsConstructor
    private static class SubsetDetails implements Serializable {
        private final String readerName;
        private final boolean entireReader;
        private final boolean oneHot;
        private final int oneHotNumClasses;
        private final int subsetStart;
        private final int subsetEndInclusive;
    }
}