SciRuby/nmatrix

View on GitHub
ext/nmatrix/math/gemm.h

Summary

Maintainability
Test Coverage
/////////////////////////////////////////////////////////////////////
// = NMatrix
//
// A linear algebra library for scientific computation in Ruby.
// NMatrix is part of SciRuby.
//
// NMatrix was originally inspired by and derived from NArray, by
// Masahiro Tanaka: http://narray.rubyforge.org
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
// == Contributing
//
// By contributing source code to SciRuby, you agree to be bound by
// our Contributor Agreement:
//
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
//
// == gemm.h
//
// Header file for interface with ATLAS's CBLAS gemm functions and
// native templated version of LAPACK's gemm function.
//

#ifndef GEMM_H
# define GEMM_H

#include "cblas_enums.h"
#include "math/long_dtype.h"

namespace nm { namespace math {
/*
 * GEneral Matrix Multiplication: based on dgemm.f from Netlib.
 *
 * This is an extremely inefficient algorithm. Recommend using ATLAS' version instead.
 *
 * Template parameters: LT -- long version of type T. Type T is the matrix dtype.
 *
 * This version throws no errors. Use gemm<DType> instead for error checking.
 */
template <typename DType>
inline void gemm_nothrow(const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
                 const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
{

  typename LongDType<DType>::type temp;

  // Quick return if possible
  if (!M or !N or ((*alpha == 0 or !K) and *beta == 1)) return;

  // For alpha = 0
  if (*alpha == 0) {
    if (*beta == 0) {
      for (int j = 0; j < N; ++j)
        for (int i = 0; i < M; ++i) {
          C[i+j*ldc] = 0;
        }
    } else {
      for (int j = 0; j < N; ++j)
        for (int i = 0; i < M; ++i) {
          C[i+j*ldc] *= *beta;
        }
    }
    return;
  }

  // Start the operations
  if (TransB == CblasNoTrans) {
    if (TransA == CblasNoTrans) {
      // C = alpha*A*B+beta*C
      for (int j = 0; j < N; ++j) {
        if (*beta == 0) {
          for (int i = 0; i < M; ++i) {
            C[i+j*ldc] = 0;
          }
        } else if (*beta != 1) {
          for (int i = 0; i < M; ++i) {
            C[i+j*ldc] *= *beta;
          }
        }

        for (int l = 0; l < K; ++l) {
          if (B[l+j*ldb] != 0) {
            temp = *alpha * B[l+j*ldb];
            for (int i = 0; i < M; ++i) {
              C[i+j*ldc] += A[i+l*lda] * temp;
            }
          }
        }
      }

    } else {

      // C = alpha*A**DType*B + beta*C
      for (int j = 0; j < N; ++j) {
        for (int i = 0; i < M; ++i) {
          temp = 0;
          for (int l = 0; l < K; ++l) {
            temp += A[l+i*lda] * B[l+j*ldb];
          }

          if (*beta == 0) {
            C[i+j*ldc] = *alpha*temp;
          } else {
            C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
          }
        }
      }

    }

  } else if (TransA == CblasNoTrans) {

    // C = alpha*A*B**T + beta*C
    for (int j = 0; j < N; ++j) {
      if (*beta == 0) {
        for (int i = 0; i < M; ++i) {
          C[i+j*ldc] = 0;
        }
      } else if (*beta != 1) {
        for (int i = 0; i < M; ++i) {
          C[i+j*ldc] *= *beta;
        }
      }

      for (int l = 0; l < K; ++l) {
        if (B[j+l*ldb] != 0) {
          temp = *alpha * B[j+l*ldb];
          for (int i = 0; i < M; ++i) {
            C[i+j*ldc] += A[i+l*lda] * temp;
          }
        }
      }

    }

  } else {

    // C = alpha*A**DType*B**T + beta*C
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < M; ++i) {
        temp = 0;
        for (int l = 0; l < K; ++l) {
          temp += A[l+i*lda] * B[j+l*ldb];
        }

        if (*beta == 0) {
          C[i+j*ldc] = *alpha*temp;
        } else {
          C[i+j*ldc] = *alpha*temp + *beta*C[i+j*ldc];
        }
      }
    }

  }

  return;
}



template <typename DType>
inline void gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
                 const DType* alpha, const DType* A, const int lda, const DType* B, const int ldb, const DType* beta, DType* C, const int ldc)
{
  if (Order == CblasRowMajor) {
    if (TransA == CblasNoTrans) {
      if (lda < std::max(K,1)) {
        rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
      }
    } else {
      if (lda < std::max(M,1)) { // && TransA == CblasTrans
        rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
      }
    }

    if (TransB == CblasNoTrans) {
      if (ldb < std::max(N,1)) {
        rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
      }
    } else {
      if (ldb < std::max(K,1)) {
        rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d K=%d", ldb, K);
      }
    }

    if (ldc < std::max(N,1)) {
      rb_raise(rb_eArgError, "ldc must be >= MAX(N,1): ldc=%d N=%d", ldc, N);
    }
  } else { // CblasColMajor
    if (TransA == CblasNoTrans) {
      if (lda < std::max(M,1)) {
        rb_raise(rb_eArgError, "lda must be >= MAX(M,1): lda=%d M=%d", lda, M);
      }
    } else {
      if (lda < std::max(K,1)) { // && TransA == CblasTrans
        rb_raise(rb_eArgError, "lda must be >= MAX(K,1): lda=%d K=%d", lda, K);
      }
    }

    if (TransB == CblasNoTrans) {
      if (ldb < std::max(K,1)) {
        rb_raise(rb_eArgError, "ldb must be >= MAX(K,1): ldb=%d N=%d", ldb, K);
      }
    } else {
      if (ldb < std::max(N,1)) { // NOTE: This error message is actually wrong in the ATLAS source currently. Or are we wrong?
        rb_raise(rb_eArgError, "ldb must be >= MAX(N,1): ldb=%d N=%d", ldb, N);
      }
    }

    if (ldc < std::max(M,1)) {
      rb_raise(rb_eArgError, "ldc must be >= MAX(M,1): ldc=%d N=%d", ldc, M);
    }
  }

  /*
   * Call SYRK when that's what the user is actually asking for; just handle beta=0, because beta=X requires
   * we copy C and then subtract to preserve asymmetry.
   */

  if (A == B && M == N && TransA != TransB && lda == ldb && beta == 0) {
    rb_raise(rb_eNotImpError, "syrk and syreflect not implemented");
    /*syrk<DType>(CblasUpper, (Order == CblasColMajor) ? TransA : TransB, N, K, alpha, A, lda, beta, C, ldc);
    syreflect(CblasUpper, N, C, ldc);
    */
  }

  if (Order == CblasRowMajor)    gemm_nothrow<DType>(TransB, TransA, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
  else                           gemm_nothrow<DType>(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);

}


}} // end of namespace nm::math

#endif // GEMM_H