deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/platform/vednn/veda_vednn.vcpp

Summary

Maintainability
Test Coverage
#include <stdint.h>
#include <stdio.h>
#include <veda_device.h>
#include <vednn.h>

#include <cmath>
#include <iostream>
#include <memory>

using LongType = long long;
//#define SHOW_ON_FUNC_ENTRY 1
#if !defined(SHOW_ON_FUNC_ENTRY)
#define LOG_FUNC()
#else
#define LOG_FUNC() printf("%s in [%s %d]\n", __PRETTY_FUNCTION__, __FILE__, __LINE__)
#endif

// Note: it should be fine to use one status for veda OMP mode
static uint64_t last_status = 0;

#define CHECK_ERROR_BEFORE_EXEC()             \
  do {                                        \
    if (last_status != 0) return last_status; \
  } while (0)

#define CHECK_FUNC(res)     \
  do {                      \
    if (last_status == 0) { \
      if (res != 0) {       \
        RETURN(res);        \
      }                     \
    }                       \
  } while (0)

#define RETURN(res)                                                  \
  do {                                                               \
    return set_return((uint64_t)res, __PRETTY_FUNCTION__, __LINE__); \
  } while (0)

inline uint64_t set_return(uint64_t res, const char *msg, int line) {
  last_status = res;
  if (res != 0) {
    printf("%s %d result code: [%d]\n", msg, line, (int)res);
  }
  return res;
}

inline void copyTo_nhwc_generic(const vednnTensorParam_t &p, const float *nchw, float *nhwc) {
  int hs = p.width;
  int cs = p.width * p.height;
  int ns = cs * p.channel;
  int ws2 = p.channel;
  int hs2 = ws2 * p.width;
  LOG_FUNC();
#pragma omp parallel for
  for (int n = 0; n < p.batch; n++) {
    for (int h = 0; h < p.height; h++) {
      for (int w = 0; w < p.width; w++) {
        for (int c = 0; c < p.channel; c++) {
          nhwc[n * ns + h * hs2 + w * ws2 + c] = nchw[n * ns + h * hs + c * cs + w];
        }
      }
    }
  }
  LOG_FUNC();
}

inline void copyTo_nchw_generic(const vednnTensorParam_t &p, const float *nhwc, float *nchw) {
  constexpr int cs = 1;
  int ws = p.channel;
  int hs = ws * p.width;
  int ns = hs * p.height;
  LOG_FUNC();
  constexpr int ws1 = 1;
  int hs1 = p.width;
  int cs1 = hs1 * p.height;
  int ns1 = cs1 * p.channel;
#pragma omp parallel for
  for (int n = 0; n < p.batch; n++) {
    for (int h = 0; h < p.height; h++) {
      for (int w = 0; w < p.width; w++) {
        for (int c = 0; c < p.channel; c++) {
          nchw[n * ns1 + h * hs1 + w * ws1 + c * cs1] = nhwc[n * ns + h * hs + w * ws + c * cs];
        }
      }
    }
  }
  LOG_FUNC();
}

void copyIntoNHWC(const vednnTensorParam_t &param, const float *nchw_data, float *nhwc_data) {
  return copyTo_nhwc_generic(param, nchw_data, nhwc_data);
}

float *getNCHW(const vednnTensorParam_t &param, float *nhwc_data, std::unique_ptr<float[]> &temp) {
  if (param.channel == 1) {
    // there is not any need for conversion
    return nhwc_data;
  } else {
    LOG_FUNC();
    int hwSize = param.height * param.width;
    int strideN = hwSize * param.channel;
    size_t length = param.batch * strideN;
    temp.reset(new float[length]);
    float *nchw_data = temp.get();
    copyTo_nchw_generic(param, nhwc_data, nchw_data);
    LOG_FUNC();
    return nchw_data;
  }
}

void showBuffer(float *x, int l) {
  for (int i = 0; i < l; i++) std::cout << x[i] << ", ";
  std::cout << std::endl;
}

