deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//   @author Adam Gibson
//

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_batched_gemm)

#include <ops/declarable/headers/blas.h>
#include <ops/declarable/helpers/batched_gemm.h>

namespace sd {
namespace ops {

CUSTOM_OP_IMPL(batched_gemm, -1, -1, false, 0, 9) {
  int transA = INT_ARG(0);
  int transB = INT_ARG(1);
  int M = INT_ARG(2);
  int N = INT_ARG(3);
  int K = INT_ARG(4);
  int ldA = INT_ARG(5);
  int ldB = INT_ARG(6);
  int ldC = INT_ARG(7);
  int batchSize = INT_ARG(8);
  if(transA == 0) {
    int ldaComp = M > 1 ? M : 1;
    if(ldA < ldaComp) THROW_EXCEPTION("LDA must be >= max(1,m) when transa == false");
  } else {
    int ldaComp = K > 1 ? K : 1;
    if(ldA < ldaComp)
      THROW_EXCEPTION("LDA must be >= max(1,k) when transa == true");
  }

  if(transB == 0) {
    int ldBComp = K > 1 ? K : 1;
    if(ldB < ldBComp) {
      THROW_EXCEPTION("LDB must be >= max(1,k) when transb == false");
    }
  } else {
    int ldbComp = N > 1 ? N : 1;
    if(ldB < ldbComp)
      THROW_EXCEPTION("LDB must be >= max(1,N) when transb == true");
  }

  int ldcComp = M > 1 ? M : 1;
  if(ldC < ldcComp) {
    THROW_EXCEPTION("LDC must be < max(1,M) when transc != false");
  }


  if (transA == 0) transA = 111;

  if (transB == 0) transB = 111;

  if (transA == 1) transA = 112;

  if (transB == 1) transB = 112;
  if(M < 0) THROW_EXCEPTION("M < 0");
  if(N < 0) THROW_EXCEPTION("N < 0");
  if(K < 0) THROW_EXCEPTION("K < 0");

  REQUIRE_TRUE((transA == 111 || transA == 112) && (transB == 111 || transB == 112), 0,
               "BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, for NoTrans/Trans respectively")
  REQUIRE_TRUE(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0, 0, "");

  auto alpha = INPUT_VARIABLE(0);
  NDArray *alphaInput = nullptr;
  if(alpha->isScalar()) {
    alphaInput = new NDArray('c',{batchSize},alpha->dataType());
    alphaInput->assign(alpha);
  } else {
    alphaInput = alpha;
  }


  auto beta = INPUT_VARIABLE(1);
  NDArray *betaInput = nullptr;
  if(beta->isScalar()) {
    betaInput = new NDArray('c',{batchSize},beta->dataType());
    betaInput->assign(beta);
  } else {
    betaInput = beta;
  }

  std::vector<NDArray*> vA(batchSize);
  std::vector<NDArray*> vB(batchSize);
  std::vector<NDArray*> vC(batchSize);

  auto firstType = INPUT_VARIABLE(0)->dataType();
  for (int e = 0; e < batchSize; e++) {
    vA[e] = INPUT_VARIABLE(e + 2);
    vB[e] = INPUT_VARIABLE(e + 2 + batchSize);
    vC[e] = OUTPUT_VARIABLE(e);

    REQUIRE_TRUE(firstType == vC[e]->dataType(), 0, "BatchedGemm: all inputs and outputs must have same data type");

    REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of A should be equal to 2", e);
    REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of B should be equal to 2", e);
    REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of C should be equal to 2", e);

    if(transA == 111) {
      REQUIRE_TRUE(M == vA[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of A.rows() should be equal to M transA: false", e);
      REQUIRE_TRUE(K == vA[e]->sizeAt(1) , 0,
                   "BatchedGemm: batch %i, number of A.columns() should be equal to K transA: false", e);
    } else  {
      REQUIRE_TRUE(M == vA[e]->sizeAt(1), 0, "BatchedGemm: batch %i, number of A.columns() should be equal to M transA: true", e);
      REQUIRE_TRUE(K == vA[e]->sizeAt(0) , 0,
                   "BatchedGemm: batch %i, number of A.rows() should be equal to K transA: true", e);
    }

    if(transB == 111) {
      REQUIRE_TRUE(N == vB[e]->sizeAt(1), 0, "BatchedGemm: batch %i, number of B.rows() should be equal to N transB: false", e);
      REQUIRE_TRUE(K == vA[e]->sizeAt(1) , 0,
                   "BatchedGemm: batch %i, number of B.rows() should be equal to K transB: false", e);
    } else {
      REQUIRE_TRUE(N == vB[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of B.columns() should be equal to N transB: true", e);
      REQUIRE_TRUE(K == vA[e]->sizeAt(1) , 0,
                   "BatchedGemm: batch %i, number of B.rows() should be equal to K transB: true", e);
    }
  }

  REQUIRE_TRUE(vA.size() == vB.size() && vA.size() == vC.size() && vA.size() == batchSize, 0,
               "BatchedGemm: mismatched numbers of A, B, C for unknown reason");

  sd::ops::helpers::bgemm(vA,
                          vB,
                          vC,
                          alphaInput,
                          betaInput,
                          transA,
                          transB,
                          M,
                          N,
                          K,
                          ldA,
                          ldB,
                          ldC);




  return sd::Status::OK;
};

DECLARE_SHAPE_FN(batched_gemm) {
  int transA = INT_ARG(0);
  int transB = INT_ARG(1);
  int M = INT_ARG(2);
  int N = INT_ARG(3);
  int K = INT_ARG(4);
  int ldA = INT_ARG(5);
  int ldB = INT_ARG(6);
  int ldC = INT_ARG(7);
  int batchSize = INT_ARG(8);
  auto firstInput = inputShape->at(2);
  auto secondInput =   inputShape->at(batchSize + 2);
  auto firstType = ArrayOptions::dataType(inputShape->at(0));
  for (int e = 1; e < block.width(); e++) {
    REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0,
                 "BatchedGemm: all inputs must have same data type");
  }

  auto shapeList = SHAPELIST();

  if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) {
    sd_printf("Invalid input shape returned. Something was 0. M: %d N: %d K %d ldA %d ldB %d ldC %d batchSize %d\n",M,N,K,ldA,ldB,ldC,batchSize);
    shapeList->push_back(
        ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1}));
    return shapeList;
  }

  std::vector<sd::LongType> shape({M, N});

  for (int e = 0; e < batchSize; e++) {
    auto newShape =
        ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'f', shape);
    shapeList->push_back(newShape);
  }

