deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author Yurii Shyrma (iuriish@yahoo.com)
//

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstmLayerCell)

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/lstmLayer.h>

namespace sd {
namespace ops {

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
  // equations (no peephole connections)
  // it  = σ(Wxi * xt  +  Wri * ht-1  +  bi)
  // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  bf)
  // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
  // ct  = ft ◦ ct-1 + it ◦ c't
  // ot  = σ(Wxo * xt  +  Wro * ht-1  +  bo)
  // ht  = ot ◦ tanh(ct)

  // equations (peephole connections are present)
  // it  = σ(Wxi * xt  +  Wri * ht-1  +  Wpi ◦ ct-1  +  bi)
  // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  Wpf ◦ ct-1  +  bf)
  // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
  // ct  = clip(ft ◦ ct-1 + it ◦ c't)
  // ot  = σ(Wxo * xt  +  Wro * ht-1  +  Wpo ◦ ct  +  bo)
  // ht  = ot ◦ tanh(ct)

  // notations:
  // bS - batch size
  // nIn - input size
  // nOut - output size (hidden size)

  // INPUTS:
  // input x:                          [bS, nIn] or [nIn]
  // input weights Wx:                 [nIn, 4*nOut]
  // recurrent weights Wr:             [nOut, 4*nOut]
  // initial (previous) output hI:     [bS, nOut] or [nOut]
  // initial (previous) cell state cI: [bS, nOut] or [nOut]
  // biases b (optional):              [4*nOut]
  // peephole weights Wp (optional):   [3*nOut]

  // OUTPUTS:
  // current output h:     [bS, nOut] or [nOut]
  // current cell state c: [bS, nOut] or [nOut]

  // !!! dimension 4*nOut implies order it, ft, c't, ot
  // !!! dimension 3*nOut implies order it, ft, ot

  // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded
  // relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
  const auto gateAct = INT_ARG(0);  // activation for input (i), forget (f) and output (o) gates
  const auto cellAct = INT_ARG(1);  // activation for cell state (c)
  const auto outAct = INT_ARG(2);   // activation for output (h)

  const auto hasBiases = B_ARG(0);  // indicates whether biases array is provided
  const auto hasPH = B_ARG(1);      // indicates whether peephole connections are present

  const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
  const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
  const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
  const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
  const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
  const auto outActHasBeta = outAct == 3 || outAct == 6;

  sd::LongType count = 1;
  const auto cellClip = T_ARG(0);  // cell clipping value, if it = 0 then do not apply clipping
  const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
  const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
  const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
  const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
  const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
  const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;

  count = 3;
  const auto x = INPUT_VARIABLE(0);                              // input
  const auto Wx = INPUT_VARIABLE(1);                             // input weights
  const auto Wr = INPUT_VARIABLE(2);                             // recurrent weights
  const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
  const auto hI = INPUT_VARIABLE(count++);                       // initial output
  const auto cI = INPUT_VARIABLE(count++);                       // initial cell state
  const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr;       // peephole weights

  REQUIRE_TRUE(cellClip >= 0, 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");

  auto h = OUTPUT_VARIABLE(0);
  auto c = OUTPUT_VARIABLE(1);

  // evaluate dimensions
  const sd::LongType bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
  const sd::LongType nIn = x->sizeAt(-1);
  const sd::LongType nOut = Wx->sizeAt(-1) / 4;

  // inputs validations
  // Wx validation
  if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  // Wr validation
  if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut)
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
  // initial output/cell validation
  std::vector<sd::LongType> exphIcIShape =
      x->rankOf() == 1 ? std::vector<sd::LongType>{nOut} : std::vector<sd::LongType>{bS, nOut};
  REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0,
               "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0,
               "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
  // biases validation
  if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut))
    REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
  // peephole weights validation
  if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut))
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({3 * nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());

  std::vector<float> params = {
      static_cast<float>(0) /*ignore*/, static_cast<float>(0) /*ignore*/, static_cast<float>(cellClip),
      static_cast<float>(gateAct),      static_cast<float>(gateAlpha),    static_cast<float>(gateBeta),
      static_cast<float>(cellAct),      static_cast<float>(cellAlpha),    static_cast<float>(cellBeta),
      static_cast<float>(outAct),       static_cast<float>(outAlpha),     static_cast<float>(outBeta)};

  helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);

  return sd::Status::OK;
}