float *getWeightFormat1Data(const vednnFilterParam_t &paramFilter, float *weight, int wFormat,
                            std::unique_ptr<float[]> &temp) {
  // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
  if (wFormat == 1) {
    return weight;
  } else {
    if (wFormat == 2) {
      LOG_FUNC();
      //[oC, kH, kW, iC] -> [oC, iC, kH, kW],
      vednnTensorParam_t param;
      param.dtype = DTYPE_FLOAT;
      param.batch = paramFilter.outChannel;
      param.channel = paramFilter.inChannel;
      param.height = paramFilter.height;
      param.width = paramFilter.width;
      auto w = getNCHW(param, weight, temp);
      LOG_FUNC();
      return w;
    } else {
      //[kH, kW, iC, oC] -> [oC, iC, kH, kW]
      LOG_FUNC();
      constexpr int ocs0 = 1;
      int ics0 = paramFilter.outChannel;
      int ws0 = ics0 * paramFilter.inChannel;
      int hs0 = ws0 * paramFilter.width;

      size_t length = hs0 * paramFilter.height;
      temp.reset(new float[length]);
      float *ret = temp.get();
      constexpr int ws1 = 1;
      int hs1 = paramFilter.width;
      int ics1 = hs1 * paramFilter.height;
      int ocs1 = ics1 * paramFilter.inChannel;
#pragma omp parallel for
      for (int h = 0; h < paramFilter.height; h++)
        for (int w = 0; w < paramFilter.width; w++)
          for (int i = 0; i < paramFilter.inChannel; i++)
            for (int j = 0; j < paramFilter.outChannel; j++) {
              ret[j * ocs1 + i * ics1 + w * ws1 + h * hs1] = weight[j * ocs0 + i * ics0 + w * ws0 + h * hs0];
            }
      //      }
      LOG_FUNC();
      return ret;
    }
  }
}

// copy WeightFormat2 into WeightFormat1
void copyWf2ToWf1(const vednnFilterParam_t &paramFilter, const float *weight, float *weightOut) {
  vednnTensorParam_t param;
  param.dtype = DTYPE_FLOAT;
  param.batch = paramFilter.outChannel;
  param.channel = paramFilter.inChannel;
  param.height = paramFilter.height;
  param.width = paramFilter.width;
  copyIntoNHWC(param, weight, weightOut);
}

void copyWf0ToWf1(const vednnFilterParam_t &paramFilter, const float *weight, float *weightOut) {
  constexpr int ocs0 = 1;
  int ics0 = paramFilter.outChannel;
  int ws0 = ics0 * paramFilter.inChannel;
  int hs0 = ws0 * paramFilter.width;

  constexpr int ws1 = 1;
  int hs1 = paramFilter.width;
  int ics1 = hs1 * paramFilter.height;
  int ocs1 = ics1 * paramFilter.inChannel;
#pragma omp parallel for
  for (int h = 0; h < paramFilter.height; h++)
    for (int w = 0; w < paramFilter.width; w++)
      for (int i = 0; i < paramFilter.inChannel; i++)
        for (int j = 0; j < paramFilter.outChannel; j++) {
          weightOut[j * ocs0 + i * ics0 + w * ws0 + h * hs0] = weight[j * ocs1 + i * ics1 + w * ws1 + h * hs1];
        }
}

extern "C" {

void showBufferVeAsFloat(VEDAdeviceptr v) {
#if defined(DEBUG_VEDA_LOGS)
  size_t size = 0;
  void *in;

  std::cout << "ve " << (void *)v << "\n";
  if (v) {
    vedaMemPtrSize(&in, &size, v);
    float *x = (float *)in;
    size = size > 80 ? 80 : 0;
    for (int i = 0; i < size / sizeof(float); i++) std::cout << x[i] << ", ";
  } else {
    std::cout << "null ";
  }
  std::cout << std::endl;
#endif
}

uint64_t vedaVednnConvolutionForwardAddBias(const vednnTensorParam_t *paramIn, VEDAdeviceptr vDataIn,
                                            uint8_t isDataInNCHW, const vednnFilterParam_t *paramFilter,
                                            VEDAdeviceptr vDataKernel, int32_t weightFormat,
                                            const vednnBiasParam_t *paramBias, VEDAdeviceptr vDataBias,
                                            const vednnTensorParam_t *paramOut, VEDAdeviceptr vDataOut,
                                            uint8_t isDataOutNCHW, const vednnConvolutionParam_t *paramConv,
                                            vednnConvolutionAlgorithm_t algo) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  vednnError_t res;
  void *pDataIn, *pDataBias = nullptr, *pDataKernelPtr;
  void *pDataOutPtr, *pDataOut = nullptr;
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataKernelPtr, vDataKernel));
  if (vDataBias) CHECK_FUNC(vedaMemPtr((void **)&pDataBias, vDataBias));
  CHECK_FUNC(vedaMemPtr((void **)&pDataOutPtr, vDataOut));