  return shapeList;
}

DECLARE_TYPES(batched_gemm) {
  getOpDescriptor()
      ->setAllowedInputTypes({ALL_FLOATS})
      ->setAllowedOutputTypes({ALL_FLOATS});
}



CUSTOM_OP_IMPL(batched_gemm_bp, -1, -1, false, 0, 9) {
  int transA = INT_ARG(0);
  int transB = INT_ARG(1);
  int M = INT_ARG(2);
  int N = INT_ARG(3);
  int K = INT_ARG(4);
  int ldA = INT_ARG(5);
  int ldB = INT_ARG(6);
  int ldC = INT_ARG(7);
  int batchSize = INT_ARG(8);

  batched_gemm batchedGemm;

  std::vector<NDArray *> matricesA;
  std::vector<NDArray *> matricesB;
  std::vector<NDArray *> dlDOut;
  std::vector<NDArray *> dldXOutputs;
  std::vector<NDArray *> dldYOutputs;

  for (int e = 0; e < batchSize; e++) {
    matricesA.push_back(INPUT_VARIABLE(e + 2));
    matricesB.push_back(INPUT_VARIABLE(e + 2 + batchSize));
    dlDOut.push_back(INPUT_VARIABLE(e + 2 + batchSize * 2));
    //alphas and betas are also set for outputs even though they're zero,every input needs a gradient
    dldXOutputs.push_back(OUTPUT_VARIABLE(e + 2));
    dldYOutputs.push_back(OUTPUT_VARIABLE(e + 2 + batchSize));
  }


  auto alpha = INPUT_VARIABLE(0);
  NDArray *alphaInput = nullptr;
  if(alpha->lengthOf() != batchSize) {
    alphaInput = new NDArray('c',{batchSize},alpha->dataType());
    alphaInput->assign(alpha);
  } else {
    alphaInput = alpha;
  }


  auto beta = INPUT_VARIABLE(1);
  NDArray *betaInput = nullptr;
  if(beta->lengthOf() != batchSize) {
    betaInput = new NDArray('c',{batchSize},beta->dataType());
    betaInput->assign(beta);
  } else {
    betaInput = beta;
  }


  int transA1 = 0;
  int transB1 = transB;
  int M1 = dlDOut[0]->sizeAt(0);
  int N1 = matricesB[0]->sizeAt(1);
  int k1 = dlDOut[0]->sizeAt(1);
  int lda1 = dlDOut[0]->sizeAt(0);
  int ldb1 = matricesB[0]->sizeAt(0);
  int ldc1 = dldXOutputs[0]->sizeAt(0);
  sd::ops::helpers::bgemm(dlDOut, matricesB, dldXOutputs, alphaInput, betaInput, transA1, transB1, M1, N1, k1, lda1, ldb1, ldc1);

  int transA2 = transA;
  int transB2 = 0;
  int M2 = matricesA[0]->sizeAt(0);
  int N2 = dlDOut[0]->sizeAt(1);
  int k2 = matricesA[0]->sizeAt(1);
  int lda2 = dlDOut[0]->sizeAt(0);
  int ldb2 = dlDOut[0]->sizeAt(0);
  int ldc2 = dlDOut[0]->sizeAt(0);
  sd::ops::helpers::bgemm(matricesA, dlDOut, dldYOutputs, alphaInput, betaInput, transA2, transB2, M2, N2, k2, lda2, ldb2, ldc2);


  if(alphaInput != alpha) {
    delete alphaInput;
  }

  if(betaInput != beta) {
    delete betaInput;
  }


  return sd::Status::OK;
};



DECLARE_SHAPE_FN(batched_gemm_bp) {
  sd::LongType *xShapeInfo;
  sd::LongType *yShapeInfo;
  int batchSize = INT_ARG(8);
  COPY_SHAPE(inputShape->at(2), xShapeInfo);
  COPY_SHAPE(inputShape->at(2 + batchSize), yShapeInfo);
  auto xConstant = CONSTANT(xShapeInfo);
  auto yConstant = CONSTANT(yShapeInfo);
  auto ret = SHAPELIST();
  //alpha
  ret->push_back(xConstant);
  //beta
  ret->push_back(yConstant);
  for(int i = 0; i < batchSize; i++) {
    ret->push_back(xConstant);
  }

  for(int i = 0; i < batchSize; i++) {
    ret->push_back(yConstant);
  }
  return ret;
}


DECLARE_TYPES(batched_gemm_bp) {
  getOpDescriptor()
      ->setAllowedInputTypes({ALL_FLOATS})
      ->setAllowedOutputTypes({ALL_FLOATS});
}


}  // namespace ops
}  // namespace sd

#endif