deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/loss/ctcLoss.cpp

Summary

Maintainability
Test Coverage
/*******************************************************************************
 * Copyright (c) 2021 Deeplearning4j Contributors
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 *******************************************************************************/

//
// @author AbdelRauf
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/ctc.h>
#include <system/op_boilerplate.h>

#if NOT_EXCLUDED(OP_ctc_loss)
namespace sd {
namespace ops {

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) {
  auto targetLabels = INPUT_VARIABLE(0);
  auto logitInput = INPUT_VARIABLE(1);
  auto targetLabelLengths = INPUT_VARIABLE(2);
  auto logitInputLengths = INPUT_VARIABLE(3);
  auto outputLosses = OUTPUT_VARIABLE(0);

  int blankIndex = INT_ARG(0);

  REQUIRE_TRUE(
      targetLabels->rankOf() == 2, 0,
      "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ",
      targetLabels->rankOf());
  REQUIRE_TRUE(logitInput->rankOf() == 3, 0,
               "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ",
               logitInput->rankOf());
  REQUIRE_TRUE(targetLabelLengths->rankOf() == 1, 0,
               "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ",
               targetLabelLengths->rankOf());
  REQUIRE_TRUE(logitInputLengths->rankOf() == 1, 0,
               "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ",
               logitInputLengths->rankOf());

  auto batchSize0 = targetLabels->sizeAt(0);
  auto batchSize1 = logitInput->sizeAt(0);
  auto batchSize2 = targetLabelLengths->sizeAt(0);
  auto batchSize3 = logitInputLengths->sizeAt(0);
  auto batchSize4 = outputLosses->sizeAt(0);

  bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
  check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);

  REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be %i", batchSize0);
  REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0,
               "CtcLoss: wrong shape of output array, expected is %s but got %s instead !",
               ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());

  auto emptyGradients = NDArrayFactory::empty<float>();
  sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, *outputLosses,
                            emptyGradients, blankIndex);

  return sd::Status::OK;
}

//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(ctc_loss) {
  getOpDescriptor()
      ->setAllowedInputTypes(0,{ALL_INDICES})
      ->setAllowedInputTypes(1, {ALL_FLOATS})
      ->setAllowedInputTypes(2,{ALL_INDICES})
      ->setAllowedInputTypes(3,{ALL_INDICES})
      ->setAllowedOutputTypes({ALL_FLOATS});
}

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(ctc_loss) {
  auto yShapeInfo = inputShape->at(1);
  auto zShapeInfo = inputShape->at(2);

  auto dtype = ArrayOptions::dataType(yShapeInfo);
  auto desc = new ShapeDescriptor(zShapeInfo, dtype);
  auto ret =  SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(desc));
  delete desc;
  return ret;
}

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) {
  auto targetLabels = INPUT_VARIABLE(0);
  auto logitInput = INPUT_VARIABLE(1);
  auto targetLabelLengths = INPUT_VARIABLE(2);
  auto logitInputLengths = INPUT_VARIABLE(3);
  auto outputGradients = OUTPUT_VARIABLE(0);

  int blankIndex = INT_ARG(0);

  REQUIRE_TRUE(
      targetLabels->rankOf() == 2, 0,
      "CtcLoss: target labels fails to meet rank requirement (batch_size, max_label_sequence_length): %i == 2 ",
      targetLabels->rankOf());
  REQUIRE_TRUE(logitInput->rankOf() == 3, 0,
               "CtcLoss: logit Input fails to meet rank requirement (batch_size, frames, classes): %i == 3 ",
               logitInput->rankOf());
  REQUIRE_TRUE(targetLabelLengths->rankOf() == 1, 0,
               "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ",
               targetLabelLengths->rankOf());
  REQUIRE_TRUE(logitInputLengths->rankOf() == 1, 0,
               "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ",
               logitInputLengths->rankOf());

  auto batchSize0 = targetLabels->sizeAt(0);
  auto batchSize1 = logitInput->sizeAt(0);
  auto batchSize2 = targetLabelLengths->sizeAt(0);
  auto batchSize3 = logitInputLengths->sizeAt(0);
  auto batchSize4 = outputGradients->sizeAt(0);

  bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
  check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);

  REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be %i", batchSize0);
  REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0,
               "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !",
               ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());

  auto emptyLoss = NDArrayFactory::empty<float>();
  sd::ops::helpers::ctcLoss(block, *logitInput, *targetLabels, *logitInputLengths, *targetLabelLengths, emptyLoss,
                            *outputGradients, blankIndex);

  return sd::Status::OK;
}

//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(ctc_loss_grad) {
  getOpDescriptor()
      ->setAllowedInputTypes({ALL_INDICES})
      ->setAllowedInputTypes(1, {ALL_FLOATS})
      ->setAllowedOutputTypes({ALL_FLOATS});
}

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(ctc_loss_grad) {
  auto yShapeInfo = inputShape->at(1);
  auto dtype = ArrayOptions::dataType(yShapeInfo);
  auto desc = new ShapeDescriptor(yShapeInfo, dtype);
  auto ret =  SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(desc));
  delete desc;
  return ret;
}

}  // namespace ops
}  // namespace sd

#endif