#if defined(DEBUG_VEDA_LOGS)
  showBufferVeAsFloat(vDataIn);
  showBufferVeAsFloat(vDataKernel);
  showBufferVeAsFloat(vDataBias);
#endif
  std::unique_ptr<float[]> tempIn, tempOut, tempW;

  if (!isDataInNCHW) {
    pDataIn = getNCHW(*paramIn, (float *)pDataIn, tempIn);
  }
  if (!isDataOutNCHW) {
    tempOut.reset(new float[paramOut->batch * paramOut->channel * paramOut->height * paramOut->width]);
    pDataOut = tempOut.get();
  } else {
    pDataOut = pDataOutPtr;
  }
  auto pDataKernel = getWeightFormat1Data(*paramFilter, (float *)pDataKernelPtr, weightFormat, tempW);

  if (pDataBias) {
    // printf("%s\n", "bias case");

    res = vednnConvolutionForwardAddBias(paramIn, pDataIn, paramFilter, pDataKernel, paramBias, pDataBias, paramOut,
                                         pDataOut, paramConv, algo);

  } else {
    res = vednnConvolutionForward(paramIn, pDataIn, paramFilter, pDataKernel, paramOut, pDataOut, paramConv, algo);
  }
  if (pDataOut != pDataOutPtr) {
    copyIntoNHWC(*paramOut, (const float *)pDataOut, (float *)pDataOutPtr);
  }

  RETURN(res);
}