DECLARE_TYPES(lstmLayerCell) {
  getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
}

DECLARE_SHAPE_FN(lstmLayerCell) {
  const auto hasBiases = B_ARG(0);  // indicates whether biases array is provided

  sd::LongType count = hasBiases ? 4 : 3;
  const auto hI = INPUT_VARIABLE(count++);  // initial output
  const auto cI = INPUT_VARIABLE(count);    // initial cell state

  return new ShapeList({hI->shapeInfo(), cI->shapeInfo()});
}

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
  // equations (no peephole connections)
  // it  = σ(Wxi * xt  +  Wri * ht-1  +  bi)
  // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  bf)
  // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
  // ct  = ft ◦ ct-1 + it ◦ c't
  // ot  = σ(Wxo * xt  +  Wro * ht-1  +  bo)
  // ht  = ot ◦ tanh(ct)

  // equations (peephole connections are present)
  // it  = σ(Wxi * xt  +  Wri * ht-1  +  Wpi ◦ ct-1  +  bi)
  // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  Wpf ◦ ct-1  +  bf)
  // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
  // ct  = clip(ft ◦ ct-1 + it ◦ c't)
  // ot  = σ(Wxo * xt  +  Wro * ht-1  +  Wpo ◦ ct  +  bo)
  // ht  = ot ◦ tanh(ct)

  // notations:
  // bS - batch size
  // nIn - input size
  // nOut - output size (hidden size)

  // INPUTS:
  // input x:                          [bS, nIn] or [nIn]
  // input weights Wx:                 [nIn, 4*nOut]
  // recurrent weights Wr:             [nOut, 4*nOut]
  // initial (previous) output hI:     [bS, nOut] or [nOut]
  // initial (previous) cell state cI: [bS, nOut] or [nOut]
  // gradient wrt output dLdh:         [bS, nOut] or [nOut]
  // gradient wrt cell state dLdc:     [bS, nOut] or [nOut]
  // peephole weights Wp (optional):   [3*nOut]
  // biases b (optional):              [4*nOut]

  // OUTPUTS:
  // gradient wrt x dLdx:              [bS, nIn] or [nIn]
  // gradient wrt Wx dLdWx:            [nIn, 4*nOut]
  // gradient wrt Wr dLdWr:            [nOut, 4*nOut]
  // gradient wrt hI dLdhI:            [bS, nOut] or [nOut]
  // gradient wrt cI dLdcI:            [bS, nOut] or [nOut]
  // gradient wrt b dLdb (optional):   [4*nOut]
  // gradient wrt Wp dLdWp (optional): [3*nOut]

  // !!! dimension 4*nOut implies order it, ft, c't, ot
  // !!! dimension 3*nOut implies order it, ft, ot

  // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded
  // relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
  const auto gateAct = INT_ARG(0);  // activation for input (i), forget (f) and output (o) gates
  const auto cellAct = INT_ARG(1);  // activation for cell state (c)
  const auto outAct = INT_ARG(2);   // activation for output (h)

  const auto hasBiases = B_ARG(0);  // indicates whether biases array is provided
  const auto hasPH = B_ARG(1);      // indicates whether peephole connections are present

  const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
  const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
  const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
  const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
  const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
  const auto outActHasBeta = outAct == 3 || outAct == 6;

  sd::LongType count = 1;
  const auto cellClip = T_ARG(0);  // cell clipping value, if it = 0 then do not apply clipping
  const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
  const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
  const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
  const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
  const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
  const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;

  count = 3;
  const auto x = INPUT_VARIABLE(0);                              // input
  const auto Wx = INPUT_VARIABLE(1);                             // input weights
  const auto Wr = INPUT_VARIABLE(2);                             // recurrent weights
  const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
  const auto hI = INPUT_VARIABLE(count++);                       // initial output
  const auto cI = INPUT_VARIABLE(count++);                       // initial cell state
  const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr;     // peephole weights
  const auto dLdh = INPUT_VARIABLE(count);                       // gradient wrt output

  REQUIRE_TRUE(cellClip >= 0, 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");

  count = 3;
  auto dLdx = OUTPUT_VARIABLE(0);
  auto dLdWx = OUTPUT_VARIABLE(1);
  auto dLdWr = OUTPUT_VARIABLE(2);
  auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
  auto dLdhI = OUTPUT_VARIABLE(count++);
  auto dLdcI = OUTPUT_VARIABLE(count++);
  auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;

  // evaluate dimensions
  const sd::LongType bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
  const sd::LongType nIn = x->sizeAt(-1);
  const sd::LongType nOut = Wx->sizeAt(-1) / 4;

  // inputs validations
  // Wx validation
  if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
  // Wr validation
  if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut)
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
  // initial output/cell validation
  std::vector<sd::LongType> exphIcIShape =
      x->rankOf() == 1 ? std::vector<sd::LongType>{nOut} : std::vector<sd::LongType>{bS, nOut};
  REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0,
               "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
  REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0,
               "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
  REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0,
               "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !",
               ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
  // biases validation
  if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut))
    REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
  if (dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4 * nOut))
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({4 * nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
  // peephole weights validation
  if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut))
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({3 * nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
  if (dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3 * nOut))
    REQUIRE_TRUE(false, 0,
                 "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !",
                 ShapeUtils::shapeAsString({3 * nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());

  std::vector<float> params = {
      static_cast<float>(0) /*ignore*/, static_cast<float>(0) /*ignore*/, static_cast<float>(cellClip),
      static_cast<float>(gateAct),      static_cast<float>(gateAlpha),    static_cast<float>(gateBeta),
      static_cast<float>(cellAct),      static_cast<float>(cellAlpha),    static_cast<float>(cellBeta),
      static_cast<float>(outAct),       static_cast<float>(outAlpha),     static_cast<float>(outBeta)};

  std::vector<sd::LongType> zShape =
      x->rankOf() == 1 ? std::vector<sd::LongType>({4 * nOut}) : std::vector<sd::LongType>({bS, 4 * nOut});

  NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
  NDArray a = z.ulike();
  NDArray h = cI->ulike();
  NDArray c = cI->ulike();

  helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);

  helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr,
                           dLdhI, dLdcI, dLdb, dLdWp);

  return sd::Status::OK;
}

DECLARE_TYPES(lstmLayerCellBp) {
  getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
}

DECLARE_SHAPE_FN(lstmLayerCellBp) {
  const auto hasBiases = B_ARG(0);  // indicates whether biases array is provided
  const auto hasPH = B_ARG(1);      // indicates whether peephole connections are present

  sd::LongType count = 3;
  const auto x = INPUT_VARIABLE(0);                              // input
  const auto Wx = INPUT_VARIABLE(1);                             // input weights
  const auto Wr = INPUT_VARIABLE(2);                             // recurrent weights
  const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
  const auto hI = INPUT_VARIABLE(count++);                       // initial output
  const auto cI = INPUT_VARIABLE(count++);                       // initial cell state
  const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr;       // peephole weights

  auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo());

  if (b != nullptr) shapes->push_back(b->shapeInfo());

  shapes->push_back(hI->shapeInfo());
  shapes->push_back(cI->shapeInfo());

  if (Wp != nullptr) shapes->push_back(Wp->shapeInfo());

  return shapes;
}

}  // namespace ops
}  // namespace sd

#endif