deeplearning4j/deeplearning4j

View on GitHub
datavec/datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java

Summary

Maintainability
A
3 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.transform.ndarray;

import lombok.Data;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnsMathOpTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.Arrays;

@Data
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {

    public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
                    @JsonProperty("mathOp") MathOp mathOp, @JsonProperty("columns") String... columns) {
        super(newColumnName, mathOp, columns);
    }

    @Override
    protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) {
        //Check types

        for (int i = 0; i < columns.length; i++) {
            if (inputSchema.getMetaData(columns[i]).getColumnType() != ColumnType.NDArray) {
                throw new RuntimeException("Column " + columns[i] + " is not an NDArray column");
            }
        }

        //Check shapes
        NDArrayMetaData meta = (NDArrayMetaData) inputSchema.getMetaData(columns[0]);
        for (int i = 1; i < columns.length; i++) {
            NDArrayMetaData meta2 = (NDArrayMetaData) inputSchema.getMetaData(columns[i]);
            if (!Arrays.equals(meta.getShape(), meta2.getShape())) {
                throw new UnsupportedOperationException(
                                "Cannot perform NDArray operation on columns with different shapes: " + "Columns \""
                                                + columns[0] + "\" and \"" + columns[i] + "\" have shapes: "
                                                + Arrays.toString(meta.getShape()) + " and "
                                                + Arrays.toString(meta2.getShape()));
            }
        }

        return new NDArrayMetaData(newColumnName, meta.getShape());
    }

    @Override
    protected Writable doOp(Writable... input) {
        INDArray out = ((NDArrayWritable) input[0]).get().dup();

        switch (mathOp) {
            case Add:
                for (int i = 1; i < input.length; i++) {
                    out.addi(((NDArrayWritable) input[i]).get());
                }
                break;
            case Subtract:
                out.subi(((NDArrayWritable) input[1]).get());
                break;
            case Multiply:
                for (int i = 1; i < input.length; i++) {
                    out.muli(((NDArrayWritable) input[i]).get());
                }
                break;
            case Divide:
                out.divi(((NDArrayWritable) input[1]).get());
                break;
            case ReverseSubtract:
                out.rsubi(((NDArrayWritable) input[1]).get());
                break;
            case ReverseDivide:
                out.rdivi(((NDArrayWritable) input[1]).get());
                break;
            case Modulus:
            case ScalarMin:
            case ScalarMax:
                throw new IllegalArgumentException(
                                "Invalid MathOp: cannot use " + mathOp + " with NDArrayColumnsMathOpTransform");
            default:
                throw new RuntimeException("Unknown MathOp: " + mathOp);
        }

        //To avoid threading issues...
        Nd4j.getExecutioner().commit();

        return new NDArrayWritable(out);
    }

    @Override
    public String toString() {
        return "NDArrayColumnsMathOpTransform(newColumnName=\"" + newColumnName + "\",mathOp=" + mathOp + ",columns="
                        + Arrays.toString(columns) + ")";
    }

    @Override
    public Object map(Object input) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Object mapSequence(Object sequence) {
        throw new UnsupportedOperationException("Not yet implemented");
    }
}