uint64_t vedaVednnConvolutionBackwardDataAndFilter(const vednnTensorParam_t *paramGradOut, VEDAdeviceptr vGradOutData,
                                                   const vednnFilterParam_t *paramFilter, VEDAdeviceptr vWeightData,
                                                   int32_t weightFormat, VEDAdeviceptr vGradWeightData,
                                                   const vednnTensorParam_t *paramGradIn, VEDAdeviceptr vInData,
                                                   VEDAdeviceptr vGradInData, uint8_t isNCHW, VEDAdeviceptr vGradBias,
                                                   const vednnConvolutionParam_t *paramConv,
                                                   vednnConvolutionAlgorithm_t algo) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  void *gradOutData, *weightData, *gradWeightsData, *inData, *gradInData, *gradInDataFormatted;
  float *gradBias = nullptr;
  CHECK_FUNC(vedaMemPtr((void **)&gradOutData, vGradOutData));
  CHECK_FUNC(vedaMemPtr((void **)&weightData, vWeightData));
  CHECK_FUNC(vedaMemPtr((void **)&gradWeightsData, vGradWeightData));
  CHECK_FUNC(vedaMemPtr((void **)&inData, vInData));
  CHECK_FUNC(vedaMemPtr((void **)&gradInData, vGradInData));
  if (vGradBias) CHECK_FUNC(vedaMemPtr((void **)&gradBias, vGradBias));
  gradInDataFormatted = gradInData;

  // temporary memory holders for the case when we need formatted buffers
  std::unique_ptr<float[]> tempW, tempIn, tempGradIn, tempGradOut;

  if (!isNCHW) {
    inData = getNCHW(*paramGradIn, (float *)inData, tempIn);
    gradOutData = getNCHW(*paramGradOut, (float *)gradOutData, tempGradOut);
    gradInDataFormatted = getNCHW(*paramGradIn, (float *)gradInData, tempGradIn);
  }

  auto weightDataFormatted = getWeightFormat1Data(*paramFilter, (float *)weightData, weightFormat, tempW);

  vednnError_t res = vednnConvolutionBackwardData(paramGradOut, gradOutData, paramFilter, weightDataFormatted,
                                                  paramGradIn, gradInDataFormatted, paramConv, algo);

  if (res != VEDNN_SUCCESS) return res;

  if (gradInDataFormatted != gradInData) {
    copyIntoNHWC(*paramGradIn, (const float *)gradInDataFormatted, (float *)gradInData);
  }

  // paramGradIn could be used for "in"
  // paramFilter could be used for "gradWeights"
  if (weightDataFormatted == weightData) {
    res = vednnConvolutionBackwardFilter(paramGradIn, inData, paramGradOut, gradOutData, paramFilter, gradWeightsData,
                                         paramConv, algo);
  } else {
    std::unique_ptr<float[]> tempWeightsData;
    auto len = paramFilter->outChannel * paramFilter->inChannel * paramFilter->height * paramFilter->width;
    tempWeightsData.reset(new float[len]);
    float *gradWeightsDataLocal = tempWeightsData.get();

    res = vednnConvolutionBackwardFilter(paramGradIn, inData, paramGradOut, gradOutData, paramFilter,
                                         gradWeightsDataLocal, paramConv, algo);

    // [oC, iC, kH, kW] -> [kH, kW, iC, oC]
    if (weightFormat == 0) copyWf0ToWf1(*paramFilter, (const float *)gradWeightsDataLocal, (float *)gradWeightsData);
    // [oC, iC, kH, kW] -> [oC, kH, kW, iC]
    else
      copyWf2ToWf1(*paramFilter, (const float *)gradWeightsDataLocal, (float *)gradWeightsData);
  }

  // calculate Bias
  if (gradBias) {
    //// sum formatted gradOutData over bS, oH, oW

    int height_weight = paramGradOut->width * paramGradOut->height;
    int ns1 = height_weight * paramGradOut->channel;
    for (int c = 0; c < paramGradOut->channel; c++) {
      gradBias[c] = 0;
    }
    auto nhwc_c = (float *)gradOutData;
    for (int n = 0; n < paramGradOut->batch; n++) {
      for (int c = 0; c < paramGradOut->channel; c++) {
        auto sum = 0.f;
        for (int hw = 0; hw < height_weight; hw++) {
          sum += nhwc_c[hw];
        }
        gradBias[c] += sum;
        nhwc_c += height_weight;
      }
    }
  }
  RETURN(res);
}

uint64_t vedaVednnActivationForward(const vednnActivationMode_t mode, VEDAdeviceptr vDataIn, VEDAdeviceptr vDataOut,
                                    const uint64_t nElements) {
  LOG_FUNC();
  CHECK_ERROR_BEFORE_EXEC();
  void *pDataIn;
  void *pDataOut;
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));

  auto res = vednnActivationForward(mode, pDataIn, pDataOut, nElements);
  RETURN(res);
}

uint64_t vedaVednnActivationBackward(const vednnActivationMode_t mode, VEDAdeviceptr vDataGradOut,
                                     VEDAdeviceptr vDataIn, VEDAdeviceptr vDataGradIn, const uint64_t nElements) {
  LOG_FUNC();
  CHECK_ERROR_BEFORE_EXEC();
  void *pDataGradOut;
  void *pDataIn;
  void *pDataGradIn;
  CHECK_FUNC(vedaMemPtr((void **)&pDataGradOut, vDataGradOut));
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataGradIn, vDataGradIn));

  auto res = vednnActivationBackward(mode, pDataGradOut, pDataIn, pDataGradIn, nElements);
  RETURN(res);
}

uint64_t vedaVednnSoftmaxForward(const vednnSoftmaxMode_t mode, VEDAdeviceptr vDataIn, VEDAdeviceptr vDataOut,
                                 const uint64_t nBatch, const uint64_t nClass) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  void *pDataIn;
  void *pDataOut;
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));

  return vednnSoftmaxForward(mode, pDataIn, pDataOut, nBatch, nClass);
}

