deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/transforms/concat.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//  @author Yurii Shyrma (iuriish@yahoo.com)
//

#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/transforms.h>

#include <array>
#if NOT_EXCLUDED(OP_concat)

namespace sd {
namespace ops {

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
  REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");

  const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

  const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

  // first of all take into account possible presence of empty arrays
  // also if scalar is present -> copy its value to vector with length=1
  std::vector<const NDArray*> nonEmptyArrs;
  std::vector<sd::LongType> arrsToDelete;
  sd::LongType index = 0;
  bool allOfSameType = true;
  auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
  auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();


  for (sd::LongType i = 0; i < numOfInArrs; ++i) {
    auto input = INPUT_VARIABLE(i);
    auto currentRank = input->rankOf();
    auto *shapeInfoCast = input->shapeInfo();

    if (!input->isEmpty()) {
      allOfSameType &= (typeOfFirstArr == input->dataType());

      if (input->rankOf() == 0) {
        auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
        vec->assign(input);
        nonEmptyArrs.push_back(vec);
        arrsToDelete.push_back(index);
      } else {
        nonEmptyArrs.push_back(input);
      }
      ++index;
    }
  }

  const sd::LongType numOfNonEmptyArrs = nonEmptyArrs.size();

  if (numOfNonEmptyArrs == 0) {
    // All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
    REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
    return sd::Status::OK;
  }

  const sd::LongType rank = nonEmptyArrs[0]->rankOf();  //  look up to first non-empty array
  sd::LongType axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<sd::LongType>(0) : INT_ARG(0);
  if (axis < 0) {
    axis += rank;
  }

  // ******** input validation ******** //
  REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
  REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0,
               "CONCAT op: output array should have the same type as inputs arrays !");
  REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0,
               "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank - 1, axis);

  for (sd::LongType i = 1; i < numOfNonEmptyArrs; ++i) {
    if(nonEmptyArrs[i]->rankOf() != rank) {
      std::string error;
      error += std::string("CONCAT op: array at index: ");
      error += std::to_string(i);
      error += std::string(" ");
      error += std::string(" did not have same rank. Expected rank: " + rank);
      error += std::string(" but was: " + nonEmptyArrs[i]->rankOf());
      REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0,error.c_str());
    }

    for (sd::LongType dim = 0; dim < rank; ++dim) {
      if (dim != axis) {
        if(nonEmptyArrs[i]->sizeAt(dim) != nonEmptyArrs[0]->sizeAt(dim)) {
          std::string error;
          error += std::string("CONCAT op: array at index: ");
          error += std::to_string(i);
          error += std::string(" ");
          error += std::string(" did not have same dimension. Expected dimension : " + nonEmptyArrs[0]->sizeAt(dim));
          error += std::string(" but was: " + nonEmptyArrs[i]->sizeAt(dim));
          REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0,error.c_str());
        }
      }
    }

  }

  // ******** end of input validation ******** //

  auto output = OUTPUT_VARIABLE(0);


  helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);

  // delete dynamically allocated vectors with length=1
 // for (sd::LongType index : arrsToDelete) delete nonEmptyArrs[index];

  return sd::Status::OK;
}

DECLARE_SYN(ParallelConcat, concat);
DECLARE_SYN(concat_v2, concat);
DECLARE_SYN(concatv2, concat);

DECLARE_TYPES(concat) {
  getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY);
}

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(concat) {
  REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");

  const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

  const sd::LongType numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
  // first of all take into account possible presence of empty arrays
  // also if scalar is present -> use the shape of vector with length=1 instead
  ShapeList arrShapes;
  std::vector<sd::LongType> shapesToDelete;
  sd::LongType index = 0;
  for (sd::LongType i = 0; i < numOfInArrs; ++i) {
    if (inputShape->at(i)[0] == 0) {
      if (shape::isEmpty(inputShape->at(i))) {
        arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
      } else {
        arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
      }
    } else {
      arrShapes.push_back(inputShape->at(i));
    }
    ++index;
  }

  const sd::LongType numOfNonEmptyArrs = arrShapes.size();

  const sd::LongType rank = shape::rank(arrShapes.at(0));

  sd::LongType axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<sd::LongType>(0) : INT_ARG(0);
  if (axis < 0) {
    axis += rank;
  }

  // ******** input validation ******** //
  REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!",
               rank - 1, axis);


  for (sd::LongType i = 1; i < numOfNonEmptyArrs; ++i) {
    if (shape::rank(arrShapes.at(i)) != rank) {
      std::string error;
      error += std::string("CONCAT op: array at index: ");
      error += std::string("" + i);
      error += std::string(" ");
      error += std::string(" did not have same rank. Expected rank: " + rank);
      error += std::string(" but was: " + shape::rank(arrShapes.at(i)));
      THROW_EXCEPTION(error.c_str());
    }

    for (sd::LongType dim = 0; dim < rank; ++dim) {
      if (dim != axis) {
        if (arrShapes.at(i)[dim + 1] != arrShapes.at(0)[dim + 1]) {
          std::string error;
          error += std::string("CONCAT op: array at index: ");
          error += std::string("" + i);
          error += std::string(" ");
          error += std::string(" did not have same dimension. Expected dimension : " + arrShapes.at(0)[dim + 1]);
          error += std::string(" but was: " + arrShapes.at(0)[dim + 1]);
          THROW_EXCEPTION(error.c_str());
        }
      }
    }
  }

  // ******** end of input validation ******** //

  sd::LongType* outShapeInfo(nullptr);
  COPY_SHAPE(arrShapes.at(0), outShapeInfo);
  // case when we have only one input array
  if (numOfNonEmptyArrs == 1) {
    ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));
    return SHAPELIST(CONSTANT(outShapeInfo));
  }


  for (sd::LongType i = 1; i < numOfNonEmptyArrs; ++i) {
    outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1];
  }

  ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));

  auto desc = new ShapeDescriptor(outShapeInfo);
  auto result = ConstantShapeHelper::getInstance().createShapeInfo(desc);
  delete desc;
  return SHAPELIST(result);
}

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) {
  const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

  const sd::LongType numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

  auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1);

  auto first = INPUT_VARIABLE(0);

  const sd::LongType axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0)
                                            : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf());

  sd::LongType startPos = 0;

  for (sd::LongType e = 0; e < numOfInArrs - 1; e++) {
    auto originalChunk = INPUT_VARIABLE(e);
    auto epsilonChunk = OUTPUT_VARIABLE(e);
    std::vector<sd::LongType> indices(2 * epsilonNext->rankOf());

    int width = originalChunk->sizeAt(axis);

    for (sd::LongType e = 0; e < epsilonNext->rankOf(); e++) {
      if (e == axis)
        indices[2 * e + 1] = (indices[2 * e] = startPos) + width;
      else
        indices[2 * e + 1] = indices[2 * e] = 0;
    }

    auto subarray = (*epsilonNext)(indices, true);
    epsilonChunk->assign(subarray);

    startPos += width;
  }

  return sd::Status::OK;
}

DECLARE_TYPES(concat_bp) {
  getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
}

DECLARE_SHAPE_FN(concat_bp) {
  const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

  const sd::LongType numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

  auto shapeList = SHAPELIST();

  for (int e = 0; e < numOfInArrs - 1; e++) {
    auto inShape = inputShape->at(e);
    auto desc = new ShapeDescriptor(
        ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape));
    shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(desc));
    delete desc;
  }

  return shapeList;
}

}  // namespace ops
}  // namespace sd

#endif