deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cpu-backend-common/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java

Summary

Maintainability
F
4 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.cpu.nativecpu.blas;

import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.nd4j.linalg.api.blas.BlasException;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;


public class CpuLapack extends BaseLapack {
    public static final int LAPACK_ROW_MAJOR = 101;
    public static final int LAPACK_COL_MAJOR = 102;

    protected static int getColumnOrder(INDArray A) {
        return A.ordering() == 'f' ? LAPACK_COL_MAJOR : LAPACK_ROW_MAJOR;
    }

    protected static int getLda(INDArray A) {
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        return A.ordering() == 'f' ? (int) A.rows() : (int) A.columns();
    }
    //=========================
// L U DECOMP
    @Override
    public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_sgetrf(getColumnOrder(A), M, N,
                (FloatPointer)A.data().addressPointer(),
                getLda(A), (IntPointer)IPIV.data().addressPointer()
        );
        if( status < 0 ) {
            throw new BlasException( "Failed to execute sgetrf", status ) ;
        }
    }

    @Override
    public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_dgetrf(getColumnOrder(A), M, N, (DoublePointer)A.data().addressPointer(),
                getLda(A), (IntPointer)IPIV.data().addressPointer()
        );
        if( status < 0 ) {
            throw new BlasException( "Failed to execute dgetrf", status ) ;
        }
    }

    //=========================
// Q R DECOMP
    @Override
    public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
        INDArray tau = Nd4j.create(DataType.FLOAT, N ) ;

        int status = Nd4j.getBlasLapackDelegator().LAPACKE_sgeqrf(getColumnOrder(A), M, N,
                (FloatPointer)A.data().addressPointer(), getLda(A),
                (FloatPointer)tau.data().addressPointer()
        );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute sgeqrf", status ) ;
        }

        // Copy R ( upper part of Q ) into result
        if( R != null ) {
            R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;

            for( int i = 1 ; i < Math.min( A.rows(), A.columns() ) ; i++ ) {
                ix[0] = NDArrayIndex.point(i) ;
                ix[1] = NDArrayIndex.interval( 0, i ) ;
                R.put(ix, 0) ;
            }
        }

        status = Nd4j.getBlasLapackDelegator().LAPACKE_sorgqr( getColumnOrder(A), M, N, N,
                (FloatPointer)A.data().addressPointer(), getLda(A),
                (FloatPointer)tau.data().addressPointer()
        );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute sorgqr", status ) ;
        }
    }

    @Override
    public void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO)  {
        INDArray tau = Nd4j.create(DataType.DOUBLE, N ) ;

        int status = Nd4j.getBlasLapackDelegator().LAPACKE_dgeqrf(getColumnOrder(A), M, N,
                (DoublePointer)A.data().addressPointer(), getLda(A),
                (DoublePointer)tau.data().addressPointer()
        );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute dgeqrf", status ) ;
        }

        // Copy R ( upper part of Q ) into result
        if( R != null ) {
            R.assign( A.get(NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ;
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;

            for( int i = 1 ; i < Math.min( A.rows(), A.columns() ) ; i++ ) {
                ix[0] = NDArrayIndex.point(i) ;
                ix[1] = NDArrayIndex.interval( 0, i ) ;
                R.put(ix, 0) ;
            }
        }

        status = Nd4j.getBlasLapackDelegator().LAPACKE_dorgqr( getColumnOrder(A), M, N, N,
                (DoublePointer)A.data().addressPointer(), getLda(A),
                (DoublePointer)tau.data().addressPointer()
        );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute dorgqr", status ) ;
        }
    }


    //=========================
// CHOLESKY DECOMP
    @Override
    public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_spotrf(getColumnOrder(A), uplo, N,
                (FloatPointer)A.data().addressPointer(), getLda(A) );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute spotrf", status ) ;
        }
        if( uplo == 'U' ) {
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
            for( int i = 1 ; i < Math.min( A.rows(), A.columns() ) ; i++ ) {
                ix[0] = NDArrayIndex.point(i);
                ix[1] = NDArrayIndex.interval(0, i ) ;
                A.put(ix, 0);
            }
        } else {
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
            for( int i = 0 ; i < Math.min(A.rows(), A.columns()-1 ) ; i++ ) {
                ix[0] = NDArrayIndex.point(i) ;
                ix[1] = NDArrayIndex.interval(i + 1, A.columns() ) ;
                A.put(ix, 0) ;
            }
        }
    }

    @Override
    public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_dpotrf(getColumnOrder(A), uplo, N,
                (DoublePointer)A.data().addressPointer(), getLda(A) );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute dpotrf", status ) ;
        }
        if( uplo == 'U' ) {
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
            for( int i = 1 ; i < Math.min( A.rows(), A.columns() ) ; i++ ) {
                ix[0] = NDArrayIndex.point(i) ;
                ix[1] = NDArrayIndex.interval( 0, i ) ;
                A.put(ix, 0) ;
            }
        } else {
            INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
            for( int i = 0; i < Math.min( A.rows(), A.columns()-1 ) ; i++ ) {
                ix[0] = NDArrayIndex.point( i ) ;
                ix[1] = NDArrayIndex.interval( i+1, A.columns() ) ;
                A.put(ix, 0) ;
            }
        }
    }



    //=========================