uint64_t vedaVednnLinearForwardExF32(uint64_t bGemm, const uint64_t inDim, const uint64_t outDim, const uint64_t nBatch,
                                     VEDAdeviceptr vX, const uint64_t xStride, VEDAdeviceptr vY, const uint64_t yStride,
                                     VEDAdeviceptr vZ, const uint64_t zStride) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  vednnError_t res;
  float *x, *y;
  float *z;
  CHECK_FUNC(vedaMemPtr((void **)&x, vX));
  CHECK_FUNC(vedaMemPtr((void **)&y, vY));
  CHECK_FUNC(vedaMemPtr((void **)&z, vZ));

  if (bGemm == 1) {
    RETURN(vednnLinearForward(inDim, outDim, nBatch, 1, x, y, z));
  } else {
    // because of the bgemm did not work as expected, we will manually parallelize over bGemm

    //#pragma omp parallel for
    for (int i = 0; i < bGemm; i++) {
      float *xPtr = x + i * xStride;
      float *yPtr = y + i * yStride;
      float *zPtr = z + i * zStride;
      vednnLinearForward(inDim, outDim, nBatch, 1, xPtr, yPtr, zPtr);
    }
    // WARNING: we will silently return success
    RETURN(VEDNN_SUCCESS);
  }
}

uint64_t vedaVednnMaxPoolingForward(const vednnTensorParam_t *pParamIn, VEDAdeviceptr vDataIn,
                                    const vednnTensorParam_t *pParamOut, VEDAdeviceptr vDataOut,
                                    const vednnPoolingParam_t *pParamPool) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  void *pDataIn;
  void *pDataOut;
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
#if defined(DEBUG_VEDA_LOGS)
  showBufferVeAsFloat(vDataIn);
#endif
  RETURN(vednnMaxPoolingForward(pParamIn, pDataIn, pParamOut, pDataOut, pParamPool));
}

uint64_t vedaVednnMaxPoolingBackwardEx(const vednnTensorParam_t *pParamGradOut, VEDAdeviceptr vDataGradOut,
                                       const vednnTensorParam_t *pParamOut, VEDAdeviceptr vDataOut,
                                       const vednnTensorParam_t *pParamIn, VEDAdeviceptr vDataIn,
                                       const vednnTensorParam_t *pParamGradIn, VEDAdeviceptr vDataGradIn,
                                       const vednnPoolingParam_t *pParamPool) {
  CHECK_ERROR_BEFORE_EXEC();
  LOG_FUNC();
  void *pDataGradOut, *pDataIn, *pDataGradIn, *pDataOut;
  CHECK_FUNC(vedaMemPtr((void **)&pDataGradOut, vDataGradOut));
  CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
  CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
  CHECK_FUNC(vedaMemPtr((void **)&pDataGradIn, vDataGradIn));

  vednnError_t res = vednnMaxPoolingForward(pParamIn, pDataIn, pParamOut, pDataOut, pParamPool);

  if (res == VEDNN_SUCCESS) {
    vednnMaxPoolingBackward(pParamGradOut, pDataGradOut, pParamOut, pDataOut, pParamIn, pDataIn, pParamGradIn,
                            pDataGradIn, pParamPool);
  }
  RETURN(res);
}

