OpenJij/OpenJij

View on GitHub
include/openjij/utility/gpu/cublas.hpp

Summary

Maintainability
Test Coverage
//    Copyright 2023 Jij Inc.

//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at

//        http://www.apache.org/licenses/LICENSE-2.0

//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.

#ifndef OPENJIJ_UTILITY_GPU_CUBLAS_HPP__
#define OPENJIJ_UTILITY_GPU_CUBLAS_HPP__

#ifdef USE_CUDA

#include <cublas_v2.h>
#include <cuda_runtime.h>

#include "openjij/utility/gpu/handle_error.hpp"
#include "openjij/utility/gpu/memory.hpp"

namespace openjij {
namespace utility {
namespace cuda {

// cuda datatype
template <typename FloatType> struct cudaDataType_impl;

template <> struct cudaDataType_impl<float> {
  constexpr static cudaDataType_t type = CUDA_R_32F;
};

template <> struct cudaDataType_impl<double> {
  constexpr static cudaDataType_t type = CUDA_R_64F;
};

// cublas get maximal value
template <typename FloatType>
inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
                                        const FloatType *x, int incx,
                                        int *result);

template <>
inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
                                        const float *x, int incx, int *result) {
  return cublasIsamax(handle, n, x, incx, result);
}

template <>
inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
                                        const double *x, int incx,
                                        int *result) {
  return cublasIdamax(handle, n, x, incx, result);
}

// cublas dot product
template <typename FloatType>
inline cublasStatus_t
cublas_dot_impl(cublasHandle_t handle, int n, const FloatType *x, int incx,
                const FloatType *y, int incy, FloatType *result);

template <>
inline cublasStatus_t cublas_dot_impl(cublasHandle_t handle, int n,
                                      const float *x, int incx, const float *y,
                                      int incy, float *result) {
  return cublasSdot(handle, n, x, incx, y, incy, result);
}

template <>
inline cublasStatus_t
cublas_dot_impl(cublasHandle_t handle, int n, const double *x, int incx,
                const double *y, int incy, double *result) {
  return cublasDdot(handle, n, x, incx, y, incy, result);
}

/**
 * @brief cuBLAS wrapper
 */
class CuBLASWrapper {
public:
  CuBLASWrapper() {
    // generate cuBLAS instance
    HANDLE_ERROR_CUBLAS(cublasCreate(&_handle));
    // use tensor core if possible
    HANDLE_ERROR_CUBLAS(cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH));
  }

  CuBLASWrapper(CuBLASWrapper &&obj) noexcept {
    // move cuBLAS handler
    this->_handle = obj._handle;
    obj._handle = NULL;
  }

  ~CuBLASWrapper() {
    // destroy generator
    if (_handle != NULL)
      HANDLE_ERROR_CUBLAS(cublasDestroy(_handle));
  }

  template <typename FloatType>
  inline void SgemmEx(cublasOperation_t transa, cublasOperation_t transb, int m,
                      int n, int k, const float *alpha,
                      const utility::cuda::unique_dev_ptr<FloatType> &A,
                      int lda,
                      const utility::cuda::unique_dev_ptr<FloatType> &B,
                      int ldb, const float *beta,
                      utility::cuda::unique_dev_ptr<FloatType> &C, int ldc) {

    cublasPointerMode_t mode;
    HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
    HANDLE_ERROR_CUBLAS(
        cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
    HANDLE_ERROR_CUBLAS(cublasSgemmEx(
        _handle, transa, transb, m, n, k, alpha, A.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        lda, B.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        ldb, beta, C.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        ldc));
    HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
  }

  /**
   * @brief matrix multiplication
   * C_mn = A_mk B_kn
   *
   * @tparam FloatType
   * @param m
   * @param k
   * @param n
   * @param A
   * @param B
   * @param C
   */
  template <typename FloatType>
  inline void matmul(int m, int k, int n,
                     const utility::cuda::unique_dev_ptr<FloatType> &A,
                     const utility::cuda::unique_dev_ptr<FloatType> &B,
                     utility::cuda::unique_dev_ptr<FloatType> &C,
                     cublasOperation_t transa = CUBLAS_OP_N,
                     cublasOperation_t transb = CUBLAS_OP_N) {
    typename std::remove_extent<FloatType>::type alpha = 1.0;
    typename std::remove_extent<FloatType>::type beta = 0;

    cublasPointerMode_t mode;
    HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
    HANDLE_ERROR_CUBLAS(
        cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
    HANDLE_ERROR_CUBLAS(cublasSgemmEx(
        _handle, transa, transb, m, n, k, &alpha, A.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        m, B.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        k, &beta, C.get(),
        cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
        m));
    HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
  }

  /**
   * @brief wrap function of cublasIsamax
   * Note: returned value will be 1-indexed!
   *
   * @tparam FloatType
   * @param n
   * @param x
   * @param incx
   * @param result
   */
  template <typename FloatType>
  inline void Iamax(int n, const FloatType *x, int incx, int *result) {
    cublasPointerMode_t mode;
    HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
    // set pointermode to device
    HANDLE_ERROR_CUBLAS(
        cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
    HANDLE_ERROR_CUBLAS(cublas_Iamax_impl(_handle, n, x, incx, result));
    // reset pointermode
    HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
  }

  /**
   * @brief return the index of maximal element
   * Note: returned value will be 1-indexed!
   *
   * @tparam FloatType
   * @param n
   * @param x
   * @param result
   */
  template <typename FloatType>
  inline void
  absmax_val_index(int n, const utility::cuda::unique_dev_ptr<FloatType[]> &x,
                   utility::cuda::unique_dev_ptr<int[]> &result) {
    Iamax(n, x.get(), 1, result.get());
  }

  /**
   * @brief wrap function of cublasXdot
   *
   * @tparam FloatType
   * @param n
   * @param x
   * @param incx
   * @param y
   * @param incy
   * @param result
   */
  template <typename FloatType>
  inline void dot(int n, const FloatType *x, int incx, const FloatType *y,
                  int incy, FloatType *result) {
    cublasPointerMode_t mode;
    HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
    HANDLE_ERROR_CUBLAS(
        cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
    // set pointermode to device
    HANDLE_ERROR_CUBLAS(cublas_dot_impl(_handle, n, x, incx, y, incy, result));
    // reset pointermode
    HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
  }

  /**
   * @brief return the dot product of x and y
   *
   * @tparam FloatType
   * @param n
   * @param x
   * @param y
   * @param result
   */
  template <typename FloatType>
  inline void dot(int n, const utility::cuda::unique_dev_ptr<FloatType[]> &x,
                  const utility::cuda::unique_dev_ptr<FloatType[]> &y,
                  utility::cuda::unique_dev_ptr<FloatType[]> &result) {
    dot(n, x.get(), 1, y.get(), 1, result.get());
  }

private:
  cublasHandle_t _handle;
};

} // namespace cuda
} // namespace utility
} // namespace openjij

#endif
#endif