// U S V' DECOMP  (aka SVD)
    @Override
    public void sgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
                       INDArray INFO) {
        INDArray superb = Nd4j.create(DataType.FLOAT, M < N ? M : N ) ;
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_sgesvd(getColumnOrder(A), jobu, jobvt, M, N,
                (FloatPointer)A.data().addressPointer(), getLda(A),
                (FloatPointer)S.data().addressPointer(),
                U == null ? null : (FloatPointer)U.data().addressPointer(), U == null ? 1 : getLda(U),
                VT == null ? null : (FloatPointer)VT.data().addressPointer(), VT == null ? 1 : getLda(VT),
                (FloatPointer)superb.data().addressPointer()
        );
        if( status != 0 ) {
            throw new BlasException( "Failed to execute sgesvd", status ) ;
        }
    }

    @Override
    public void dgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
                       INDArray INFO) {
        INDArray superb = Nd4j.create(DataType.DOUBLE, M < N ? M : N ) ;
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_dgesvd(getColumnOrder(A), jobu, jobvt, M, N,
                (DoublePointer)A.data().addressPointer(), getLda(A),
                (DoublePointer)S.data().addressPointer(),
                U == null ? null : (DoublePointer)U.data().addressPointer(), U == null ? 1 : getLda(U),
                VT == null ? null : (DoublePointer)VT.data().addressPointer(), VT == null ? 1 : getLda(VT),
                (DoublePointer)superb.data().addressPointer()
        ) ;
        if( status != 0 ) {
            throw new BlasException( "Failed to execute dgesvd", status ) ;
        }
    }


    //=========================
// syev EigenValue/Vectors
//
    @Override
    public int ssyev( char jobz, char uplo, int N, INDArray A, INDArray R ) {
        FloatPointer fp = new FloatPointer(1) ;
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_ssyev_work( getColumnOrder(A), (byte)jobz, (byte)uplo,
                N, (FloatPointer)A.data().addressPointer(), getLda(A),
                (FloatPointer)R.data().addressPointer(), fp, -1 ) ;
        if( status == 0 ) {
            int lwork = (int)fp.get() ;
            INDArray work = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createFloat(lwork),
                    Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {lwork}, A.dataType()).getFirst());

            status = Nd4j.getBlasLapackDelegator().LAPACKE_ssyev( getColumnOrder(A), (byte)jobz, (byte)uplo, N,
                    (FloatPointer)A.data().addressPointer(), getLda(A),
                    (FloatPointer)work.data().addressPointer() ) ;
            if( status == 0 ) {
                R.assign(work.get(NDArrayIndex.interval(0,N))) ;
            }
        }
        return status ;
    }


    public int dsyev( char jobz, char uplo, int N, INDArray A, INDArray R ) {

        DoublePointer dp = new DoublePointer(1) ;
        int status = Nd4j.getBlasLapackDelegator().LAPACKE_dsyev_work( getColumnOrder(A), (byte)jobz, (byte)uplo,
                N, (DoublePointer)A.data().addressPointer(), getLda(A),
                (DoublePointer)R.data().addressPointer(), dp, -1 ) ;
        if( status == 0 ) {
            int lwork = (int)dp.get() ;
            INDArray work = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createDouble(lwork),
                    Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {lwork}, A.dataType()).getFirst());

            status = Nd4j.getBlasLapackDelegator().LAPACKE_dsyev( getColumnOrder(A), (byte)jobz, (byte)uplo, N,
                    (DoublePointer)A.data().addressPointer(), getLda(A),
                    (DoublePointer)work.data().addressPointer() ) ;

            if( status == 0 ) {
                R.assign( work.get( NDArrayIndex.interval(0,N) ) ) ;
            }
        }
        return status ;
    }




    /**
     * Generate inverse given LU decomp
     *
     * @param N
     * @param A
     * @param lda
     * @param IPIV
     * @param WORK
     * @param lwork
     * @param INFO
     */
    @Override
    public void getri(int N, INDArray A, int lda, int[] IPIV, INDArray WORK, int lwork, int INFO) {
        throw new UnsupportedOperationException();
    }
}