deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp

Summary

Maintainability
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  * See the NOTICE file distributed with this work for additional
 *  * information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

//
// @author Yurii Shyrma (iuriish@yahoo.com)
//

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_gru)

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/gru.h>

namespace sd {
namespace ops {

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
  auto x = INPUT_VARIABLE(0);   // input [time, bS, nIn]
  auto hI = INPUT_VARIABLE(1);  // initial cell output (at time step = 0) [bS, nOut]

  auto Wx = INPUT_VARIABLE(2);  // input-to-hidden  weights, [nIn, 3*nOut]
  auto Wh = INPUT_VARIABLE(3);  // hidden-to-hidden weights, [nOut, 3*nOut]
  auto b = INPUT_VARIABLE(4);   // biases, [3*nOut]

  auto h = OUTPUT_VARIABLE(0);  // cell outputs [time, bS, nOut], that is per each time step


  auto linearBeforeReset = block.numB() > 0 ? B_ARG(0) : false;

  const int bS = x->sizeAt(1);
  const int nIn = x->sizeAt(2);
  const int nOut = hI->sizeAt(1);

  const std::vector<sd::LongType> h0CorrectShape = {bS, nOut};
  const std::vector<sd::LongType> wxCorrectShape = {nIn, 3 * nOut};
  const std::vector<sd::LongType> whCorrectShape = {nOut, 3 * nOut};
  const std::vector<sd::LongType> bCorrectShape = {3 * nOut};

  REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0,
               "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0,
               "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0,
               "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
  REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0,
               "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());

  helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h, linearBeforeReset);

  return sd::Status::OK;
}

//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(gru) { getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); }

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(gru) {
  auto x = INPUT_VARIABLE(0);   // input [time, bS, nIn]
  auto hI = INPUT_VARIABLE(1);  // initial cell output (at time step = 0) [bS, nOut]

  auto Wx = INPUT_VARIABLE(2);  // input-to-hidden  weights, [nIn, 3*nOut]
  auto Wh = INPUT_VARIABLE(3);  // hidden-to-hidden weights, [nOut, 3*nOut]
  auto b = INPUT_VARIABLE(4);   // biases, [3*nOut]

  const int time = x->sizeAt(0);
  const int bS = x->sizeAt(1);
  const int nIn = x->sizeAt(2);
  const int nOut = hI->sizeAt(1);
  const std::vector<sd::LongType> h0CorrectShape = {bS, nOut};
  const std::vector<sd::LongType> wxCorrectShape = {nIn, 3 * nOut};
  const std::vector<sd::LongType> whCorrectShape = {nOut, 3 * nOut};
  const std::vector<sd::LongType> bCorrectShape = {3 * nOut};
 
  REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0,
               "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0,
               "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0,
               "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
  REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0,
               "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());

  auto hShapeInfo =
      ConstantShapeHelper::getInstance().createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut});

  return SHAPELIST(hShapeInfo);
}

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) {
  auto x = INPUT_VARIABLE(0);   // input [time, bS, nIn]
  auto hI = INPUT_VARIABLE(1);  // initial cell output (at time step = 0) [bS, nOut]

  auto Wx = INPUT_VARIABLE(2);  // input-to-hidden  weights, [nIn, 3*nOut]
  auto Wh = INPUT_VARIABLE(3);  // hidden-to-hidden weights, [nOut, 3*nOut]
  auto b = INPUT_VARIABLE(4);   // biases, [3*nOut]

  auto dLdh = INPUT_VARIABLE(5);  // gradient vs. ff output, [time, bS, nOut]

  auto dLdx = OUTPUT_VARIABLE(0);    // gradient vs. ff input, [time, bS, nIn]
  auto dLdhI = OUTPUT_NULLIFIED(1);  // gradient vs. initial cell output, [bS, nOut]
  auto dLdWx = OUTPUT_NULLIFIED(2);  // gradient vs. input-to-hidden  weights, [nIn, 3*nOut]
  auto dLdWh = OUTPUT_NULLIFIED(3);  // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut]
  auto dLdb = OUTPUT_NULLIFIED(4);   // gradient vs. biases [3*nOut]

  const int time = x->sizeAt(0);
  const int bS = x->sizeAt(1);
  const int nIn = x->sizeAt(2);
  const int nOut = hI->sizeAt(1);

  const std::vector<sd::LongType> h0CorrectShape = {bS, nOut};
  const std::vector<sd::LongType> wxCorrectShape = {nIn, 3 * nOut};
  const std::vector<sd::LongType> whCorrectShape = {nOut, 3 * nOut};
  const std::vector<sd::LongType> bCorrectShape = {3 * nOut};
  const std::vector<sd::LongType> hCorrectShape = {time, bS, nOut};

  REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0,
               "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0,
               "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0,
               "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
  REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0,
               "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
  REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0,
               "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());

  helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb);

  return sd::Status::OK;
}

//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(gru_bp) {
  getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
}

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(gru_bp) {
  auto x = INPUT_VARIABLE(0);   // input [time, bS, nIn]
  auto hI = INPUT_VARIABLE(1);  // initial cell output (at time step = 0) [bS, nOut]

  auto Wx = INPUT_VARIABLE(2);  // input-to-hidden  weights, [nIn, 3*nOut]
  auto Wh = INPUT_VARIABLE(3);  // hidden-to-hidden weights, [nOut, 3*nOut]
  auto b = INPUT_VARIABLE(4);   // biases, [3*nOut]

  auto dLdh = INPUT_VARIABLE(5);  // gradient vs. ff output, [time, bS, nOut]

  const int time = x->sizeAt(0);
  const int bS = x->sizeAt(1);
  const int nIn = x->sizeAt(2);
  const int nOut = hI->sizeAt(1);

  const std::vector<sd::LongType> h0CorrectShape = {bS, nOut};
  const std::vector<sd::LongType> wxCorrectShape = {nIn, 3 * nOut};
  const std::vector<sd::LongType> whCorrectShape = {nOut, 3 * nOut};
  const std::vector<sd::LongType> bCorrectShape = {3 * nOut};
  const std::vector<sd::LongType> hCorrectShape = {time, bS, nOut};

  REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0,
               "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0,
               "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0,
               "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
  REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0,
               "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
  REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0,
               "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());

  auto dLdxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), x->shapeInfo());
  auto dLdhIShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), hI->shapeInfo());
  auto dLdWxShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), Wx->shapeInfo());
  auto dLdWhShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), Wh->shapeInfo());
  auto dLdbShapeInfo = ConstantShapeHelper::getInstance().createShapeInfo(dLdh->dataType(), b->shapeInfo());

  return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo);
}

}  // namespace ops
}  // namespace sd

#endif