uint64_t vedaConcatUpTo32(const uint64_t nInput, VEDAdeviceptr *inputList, uint64_t *inputLengthInBytesList,
                          VEDAdeviceptr vO) {
  CHECK_ERROR_BEFORE_EXEC();

  LOG_FUNC();
  if (nInput > 0 && nInput <= 32) {
    //WARNING: we make it as uint32_t* because it is vectorizable
    //For now please pass only data types divisible by sizeof(uint32_t)
    uint32_t *output;
    CHECK_FUNC(vedaMemPtr((void **)&output, vO));
    struct OmpArgs {
      uint32_t *in;
      uint32_t *out;
      uint64_t size;
    };

    OmpArgs zPtrList[32];
    for (uint64_t i = 0; i < nInput; i++) {
      uint32_t *in;
      CHECK_FUNC(vedaMemPtr((void **)&in, inputList[i]));
      auto lengthOf = inputLengthInBytesList[i]/sizeof(uint32_t);

      zPtrList[i].in = in;
      zPtrList[i].out = output;
      zPtrList[i].size = lengthOf ;
      output += lengthOf;
    }

#pragma omp parallel for
    for (int i = 0; i < nInput; ++i) {
      auto inputPtr = zPtrList[i].in;
      auto outPtr = zPtrList[i].out;
      uint64_t size = zPtrList[i].size;
      for (uint64_t j = 0; j < size; j++) {
        outPtr[j] = inputPtr[j];
      }
    }
    RETURN(0);
  }

  RETURN(-1);

}

uint64_t vedaAdd_A(uint64_t length0, VEDAdeviceptr vIn0, uint64_t length1, VEDAdeviceptr vIn1, VEDAdeviceptr vO) {
  CHECK_ERROR_BEFORE_EXEC();

  uint64_t min_len = 0;
  uint64_t max_len = 0;
  float *big = nullptr;
  float *small = nullptr;
  float *out = nullptr;

  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  if (length1 > length0) {
    min_len = length0;
    max_len = length1;
    CHECK_FUNC(vedaMemPtr((void **)&small, vIn0));
    CHECK_FUNC(vedaMemPtr((void **)&big, vIn1));

  } else {
    min_len = length1;
    max_len = length0;
    CHECK_FUNC(vedaMemPtr((void **)&small, vIn1));
    CHECK_FUNC(vedaMemPtr((void **)&big, vIn0));
  }

  if (min_len == 1) {
    auto val = small[0];
    for (uint64_t i = 0L; i < max_len; i++) {
      out[i] = big[i] + val;
    }
  } else {
    int times = (int)(max_len / min_len);
    auto copy_times = 4096 / min_len;
    int times_k = (int)(times / copy_times);
    if (min_len < 4096 && times_k > 10) {
      float local_ll[4096];
      // copy into buffer
      float *ll = local_ll;

      for (int i = 0; i < copy_times; i++) {
        for (int j = 0; j < min_len; j++) {
          ll[j] = small[j];
        }
        ll += min_len;
      }
      auto llen = ll - local_ll;

      int times_tail = times % copy_times;
      for (int i = 0; i < times_k; i++) {
        auto out_inner = &(out[i * llen]);
        auto big_inner = &(big[i * llen]);
        for (uint64_t j = 0; j < llen; j++) {
          out_inner[j] = big_inner[j] + local_ll[j];
        }
      }
      if (times_tail > 0) {
        auto out_inner = &(out[times_k * llen]);
        auto big_inner = &(big[times_k * llen]);
        for (uint64_t j = 0; j < times_tail * min_len; j++) {
          out_inner[j] = big_inner[j] + local_ll[j];
        }
      }

    } else {
      for (int i = 0; i < times; i++) {
        auto out_inner = &(out[i * min_len]);
        auto big_inner = &(big[i * min_len]);
        for (uint64_t j = 0; j < min_len; j++) {
          out_inner[j] = big_inner[j] + small[j];
        }
      }
    }
  }

  RETURN(0);
}

