deeplearning4j/deeplearning4j

View on GitHub
python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java

Summary

Maintainability
B
6 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.python4j;


import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.Pointer;

import java.util.*;

import static org.bytedeco.cpython.global.python.*;

public class PythonObject {

    static {
        new PythonExecutioner();
    }

    private boolean owned = true;
    private PyObject nativePythonObject;


    public PythonObject(PyObject nativePythonObject, boolean owned) {
        PythonGIL.assertThreadSafe();
        this.nativePythonObject = nativePythonObject;
        this.owned = owned;
        if (owned && nativePythonObject != null) {
            PythonGC.register(this);
        }
    }

    public PythonObject(PyObject nativePythonObject) {
        PythonGIL.assertThreadSafe();
        this.nativePythonObject = nativePythonObject;
        if (nativePythonObject != null) {
            PythonGC.register(this);
        }

    }

    public PyObject getNativePythonObject() {
        return nativePythonObject;
    }

    public String toString() {
        return PythonTypes.STR.toJava(this);

    }

    public boolean isNone() {
        if (nativePythonObject == null || Pointer.isNull(nativePythonObject)) {
            return true;
        }
        try (PythonGC gc = PythonGC.pause()) {
            PythonObject type = Python.type(this);
            boolean ret = Python.type(this).toString().equals("<class 'NoneType'>") && toString().equals("None");
            Py_DecRef(type.nativePythonObject);
            return ret;
        }
    }

    public void del() {
        PythonGIL.assertThreadSafe();
        if (owned && nativePythonObject != null && !PythonGC.isWatching()) {
            Py_DecRef(nativePythonObject);
            nativePythonObject = null;
        }
    }

    public PythonObject callWithArgs(PythonObject args) {
        return callWithArgsAndKwargs(args, null);
    }

    public PythonObject callWithKwargs(PythonObject kwargs) {
        if (!Python.callable(this)) {
            throw new PythonException("Object is not callable: " + toString());
        }
        PyObject tuple = PyTuple_New(0);
        PyObject dict = kwargs.nativePythonObject;
        if (PyObject_IsInstance(dict, new PyObject(PyDict_Type())) != 1) {
            throw new PythonException("Expected kwargs to be dict. Received: " + kwargs.toString());
        }
        PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
        Py_DecRef(tuple);
        return ret;
    }

    public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) {
        PythonGIL.assertThreadSafe();
        PyObject tuple = null;
        boolean ownsTuple = false;
        try {
            if (!Python.callable(this)) {
                throw new PythonException("Object is not callable: " + toString());
            }

            if (PyObject_IsInstance(args.nativePythonObject, new PyObject(PyTuple_Type())) == 1) {
                tuple = args.nativePythonObject;
            } else if (PyObject_IsInstance(args.nativePythonObject, new PyObject(PyList_Type())) == 1) {
                tuple = PyList_AsTuple(args.nativePythonObject);
                ownsTuple = true;
            } else {
                throw new PythonException("Expected args to be tuple or list. Received: " + args.toString());
            }
            if (kwargs != null && PyObject_IsInstance(kwargs.nativePythonObject, new PyObject(PyDict_Type())) != 1) {
                throw new PythonException("Expected kwargs to be dict. Received: " + kwargs.toString());
            }
            return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs == null ? null : kwargs.nativePythonObject));
        } finally {
            if (ownsTuple) Py_DecRef(tuple);
        }

    }


    public PythonObject call(Object... args) {
        return callWithArgsAndKwargs(Arrays.asList(args), null);
    }

    public PythonObject callWithArgs(List args) {
        return call(args, null);
    }

    public PythonObject callWithKwargs(Map kwargs) {
        return call(null, kwargs);
    }

    public PythonObject callWithArgsAndKwargs(List args, Map kwargs) {
        PythonGIL.assertThreadSafe();
        try (PythonGC gc = PythonGC.watch()) {
            if (!Python.callable(this)) {
                throw new PythonException("Object is not callable: " + toString());
            }
            PythonObject pyArgs;
            PythonObject pyKwargs;

            if (args == null || args.isEmpty()) {
                pyArgs = new PythonObject(PyTuple_New(0));
            } else {
                PythonObject argsList = PythonTypes.convert(args);
                pyArgs = new PythonObject(PyList_AsTuple(argsList.getNativePythonObject()));
            }
            if (kwargs == null) {
                pyKwargs = null;
            } else {
                pyKwargs = PythonTypes.convert(kwargs);
            }

            PythonObject ret = new PythonObject(
                    PyObject_Call(
                            nativePythonObject,
                            pyArgs.nativePythonObject,
                            pyKwargs == null ? null : pyKwargs.nativePythonObject
                    )
            );

            PythonGC.keep(ret);

            return ret;
        }

    }


    public PythonObject attr(String attrName) {
        PythonGIL.assertThreadSafe();
        return new PythonObject(PyObject_GetAttrString(nativePythonObject, attrName));
    }


    public PythonObject(Object javaObject) {
        PythonGIL.assertThreadSafe();
        if (javaObject instanceof PythonObject) {
            owned = false;
            nativePythonObject = ((PythonObject) javaObject).nativePythonObject;
        } else {
            try (PythonGC gc = PythonGC.pause()) {
                nativePythonObject = PythonTypes.convert(javaObject).getNativePythonObject();
            }
            PythonGC.register(this);
        }

    }

    public int toInt() {
        return PythonTypes.INT.toJava(this).intValue();
    }

    public long toLong() {
        return PythonTypes.INT.toJava(this);
    }

    public float toFloat() {
        return PythonTypes.FLOAT.toJava(this).floatValue();
    }

    public double toDouble() {
        return PythonTypes.FLOAT.toJava(this);
    }

    public boolean toBoolean() {
        return PythonTypes.BOOL.toJava(this);

    }

    public List toList() {
        return PythonTypes.LIST.toJava(this);
    }

    public Map toMap() {
        return PythonTypes.DICT.toJava(this);
    }

    public PythonObject get(int key) {
        PythonGIL.assertThreadSafe();
        return new PythonObject(PyObject_GetItem(nativePythonObject, PyLong_FromLong(key)));
    }

    public PythonObject get(String key) {
        PythonGIL.assertThreadSafe();
        return new PythonObject(PyObject_GetItem(nativePythonObject, PyUnicode_FromString(key)));
    }

    public PythonObject get(PythonObject key) {
        PythonGIL.assertThreadSafe();
        return new PythonObject(PyObject_GetItem(nativePythonObject, key.nativePythonObject));
    }

    public void set(PythonObject key, PythonObject value) {
        PythonGIL.assertThreadSafe();
        PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
    }


    public PythonObject abs(){
        return new PythonObject(PyNumber_Absolute(nativePythonObject));
    }
    public PythonObject add(PythonObject pythonObject){
        return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject sub(PythonObject pythonObject){
        return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject mod(PythonObject pythonObject){
        return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject mul(PythonObject pythonObject){
        return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject trueDiv(PythonObject pythonObject){
        return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject floorDiv(PythonObject pythonObject){
        return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
    }
    public PythonObject matMul(PythonObject pythonObject){
        return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
    }

    public void addi(PythonObject pythonObject){
        PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
    }
    public void subi(PythonObject pythonObject){
        PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
    }
    public void muli(PythonObject pythonObject){
        PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
    }
    public void trueDivi(PythonObject pythonObject){
        PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
    }
    public void floorDivi(PythonObject pythonObject){
        PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
    }
    public void matMuli(PythonObject pythonObject){
        PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
    }
}