libnd4j/include/ops/declarable/platform/vednn/veda_vednn.vcpp
#include <stdint.h>
#include <stdio.h>
#include <veda_device.h>
#include <vednn.h>
#include <cmath>
#include <iostream>
#include <memory>
using LongType = long long;
//#define SHOW_ON_FUNC_ENTRY 1
#if !defined(SHOW_ON_FUNC_ENTRY)
#define LOG_FUNC()
#else
#define LOG_FUNC() printf("%s in [%s %d]\n", __PRETTY_FUNCTION__, __FILE__, __LINE__)
#endif
// Note: it should be fine to use one status for veda OMP mode
static uint64_t last_status = 0;
#define CHECK_ERROR_BEFORE_EXEC() \
do { \
if (last_status != 0) return last_status; \
} while (0)
#define CHECK_FUNC(res) \
do { \
if (last_status == 0) { \
if (res != 0) { \
RETURN(res); \
} \
} \
} while (0)
#define RETURN(res) \
do { \
return set_return((uint64_t)res, __PRETTY_FUNCTION__, __LINE__); \
} while (0)
inline uint64_t set_return(uint64_t res, const char *msg, int line) {
last_status = res;
if (res != 0) {
printf("%s %d result code: [%d]\n", msg, line, (int)res);
}
return res;
}
inline void copyTo_nhwc_generic(const vednnTensorParam_t &p, const float *nchw, float *nhwc) {
int hs = p.width;
int cs = p.width * p.height;
int ns = cs * p.channel;
int ws2 = p.channel;
int hs2 = ws2 * p.width;
LOG_FUNC();
#pragma omp parallel for
for (int n = 0; n < p.batch; n++) {
for (int h = 0; h < p.height; h++) {
for (int w = 0; w < p.width; w++) {
for (int c = 0; c < p.channel; c++) {
nhwc[n * ns + h * hs2 + w * ws2 + c] = nchw[n * ns + h * hs + c * cs + w];
}
}
}
}
LOG_FUNC();
}
inline void copyTo_nchw_generic(const vednnTensorParam_t &p, const float *nhwc, float *nchw) {
constexpr int cs = 1;
int ws = p.channel;
int hs = ws * p.width;
int ns = hs * p.height;
LOG_FUNC();
constexpr int ws1 = 1;
int hs1 = p.width;
int cs1 = hs1 * p.height;
int ns1 = cs1 * p.channel;
#pragma omp parallel for
for (int n = 0; n < p.batch; n++) {
for (int h = 0; h < p.height; h++) {
for (int w = 0; w < p.width; w++) {
for (int c = 0; c < p.channel; c++) {
nchw[n * ns1 + h * hs1 + w * ws1 + c * cs1] = nhwc[n * ns + h * hs + w * ws + c * cs];
}
}
}
}
LOG_FUNC();
}
void copyIntoNHWC(const vednnTensorParam_t ¶m, const float *nchw_data, float *nhwc_data) {
return copyTo_nhwc_generic(param, nchw_data, nhwc_data);
}
float *getNCHW(const vednnTensorParam_t ¶m, float *nhwc_data, std::unique_ptr<float[]> &temp) {
if (param.channel == 1) {
// there is not any need for conversion
return nhwc_data;
} else {
LOG_FUNC();
int hwSize = param.height * param.width;
int strideN = hwSize * param.channel;
size_t length = param.batch * strideN;
temp.reset(new float[length]);
float *nchw_data = temp.get();
copyTo_nchw_generic(param, nhwc_data, nchw_data);
LOG_FUNC();
return nchw_data;
}
}
void showBuffer(float *x, int l) {
for (int i = 0; i < l; i++) std::cout << x[i] << ", ";
std::cout << std::endl;
}
float *getWeightFormat1Data(const vednnFilterParam_t ¶mFilter, float *weight, int wFormat,
std::unique_ptr<float[]> &temp) {
// 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC]
if (wFormat == 1) {
return weight;
} else {
if (wFormat == 2) {
LOG_FUNC();
//[oC, kH, kW, iC] -> [oC, iC, kH, kW],
vednnTensorParam_t param;
param.dtype = DTYPE_FLOAT;
param.batch = paramFilter.outChannel;
param.channel = paramFilter.inChannel;
param.height = paramFilter.height;
param.width = paramFilter.width;
auto w = getNCHW(param, weight, temp);
LOG_FUNC();
return w;
} else {
//[kH, kW, iC, oC] -> [oC, iC, kH, kW]
LOG_FUNC();
constexpr int ocs0 = 1;
int ics0 = paramFilter.outChannel;
int ws0 = ics0 * paramFilter.inChannel;
int hs0 = ws0 * paramFilter.width;
size_t length = hs0 * paramFilter.height;
temp.reset(new float[length]);
float *ret = temp.get();
constexpr int ws1 = 1;
int hs1 = paramFilter.width;
int ics1 = hs1 * paramFilter.height;
int ocs1 = ics1 * paramFilter.inChannel;
#pragma omp parallel for
for (int h = 0; h < paramFilter.height; h++)
for (int w = 0; w < paramFilter.width; w++)
for (int i = 0; i < paramFilter.inChannel; i++)
for (int j = 0; j < paramFilter.outChannel; j++) {
ret[j * ocs1 + i * ics1 + w * ws1 + h * hs1] = weight[j * ocs0 + i * ics0 + w * ws0 + h * hs0];
}
// }
LOG_FUNC();
return ret;
}
}
}
// copy WeightFormat2 into WeightFormat1
void copyWf2ToWf1(const vednnFilterParam_t ¶mFilter, const float *weight, float *weightOut) {
vednnTensorParam_t param;
param.dtype = DTYPE_FLOAT;
param.batch = paramFilter.outChannel;
param.channel = paramFilter.inChannel;
param.height = paramFilter.height;
param.width = paramFilter.width;
copyIntoNHWC(param, weight, weightOut);
}
void copyWf0ToWf1(const vednnFilterParam_t ¶mFilter, const float *weight, float *weightOut) {
constexpr int ocs0 = 1;
int ics0 = paramFilter.outChannel;
int ws0 = ics0 * paramFilter.inChannel;
int hs0 = ws0 * paramFilter.width;
constexpr int ws1 = 1;
int hs1 = paramFilter.width;
int ics1 = hs1 * paramFilter.height;
int ocs1 = ics1 * paramFilter.inChannel;
#pragma omp parallel for
for (int h = 0; h < paramFilter.height; h++)
for (int w = 0; w < paramFilter.width; w++)
for (int i = 0; i < paramFilter.inChannel; i++)
for (int j = 0; j < paramFilter.outChannel; j++) {
weightOut[j * ocs0 + i * ics0 + w * ws0 + h * hs0] = weight[j * ocs1 + i * ics1 + w * ws1 + h * hs1];
}
}
extern "C" {
void showBufferVeAsFloat(VEDAdeviceptr v) {
#if defined(DEBUG_VEDA_LOGS)
size_t size = 0;
void *in;
std::cout << "ve " << (void *)v << "\n";
if (v) {
vedaMemPtrSize(&in, &size, v);
float *x = (float *)in;
size = size > 80 ? 80 : 0;
for (int i = 0; i < size / sizeof(float); i++) std::cout << x[i] << ", ";
} else {
std::cout << "null ";
}
std::cout << std::endl;
#endif
}
uint64_t vedaVednnConvolutionForwardAddBias(const vednnTensorParam_t *paramIn, VEDAdeviceptr vDataIn,
uint8_t isDataInNCHW, const vednnFilterParam_t *paramFilter,
VEDAdeviceptr vDataKernel, int32_t weightFormat,
const vednnBiasParam_t *paramBias, VEDAdeviceptr vDataBias,
const vednnTensorParam_t *paramOut, VEDAdeviceptr vDataOut,
uint8_t isDataOutNCHW, const vednnConvolutionParam_t *paramConv,
vednnConvolutionAlgorithm_t algo) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
vednnError_t res;
void *pDataIn, *pDataBias = nullptr, *pDataKernelPtr;
void *pDataOutPtr, *pDataOut = nullptr;
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataKernelPtr, vDataKernel));
if (vDataBias) CHECK_FUNC(vedaMemPtr((void **)&pDataBias, vDataBias));
CHECK_FUNC(vedaMemPtr((void **)&pDataOutPtr, vDataOut));
#if defined(DEBUG_VEDA_LOGS)
showBufferVeAsFloat(vDataIn);
showBufferVeAsFloat(vDataKernel);
showBufferVeAsFloat(vDataBias);
#endif
std::unique_ptr<float[]> tempIn, tempOut, tempW;
if (!isDataInNCHW) {
pDataIn = getNCHW(*paramIn, (float *)pDataIn, tempIn);
}
if (!isDataOutNCHW) {
tempOut.reset(new float[paramOut->batch * paramOut->channel * paramOut->height * paramOut->width]);
pDataOut = tempOut.get();
} else {
pDataOut = pDataOutPtr;
}
auto pDataKernel = getWeightFormat1Data(*paramFilter, (float *)pDataKernelPtr, weightFormat, tempW);
if (pDataBias) {
// printf("%s\n", "bias case");
res = vednnConvolutionForwardAddBias(paramIn, pDataIn, paramFilter, pDataKernel, paramBias, pDataBias, paramOut,
pDataOut, paramConv, algo);
} else {
res = vednnConvolutionForward(paramIn, pDataIn, paramFilter, pDataKernel, paramOut, pDataOut, paramConv, algo);
}
if (pDataOut != pDataOutPtr) {
copyIntoNHWC(*paramOut, (const float *)pDataOut, (float *)pDataOutPtr);
}
RETURN(res);
}
uint64_t vedaVednnConvolutionBackwardDataAndFilter(const vednnTensorParam_t *paramGradOut, VEDAdeviceptr vGradOutData,
const vednnFilterParam_t *paramFilter, VEDAdeviceptr vWeightData,
int32_t weightFormat, VEDAdeviceptr vGradWeightData,
const vednnTensorParam_t *paramGradIn, VEDAdeviceptr vInData,
VEDAdeviceptr vGradInData, uint8_t isNCHW, VEDAdeviceptr vGradBias,
const vednnConvolutionParam_t *paramConv,
vednnConvolutionAlgorithm_t algo) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
void *gradOutData, *weightData, *gradWeightsData, *inData, *gradInData, *gradInDataFormatted;
float *gradBias = nullptr;
CHECK_FUNC(vedaMemPtr((void **)&gradOutData, vGradOutData));
CHECK_FUNC(vedaMemPtr((void **)&weightData, vWeightData));
CHECK_FUNC(vedaMemPtr((void **)&gradWeightsData, vGradWeightData));
CHECK_FUNC(vedaMemPtr((void **)&inData, vInData));
CHECK_FUNC(vedaMemPtr((void **)&gradInData, vGradInData));
if (vGradBias) CHECK_FUNC(vedaMemPtr((void **)&gradBias, vGradBias));
gradInDataFormatted = gradInData;
// temporary memory holders for the case when we need formatted buffers
std::unique_ptr<float[]> tempW, tempIn, tempGradIn, tempGradOut;
if (!isNCHW) {
inData = getNCHW(*paramGradIn, (float *)inData, tempIn);
gradOutData = getNCHW(*paramGradOut, (float *)gradOutData, tempGradOut);
gradInDataFormatted = getNCHW(*paramGradIn, (float *)gradInData, tempGradIn);
}
auto weightDataFormatted = getWeightFormat1Data(*paramFilter, (float *)weightData, weightFormat, tempW);
vednnError_t res = vednnConvolutionBackwardData(paramGradOut, gradOutData, paramFilter, weightDataFormatted,
paramGradIn, gradInDataFormatted, paramConv, algo);
if (res != VEDNN_SUCCESS) return res;
if (gradInDataFormatted != gradInData) {
copyIntoNHWC(*paramGradIn, (const float *)gradInDataFormatted, (float *)gradInData);
}
// paramGradIn could be used for "in"
// paramFilter could be used for "gradWeights"
if (weightDataFormatted == weightData) {
res = vednnConvolutionBackwardFilter(paramGradIn, inData, paramGradOut, gradOutData, paramFilter, gradWeightsData,
paramConv, algo);
} else {
std::unique_ptr<float[]> tempWeightsData;
auto len = paramFilter->outChannel * paramFilter->inChannel * paramFilter->height * paramFilter->width;
tempWeightsData.reset(new float[len]);
float *gradWeightsDataLocal = tempWeightsData.get();
res = vednnConvolutionBackwardFilter(paramGradIn, inData, paramGradOut, gradOutData, paramFilter,
gradWeightsDataLocal, paramConv, algo);
// [oC, iC, kH, kW] -> [kH, kW, iC, oC]
if (weightFormat == 0) copyWf0ToWf1(*paramFilter, (const float *)gradWeightsDataLocal, (float *)gradWeightsData);
// [oC, iC, kH, kW] -> [oC, kH, kW, iC]
else
copyWf2ToWf1(*paramFilter, (const float *)gradWeightsDataLocal, (float *)gradWeightsData);
}
// calculate Bias
if (gradBias) {
//// sum formatted gradOutData over bS, oH, oW
int height_weight = paramGradOut->width * paramGradOut->height;
int ns1 = height_weight * paramGradOut->channel;
for (int c = 0; c < paramGradOut->channel; c++) {
gradBias[c] = 0;
}
auto nhwc_c = (float *)gradOutData;
for (int n = 0; n < paramGradOut->batch; n++) {
for (int c = 0; c < paramGradOut->channel; c++) {
auto sum = 0.f;
for (int hw = 0; hw < height_weight; hw++) {
sum += nhwc_c[hw];
}
gradBias[c] += sum;
nhwc_c += height_weight;
}
}
}
RETURN(res);
}
uint64_t vedaVednnActivationForward(const vednnActivationMode_t mode, VEDAdeviceptr vDataIn, VEDAdeviceptr vDataOut,
const uint64_t nElements) {
LOG_FUNC();
CHECK_ERROR_BEFORE_EXEC();
void *pDataIn;
void *pDataOut;
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
auto res = vednnActivationForward(mode, pDataIn, pDataOut, nElements);
RETURN(res);
}
uint64_t vedaVednnActivationBackward(const vednnActivationMode_t mode, VEDAdeviceptr vDataGradOut,
VEDAdeviceptr vDataIn, VEDAdeviceptr vDataGradIn, const uint64_t nElements) {
LOG_FUNC();
CHECK_ERROR_BEFORE_EXEC();
void *pDataGradOut;
void *pDataIn;
void *pDataGradIn;
CHECK_FUNC(vedaMemPtr((void **)&pDataGradOut, vDataGradOut));
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataGradIn, vDataGradIn));
auto res = vednnActivationBackward(mode, pDataGradOut, pDataIn, pDataGradIn, nElements);
RETURN(res);
}
uint64_t vedaVednnSoftmaxForward(const vednnSoftmaxMode_t mode, VEDAdeviceptr vDataIn, VEDAdeviceptr vDataOut,
const uint64_t nBatch, const uint64_t nClass) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
void *pDataIn;
void *pDataOut;
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
return vednnSoftmaxForward(mode, pDataIn, pDataOut, nBatch, nClass);
}
uint64_t vedaVednnLinearForwardExF32(uint64_t bGemm, const uint64_t inDim, const uint64_t outDim, const uint64_t nBatch,
VEDAdeviceptr vX, const uint64_t xStride, VEDAdeviceptr vY, const uint64_t yStride,
VEDAdeviceptr vZ, const uint64_t zStride) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
vednnError_t res;
float *x, *y;
float *z;
CHECK_FUNC(vedaMemPtr((void **)&x, vX));
CHECK_FUNC(vedaMemPtr((void **)&y, vY));
CHECK_FUNC(vedaMemPtr((void **)&z, vZ));
if (bGemm == 1) {
RETURN(vednnLinearForward(inDim, outDim, nBatch, 1, x, y, z));
} else {
// because of the bgemm did not work as expected, we will manually parallelize over bGemm
//#pragma omp parallel for
for (int i = 0; i < bGemm; i++) {
float *xPtr = x + i * xStride;
float *yPtr = y + i * yStride;
float *zPtr = z + i * zStride;
vednnLinearForward(inDim, outDim, nBatch, 1, xPtr, yPtr, zPtr);
}
// WARNING: we will silently return success
RETURN(VEDNN_SUCCESS);
}
}
uint64_t vedaVednnMaxPoolingForward(const vednnTensorParam_t *pParamIn, VEDAdeviceptr vDataIn,
const vednnTensorParam_t *pParamOut, VEDAdeviceptr vDataOut,
const vednnPoolingParam_t *pParamPool) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
void *pDataIn;
void *pDataOut;
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
#if defined(DEBUG_VEDA_LOGS)
showBufferVeAsFloat(vDataIn);
#endif
RETURN(vednnMaxPoolingForward(pParamIn, pDataIn, pParamOut, pDataOut, pParamPool));
}
uint64_t vedaVednnMaxPoolingBackwardEx(const vednnTensorParam_t *pParamGradOut, VEDAdeviceptr vDataGradOut,
const vednnTensorParam_t *pParamOut, VEDAdeviceptr vDataOut,
const vednnTensorParam_t *pParamIn, VEDAdeviceptr vDataIn,
const vednnTensorParam_t *pParamGradIn, VEDAdeviceptr vDataGradIn,
const vednnPoolingParam_t *pParamPool) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
void *pDataGradOut, *pDataIn, *pDataGradIn, *pDataOut;
CHECK_FUNC(vedaMemPtr((void **)&pDataGradOut, vDataGradOut));
CHECK_FUNC(vedaMemPtr((void **)&pDataIn, vDataIn));
CHECK_FUNC(vedaMemPtr((void **)&pDataOut, vDataOut));
CHECK_FUNC(vedaMemPtr((void **)&pDataGradIn, vDataGradIn));
vednnError_t res = vednnMaxPoolingForward(pParamIn, pDataIn, pParamOut, pDataOut, pParamPool);
if (res == VEDNN_SUCCESS) {
vednnMaxPoolingBackward(pParamGradOut, pDataGradOut, pParamOut, pDataOut, pParamIn, pDataIn, pParamGradIn,
pDataGradIn, pParamPool);
}
RETURN(res);
}
uint64_t vedaConcatUpTo32(const uint64_t nInput, VEDAdeviceptr *inputList, uint64_t *inputLengthInBytesList,
VEDAdeviceptr vO) {
CHECK_ERROR_BEFORE_EXEC();
LOG_FUNC();
if (nInput > 0 && nInput <= 32) {
//WARNING: we make it as uint32_t* because it is vectorizable
//For now please pass only data types divisible by sizeof(uint32_t)
uint32_t *output;
CHECK_FUNC(vedaMemPtr((void **)&output, vO));
struct OmpArgs {
uint32_t *in;
uint32_t *out;
uint64_t size;
};
OmpArgs zPtrList[32];
for (uint64_t i = 0; i < nInput; i++) {
uint32_t *in;
CHECK_FUNC(vedaMemPtr((void **)&in, inputList[i]));
auto lengthOf = inputLengthInBytesList[i]/sizeof(uint32_t);
zPtrList[i].in = in;
zPtrList[i].out = output;
zPtrList[i].size = lengthOf ;
output += lengthOf;
}
#pragma omp parallel for
for (int i = 0; i < nInput; ++i) {
auto inputPtr = zPtrList[i].in;
auto outPtr = zPtrList[i].out;
uint64_t size = zPtrList[i].size;
for (uint64_t j = 0; j < size; j++) {
outPtr[j] = inputPtr[j];
}
}
RETURN(0);
}
RETURN(-1);
}
uint64_t vedaAdd_A(uint64_t length0, VEDAdeviceptr vIn0, uint64_t length1, VEDAdeviceptr vIn1, VEDAdeviceptr vO) {
CHECK_ERROR_BEFORE_EXEC();
uint64_t min_len = 0;
uint64_t max_len = 0;
float *big = nullptr;
float *small = nullptr;
float *out = nullptr;
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
if (length1 > length0) {
min_len = length0;
max_len = length1;
CHECK_FUNC(vedaMemPtr((void **)&small, vIn0));
CHECK_FUNC(vedaMemPtr((void **)&big, vIn1));
} else {
min_len = length1;
max_len = length0;
CHECK_FUNC(vedaMemPtr((void **)&small, vIn1));
CHECK_FUNC(vedaMemPtr((void **)&big, vIn0));
}
if (min_len == 1) {
auto val = small[0];
for (uint64_t i = 0L; i < max_len; i++) {
out[i] = big[i] + val;
}
} else {
int times = (int)(max_len / min_len);
auto copy_times = 4096 / min_len;
int times_k = (int)(times / copy_times);
if (min_len < 4096 && times_k > 10) {
float local_ll[4096];
// copy into buffer
float *ll = local_ll;
for (int i = 0; i < copy_times; i++) {
for (int j = 0; j < min_len; j++) {
ll[j] = small[j];
}
ll += min_len;
}
auto llen = ll - local_ll;
int times_tail = times % copy_times;
for (int i = 0; i < times_k; i++) {
auto out_inner = &(out[i * llen]);
auto big_inner = &(big[i * llen]);
for (uint64_t j = 0; j < llen; j++) {
out_inner[j] = big_inner[j] + local_ll[j];
}
}
if (times_tail > 0) {
auto out_inner = &(out[times_k * llen]);
auto big_inner = &(big[times_k * llen]);
for (uint64_t j = 0; j < times_tail * min_len; j++) {
out_inner[j] = big_inner[j] + local_ll[j];
}
}
} else {
for (int i = 0; i < times; i++) {
auto out_inner = &(out[i * min_len]);
auto big_inner = &(big[i * min_len]);
for (uint64_t j = 0; j < min_len; j++) {
out_inner[j] = big_inner[j] + small[j];
}
}
}
}
RETURN(0);
}
uint64_t vedaMult_A(uint64_t length0, VEDAdeviceptr vIn0, uint64_t length1, VEDAdeviceptr vIn1, VEDAdeviceptr vO) {
uint64_t min_len = 0;
uint64_t max_len = 0;
float *big = nullptr;
float *small = nullptr;
float *out = nullptr;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
if (length1 > length0) {
min_len = length0;
max_len = length1;
CHECK_FUNC(vedaMemPtr((void **)&small, vIn0));
CHECK_FUNC(vedaMemPtr((void **)&big, vIn1));
} else {
min_len = length1;
max_len = length0;
CHECK_FUNC(vedaMemPtr((void **)&small, vIn1));
CHECK_FUNC(vedaMemPtr((void **)&big, vIn0));
}
if (min_len == 1) {
auto val = small[0];
for (uint64_t i = 0L; i < max_len; i++) {
out[i] = big[i] * val;
}
} else {
int times = (int)(max_len / min_len);
auto copy_times = 4096 / min_len;
int times_k = (int)(times / copy_times);
if (min_len < 4096 && times_k > 10) {
float local_ll[4096];
// copy into buffer
float *ll = local_ll;
for (int i = 0; i < copy_times; i++) {
for (int j = 0; j < min_len; j++) {
ll[j] = small[j];
}
ll += min_len;
}
auto llen = ll - local_ll;
int times_tail = times % copy_times;
for (int i = 0; i < times_k; i++) {
auto out_inner = &(out[i * llen]);
auto big_inner = &(big[i * llen]);
for (uint64_t j = 0; j < llen; j++) {
out_inner[j] = big_inner[j] * local_ll[j];
}
}
if (times_tail > 0) {
auto out_inner = &(out[times_k * llen]);
auto big_inner = &(big[times_k * llen]);
for (uint64_t j = 0; j < times_tail * min_len; j++) {
out_inner[j] = big_inner[j] * local_ll[j];
}
}
} else {
for (int i = 0; i < times; i++) {
auto out_inner = &(out[i * min_len]);
auto big_inner = &(big[i * min_len]);
for (uint64_t j = 0; j < min_len; j++) {
out_inner[j] = big_inner[j] * small[j];
}
}
}
}
RETURN(0);
}
uint64_t vedaPermuteAssignRank2_4(LongType *inputShapeInfo, VEDAdeviceptr vIn, LongType *outputShapeInfo,
VEDAdeviceptr vO, const int *permutation) {
// C order and ews Unchecked call, caller should check for the conditions
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
LongType rank = inputShapeInfo[0];
if (rank >= 2 && rank <= 4) {
auto shape = inputShapeInfo + 1;
auto shape2 = outputShapeInfo + 1;
if (rank == 2) {
// rank2 copy [h, w] <- permute([a, b])
auto strideH = shape2[rank];
constexpr decltype(strideH) strideW = 1;
auto stridePermH = shape[rank + permutation[0]];
auto stridePermW = shape[rank + permutation[1]];
auto shapeH = shape2[0];
auto shapeW = shape2[1];
#pragma omp parallel for
for (decltype(shapeH) h = 0; h < shapeH; h++)
for (decltype(shapeW) w = 0; w < shapeW; w++) {
out[h * strideH + w * strideW] = in[h * stridePermH + w * stridePermW];
}
} else if (rank == 3) {
// rank3 copy [g, h, w] <- permute([a, b, c])
auto strideG = shape2[rank];
auto strideH = shape2[rank + 1];
constexpr decltype(strideH) strideW = 1;
auto stridePermG = shape[rank + permutation[0]];
auto stridePermH = shape[rank + permutation[1]];
auto stridePermW = shape[rank + permutation[2]];
auto shapeG = shape2[0];
auto shapeH = shape2[1];
auto shapeW = shape2[2];
#pragma omp parallel for
for (decltype(shapeG) g = 0; g < shapeG; g++)
for (decltype(shapeH) h = 0; h < shapeH; h++)
for (decltype(shapeW) w = 0; w < shapeW; w++) {
out[g * strideG + h * strideH + w * strideW] = in[g * stridePermG + h * stridePermH + w * stridePermW];
}
} else {
// rank 4 copy [f, g, h, w] <- permute([a, b, c, d])
auto strideF = shape2[rank];
auto strideG = shape2[rank + 1];
auto strideH = shape2[rank + 2];
constexpr decltype(strideH) strideW = 1;
auto stridePermF = shape[rank + permutation[0]];
auto stridePermG = shape[rank + permutation[1]];
auto stridePermH = shape[rank + permutation[2]];
auto stridePermW = shape[rank + permutation[3]];
auto shapeF = shape2[0];
auto shapeG = shape2[1];
auto shapeH = shape2[2];
auto shapeW = shape2[3];
#pragma omp parallel for
for (decltype(shapeF) f = 0; f < shapeF; f++)
for (decltype(shapeG) g = 0; g < shapeG; g++)
for (decltype(shapeH) h = 0; h < shapeH; h++)
for (decltype(shapeW) w = 0; w < shapeW; w++) {
out[f * strideF + g * strideG + h * strideH + w * strideW] =
in[f * stridePermF + g * stridePermG + h * stridePermH + w * stridePermW];
}
}
RETURN(0);
}
RETURN(-1);
}
uint64_t vedaPadConstantRank4(LongType *inputShapeInfo, VEDAdeviceptr vIn, LongType *outputShapeInfo, VEDAdeviceptr vO,
const LongType paddingOffset, float padValue) {
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
LongType rank = inputShapeInfo[0];
if (rank == 4) {
auto shape = inputShapeInfo + 1;
auto shape2 = outputShapeInfo + 1;
// rank 4 copy [f, g, h, w] <- permute([a, b, c, d])
auto strideInF = shape[rank + 0];
auto strideInG = shape[rank + 1];
auto strideInH = shape[rank + 2];
constexpr decltype(strideInH) strideInW = 1;
auto strideF = shape2[rank + 0];
auto strideG = shape2[rank + 1];
auto strideH = shape2[rank + 2];
constexpr decltype(strideH) strideW = 1;
auto shapeF = shape[0];
auto shapeG = shape[1];
auto shapeH = shape[2];
auto shapeW = shape[3];
if (paddingOffset != 0) {
// For now, assign value to the whole
LongType length = shape2[0] * shape2[1] * shape2[2] * shape2[3];
for (LongType i = 0L; i < length; i++) {
out[i] = padValue;
}
}
// copy the core into Output
auto offsetedOut = out + paddingOffset;
#pragma omp parallel for
for (decltype(shapeF) f = 0; f < shapeF; f++)
for (decltype(shapeG) g = 0; g < shapeG; g++)
for (decltype(shapeH) h = 0; h < shapeH; h++)
for (decltype(shapeW) w = 0; w < shapeW; w++) {
offsetedOut[f * strideF + g * strideG + h * strideH + w * strideW] =
in[f * strideInF + g * strideInG + h * strideInH + w * strideInW];
}
}
RETURN(0);
}
uint64_t vedaExpF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
CHECK_ERROR_BEFORE_EXEC();
float *in, *out;
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
for (uint64_t i = 0; i < length; i++) {
out[i] = exp(in[i]);
}
RETURN(0);
}
uint64_t vedaLogF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
for (uint64_t i = 0; i < length; i++) {
out[i] = log(in[i]);
}
RETURN(0);
}
uint64_t vedaTanhF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
for (uint64_t i = 0; i < length; i++) {
out[i] = tanh(in[i]);
}
RETURN(0);
}
uint64_t vedaSigmoidF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO) {
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
for (uint64_t i = 0; i < length; i++) {
out[i] = 1.0f / (1.0f + exp(-in[i]));
}
RETURN(0);
}
uint64_t vedaLeakyRELUF32(uint64_t length, VEDAdeviceptr vIn, VEDAdeviceptr vO, float alpha) {
float *in, *out;
CHECK_ERROR_BEFORE_EXEC();
CHECK_FUNC(vedaMemPtr((void **)&in, vIn));
CHECK_FUNC(vedaMemPtr((void **)&out, vO));
for (uint64_t i = 0; i < length; i++) {
auto val = in[i];
out[i] = val < 0.0f ? alpha * val : val;
}
RETURN(0);
}
} // extern "C"