uint64_t vedaMult_A(uint64_t length0, VEDAdeviceptr vIn0, uint64_t length1, VEDAdeviceptr vIn1, VEDAdeviceptr vO) {
  uint64_t min_len = 0;
  uint64_t max_len = 0;
  float *big = nullptr;
  float *small = nullptr;
  float *out = nullptr;

  CHECK_ERROR_BEFORE_EXEC();

  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  if (length1 > length0) {
    min_len = length0;
    max_len = length1;
    CHECK_FUNC(vedaMemPtr((void **)&small, vIn0));
    CHECK_FUNC(vedaMemPtr((void **)&big, vIn1));

  } else {
    min_len = length1;
    max_len = length0;
    CHECK_FUNC(vedaMemPtr((void **)&small, vIn1));
    CHECK_FUNC(vedaMemPtr((void **)&big, vIn0));
  }

  if (min_len == 1) {
    auto val = small[0];
    for (uint64_t i = 0L; i < max_len; i++) {
      out[i] = big[i] * val;
    }
  } else {
    int times = (int)(max_len / min_len);
    auto copy_times = 4096 / min_len;
    int times_k = (int)(times / copy_times);
    if (min_len < 4096 && times_k > 10) {
      float local_ll[4096];
      // copy into buffer
      float *ll = local_ll;

      for (int i = 0; i < copy_times; i++) {
        for (int j = 0; j < min_len; j++) {
          ll[j] = small[j];
        }
        ll += min_len;
      }
      auto llen = ll - local_ll;

      int times_tail = times % copy_times;
      for (int i = 0; i < times_k; i++) {
        auto out_inner = &(out[i * llen]);
        auto big_inner = &(big[i * llen]);
        for (uint64_t j = 0; j < llen; j++) {
          out_inner[j] = big_inner[j] * local_ll[j];
        }
      }
      if (times_tail > 0) {
        auto out_inner = &(out[times_k * llen]);
        auto big_inner = &(big[times_k * llen]);
        for (uint64_t j = 0; j < times_tail * min_len; j++) {
          out_inner[j] = big_inner[j] * local_ll[j];
        }
      }

    } else {
      for (int i = 0; i < times; i++) {
        auto out_inner = &(out[i * min_len]);
        auto big_inner = &(big[i * min_len]);
        for (uint64_t j = 0; j < min_len; j++) {
          out_inner[j] = big_inner[j] * small[j];
        }
      }
    }
  }

  RETURN(0);
}

uint64_t vedaPermuteAssignRank2_4(LongType *inputShapeInfo, VEDAdeviceptr vIn, LongType *outputShapeInfo,
                                  VEDAdeviceptr vO, const int *permutation) {
  // C order and ews Unchecked call, caller should check for the conditions
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();

  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));

  LongType rank = inputShapeInfo[0];
  if (rank >= 2 && rank <= 4) {
    auto shape = inputShapeInfo + 1;
    auto shape2 = outputShapeInfo + 1;
    if (rank == 2) {
      // rank2 copy  [h, w] <- permute([a, b])
      auto strideH = shape2[rank];
      constexpr decltype(strideH) strideW = 1;

      auto stridePermH = shape[rank + permutation[0]];
      auto stridePermW = shape[rank + permutation[1]];
      auto shapeH = shape2[0];
      auto shapeW = shape2[1];

#pragma omp parallel for
      for (decltype(shapeH) h = 0; h < shapeH; h++)
        for (decltype(shapeW) w = 0; w < shapeW; w++) {
          out[h * strideH + w * strideW] = in[h * stridePermH + w * stridePermW];
        }

    } else if (rank == 3) {
      // rank3 copy  [g, h, w] <- permute([a, b, c])
      auto strideG = shape2[rank];
      auto strideH = shape2[rank + 1];
      constexpr decltype(strideH) strideW = 1;
      auto stridePermG = shape[rank + permutation[0]];
      auto stridePermH = shape[rank + permutation[1]];
      auto stridePermW = shape[rank + permutation[2]];
      auto shapeG = shape2[0];
      auto shapeH = shape2[1];
      auto shapeW = shape2[2];
#pragma omp parallel for
      for (decltype(shapeG) g = 0; g < shapeG; g++)
        for (decltype(shapeH) h = 0; h < shapeH; h++)
          for (decltype(shapeW) w = 0; w < shapeW; w++) {
            out[g * strideG + h * strideH + w * strideW] = in[g * stridePermG + h * stridePermH + w * stridePermW];
          }

    } else {
      // rank 4 copy  [f, g, h, w] <- permute([a, b, c, d])
      auto strideF = shape2[rank];
      auto strideG = shape2[rank + 1];
      auto strideH = shape2[rank + 2];
      constexpr decltype(strideH) strideW = 1;
      auto stridePermF = shape[rank + permutation[0]];
      auto stridePermG = shape[rank + permutation[1]];
      auto stridePermH = shape[rank + permutation[2]];
      auto stridePermW = shape[rank + permutation[3]];
      auto shapeF = shape2[0];
      auto shapeG = shape2[1];
      auto shapeH = shape2[2];
      auto shapeW = shape2[3];
#pragma omp parallel for
      for (decltype(shapeF) f = 0; f < shapeF; f++)
        for (decltype(shapeG) g = 0; g < shapeG; g++)
          for (decltype(shapeH) h = 0; h < shapeH; h++)
            for (decltype(shapeW) w = 0; w < shapeW; w++) {
              out[f * strideF + g * strideG + h * strideH + w * strideW] =
                  in[f * stridePermF + g * stridePermG + h * stridePermH + w * stridePermW];
            }
    }

    RETURN(0);
  }

  RETURN(-1);
}

