deeplearning4j/deeplearning4j

View on GitHub
datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java

Summary

Maintainability
D
1 day
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.util.ndarray;

import org.nd4j.shade.guava.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 * @author Adam Gibson
 */
public class RecordConverter {
    private RecordConverter() {}

    /**
     * Convert a record to an ndarray
     * @param record the record to convert
     *
     * @return the array
     */
    public static INDArray toArray(DataType dataType, Collection<Writable> record, int size) {
        return toArray(dataType, record);
    }

    /**
     * Convert a set of records in to a matrix
     * @param matrix the records ot convert
     * @return the matrix for the records
     */
    public static List<List<Writable>> toRecords(INDArray matrix) {
        List<List<Writable>> ret = new ArrayList<>();
        for (int i = 0; i < matrix.rows(); i++) {
            ret.add(RecordConverter.toRecord(matrix.getRow(i)));
        }

        return ret;
    }


    /**
     * Convert a set of records in to a matrix
     * @param records the records ot convert
     * @return the matrix for the records
     */
    public static INDArray toTensor(List<List<List<Writable>>> records) {
       return TimeSeriesWritableUtils.convertWritablesSequence(records).getFirst();
    }

    /**
     * Convert a set of records in to a matrix
     * As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype
     * @param records the records ot convert
     * @return the matrix for the records
     */
    public static INDArray toMatrix(List<List<Writable>> records) {
        return toMatrix(DataType.FLOAT, records);
    }

    /**
     * Convert a set of records in to a matrix
     * @param records the records ot convert
     * @return the matrix for the records
     */
    public static INDArray toMatrix(DataType dataType, List<List<Writable>> records) {
        List<INDArray> toStack = new ArrayList<>();
        for(List<Writable> l : records){
            toStack.add(toArray(dataType, l));
        }

        return Nd4j.vstack(toStack);
    }

    /**
     * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
     * As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype
     * @param record the record to convert
     * @return the array
     */
    public static INDArray toArray(Collection<? extends Writable> record){
        return toArray(DataType.FLOAT, record);
    }