uint64_t vedaPadConstantRank4(LongType *inputShapeInfo, VEDAdeviceptr vIn, LongType *outputShapeInfo, VEDAdeviceptr vO,
                              const LongType paddingOffset, float padValue) {
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();

  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));

  LongType rank = inputShapeInfo[0];
  if (rank == 4) {
    auto shape = inputShapeInfo + 1;
    auto shape2 = outputShapeInfo + 1;
    // rank 4 copy  [f, g, h, w] <- permute([a, b, c, d])
    auto strideInF = shape[rank + 0];
    auto strideInG = shape[rank + 1];
    auto strideInH = shape[rank + 2];
    constexpr decltype(strideInH) strideInW = 1;
    auto strideF = shape2[rank + 0];
    auto strideG = shape2[rank + 1];
    auto strideH = shape2[rank + 2];
    constexpr decltype(strideH) strideW = 1;
    auto shapeF = shape[0];
    auto shapeG = shape[1];
    auto shapeH = shape[2];
    auto shapeW = shape[3];
    if (paddingOffset != 0) {
      // For now, assign value to the whole
      LongType length = shape2[0] * shape2[1] * shape2[2] * shape2[3];
      for (LongType i = 0L; i < length; i++) {
        out[i] = padValue;
      }
    }
    // copy the core into Output
    auto offsetedOut = out + paddingOffset;
#pragma omp parallel for
    for (decltype(shapeF) f = 0; f < shapeF; f++)
      for (decltype(shapeG) g = 0; g < shapeG; g++)
        for (decltype(shapeH) h = 0; h < shapeH; h++)
          for (decltype(shapeW) w = 0; w < shapeW; w++) {
            offsetedOut[f * strideF + g * strideG + h * strideH + w * strideW] =
                in[f * strideInF + g * strideInG + h * strideInH + w * strideInW];
          }
  }

  RETURN(0);
}

uint64_t vedaExpF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
  CHECK_ERROR_BEFORE_EXEC();
  float *in, *out;
  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  for (uint64_t i = 0; i < length; i++) {
    out[i] = exp(in[i]);
  }
  RETURN(0);
}

uint64_t vedaLogF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();
  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  for (uint64_t i = 0; i < length; i++) {
    out[i] = log(in[i]);
  }
  RETURN(0);
}

uint64_t vedaTanhF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();
  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  for (uint64_t i = 0; i < length; i++) {
    out[i] = tanh(in[i]);
  }
  RETURN(0);
}

uint64_t vedaSigmoidF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();
  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  for (uint64_t i = 0; i < length; i++) {
    out[i] = 1.0f / (1.0f + exp(-in[i]));
  }
  RETURN(0);
}

uint64_t vedaLeakyRELUF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO, float alpha) {
  float *in, *out;
  CHECK_ERROR_BEFORE_EXEC();
  CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
  CHECK_FUNC(vedaMemPtr((void **)&out, vO));
  for (uint64_t i = 0; i < length; i++) {
    auto val = in[i];
    out[i] = val < 0.0f ? alpha * val : val;
  }
  RETURN(0);
}

}  // extern "C"