    /**
     * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables.
     * @param record the record to convert
     * @return the array
     */
    public static INDArray toArray(DataType dataType, Collection<? extends Writable> record) {
        List<Writable> l;
        if(record instanceof List){
            l = (List<Writable>)record;
        } else {
            l = new ArrayList<>(record);
        }

        //Edge case: single NDArrayWritable
        if(l.size() == 1 && l.get(0) instanceof NDArrayWritable){
            return ((NDArrayWritable) l.get(0)).get();
        }

        int length = 0;
        for (Writable w : record) {
            if (w instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable) w).get();
                if (!a.isRowVector()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is "
                            + "not a row vector. Can only concat row vectors with other writables. Shape: "
                            + Arrays.toString(a.shape()));
                }
                length += a.length();
            } else {
                //Assume all others are single value
                length++;
            }
        }

        INDArray arr = Nd4j.create(dataType, 1, length);

        int k = 0;
        for (Writable w : record ) {
            if (w instanceof NDArrayWritable) {
                INDArray toPut = ((NDArrayWritable) w).get();
                arr.put(new INDArrayIndex[] {NDArrayIndex.point(0),
                        NDArrayIndex.interval(k, k + toPut.length())}, toPut);
                k += toPut.length();
            } else {
                arr.putScalar(0, k, w.toDouble());
                k++;
            }
        }

        return arr;
    }

    /**
     * Convert a record to an INDArray, for use in minibatch training. That is, for an input record of length N, the output
     * array has dimension 0 of size N (i.e., suitable for minibatch training in DL4J, for example).<br>
     * The input list of writables must all be the same type (i.e., all NDArrayWritables or all non-array writables such
     * as DoubleWritable etc).<br>
     * Note that for NDArrayWritables, they must have leading dimension 1, and all other dimensions must match. <br>
     * For example, row vectors are valid NDArrayWritables, as are 3d (usually time series) with shape [1, x, y], or
     * 4d (usually images) with shape [1, x, y, z] where (x,y,z) are the same for all inputs
     * @param l the records to convert
     * @return the array
     * @see #toArray(Collection) for the "single example concatenation" version of this method
     */
    public static INDArray toMinibatchArray(@NonNull List<? extends Writable> l) {
        Preconditions.checkArgument(l.size() > 0, "Cannot convert empty list");

        //Edge case: single NDArrayWritable
        if(l.size() == 1 && l.get(0) instanceof NDArrayWritable){
            return ((NDArrayWritable) l.get(0)).get();
        }

        //Check: all NDArrayWritable or all non-writable
        List<INDArray> toConcat = null;
        DoubleArrayList list = null;
        for (Writable w : l) {
            if (w instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable) w).get();
                if (a.size(0) != 1) {
                    throw new UnsupportedOperationException("NDArrayWritable must have leading dimension 1 for this " +
                            "method. Received array with shape: " + Arrays.toString(a.shape()));
                }
                if(toConcat == null) {
                    toConcat = new ArrayList<>();
                }
                toConcat.add(a);
            } else {
                //Assume all others are single value
                if(list == null) {
                    list = new DoubleArrayList();
                }
                list.add(w.toDouble());
            }
        }


        if(toConcat != null && list != null){
            throw new IllegalStateException("Error converting writables: found both NDArrayWritable and single value" +
                    " (DoubleWritable etc) in the one list. All writables must be NDArrayWritables or " +
                    "single value writables only for this method");
        }

        if(toConcat != null){
            return Nd4j.concat(0, toConcat.toArray(new INDArray[toConcat.size()]));
        } else {
            return Nd4j.create(list.toArray(new double[list.size()]), new long[]{list.size(), 1}, DataType.FLOAT);
        }
    }

    /**
     * Convert an ndarray to a record
     * @param array the array to convert
     * @return the record
     */
    public static List<Writable> toRecord(INDArray array) {
        List<Writable> writables = new ArrayList<>();
        writables.add(new NDArrayWritable(array));
        return writables;
    }

    /**
     *  Convert a collection into a `List<Writable>`, i.e. a record that can be used with other datavec methods.
     *  Uses a schema to decide what kind of writable to use.
     *
     * @return a record
     */
    public static List<Writable> toRecord(Schema schema, List<Object> source){
        final List<Writable> record = new ArrayList<>(source.size());
        final List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();

        if(columnMetaData.size() != source.size()){
            throw new IllegalArgumentException("Schema and source list don't have the same length!");
        }

        for (int i = 0; i < columnMetaData.size(); i++) {
            final ColumnMetaData metaData = columnMetaData.get(i);
            final Object data = source.get(i);
            if(!metaData.isValid(data)){
                throw new IllegalArgumentException("Element "+i+": "+data+" is not valid for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
            }

            try {
                final Writable writable;
                switch (metaData.getColumnType().getWritableType()){
                    case Float:
                        writable = new FloatWritable((Float) data);
                        break;
                    case Double:
                        writable = new DoubleWritable((Double) data);
                        break;
                    case Int:
                        writable = new IntWritable((Integer) data);
                        break;
                    case Byte:
                        writable = new ByteWritable((Byte) data);
                        break;
                    case Boolean:
                        writable = new BooleanWritable((Boolean) data);
                        break;
                    case Long:
                        writable = new LongWritable((Long) data);
                        break;
                    case Null:
                        writable = new NullWritable();
                        break;
                    case Bytes:
                        writable = new BytesWritable((byte[]) data);
                        break;
                    case NDArray:
                        writable = new NDArrayWritable((INDArray) data);
                        break;
                    case Text:
                        if(data instanceof String)
                            writable = new Text((String) data);
                        else if(data instanceof Text)
                            writable = new Text((Text) data);
                        else if(data instanceof byte[])
                            writable = new Text((byte[]) data);
                        else
                            throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
                        break;
                    default:
                        throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
                }
                record.add(writable);
            } catch (ClassCastException e) {
                throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")", e);
            }
        }

        return record;
    }

    /**
     * Convert a DataSet to a matrix
     * @param dataSet the DataSet to convert
     * @return the matrix for the records
     */
    public static List<List<Writable>> toRecords(DataSet dataSet) {
        if (isClassificationDataSet(dataSet)) {
            return getClassificationWritableMatrix(dataSet);
        } else {
            return getRegressionWritableMatrix(dataSet);
        }
    }

    private static boolean isClassificationDataSet(DataSet dataSet) {
        INDArray labels = dataSet.getLabels();

        return labels.sum(0, -1).getInt(0) == dataSet.numExamples() && labels.shape()[1] > 1;
    }

    private static List<List<Writable>> getClassificationWritableMatrix(DataSet dataSet) {
        List<List<Writable>> writableMatrix = new ArrayList<>();

        for (int i = 0; i < dataSet.numExamples(); i++) {
            List<Writable> writables = toRecord(dataSet.getFeatures().getRow(i, true));
            writables.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i)).getInt(0)));

            writableMatrix.add(writables);
        }

        return writableMatrix;
    }

    private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) {
        List<List<Writable>> writableMatrix = new ArrayList<>();

        for (int i = 0; i < dataSet.numExamples(); i++) {
            List<Writable> writables = toRecord(dataSet.getFeatures().rank() > 1 ?
                    dataSet.getFeatures().getRow(i) : dataSet.getFeatures());
            INDArray labelRow = dataSet.getLabels().rank() > 1 ? dataSet.getLabels().getRow(i)
                    : dataSet.getLabels();

            for (int j = 0; j < labelRow.size(-1); j++) {
                writables.add(new DoubleWritable(labelRow.getDouble(j)));
            }

            writableMatrix.add(writables);
        }

        return writableMatrix;
    }
}