deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/transforms/split_v.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_v)

#include <ops/declarable/headers/parity_ops.h>

namespace sd {
namespace ops {
CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) {
  auto input = INPUT_VARIABLE(0);
  auto sizes = INPUT_VARIABLE(1);

  int axis = 0;

  if (block.getIArguments()->size() > 0) {
    axis = INT_ARG(0);
  } else if (block.width() > 2) {
    auto _a = INPUT_VARIABLE(2);
    axis = _a->e<int>(0);
  }

  if (axis < 0) axis += input->rankOf();

  std::vector<sd::LongType> axisVec = {axis};

  int pos = 0;
  std::vector<sd::LongType> indices(2 * input->rankOf());

  for (sd::LongType e = 0; e < sizes->lengthOf(); e++) {
    int c_size = sizes->e<int>(e);

    for (int d = 0; d < input->rankOf(); d++) {
      if (d == axis)
        indices[2 * d + 1] = (indices[2 * d] = pos) + c_size;
      else
        indices[2 * d] = indices[2 * d + 1] = 0;
    }

    auto output = OUTPUT_VARIABLE(e);
    REQUIRE_TRUE(output->dataType() == input->dataType(), 0, "SplitV: all outputs must have same data type as input");

    auto sub = (*input)(indices);

    output->assign(sub);

    pos += c_size;
  }

  return sd::Status::OK;
}

DECLARE_TYPES(split_v) {
  getOpDescriptor()
      ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
      ->setAllowedInputTypes(1, {ALL_INTS})
      ->setAllowedInputTypes(2, {ALL_INTS})
      ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
}

DECLARE_SHAPE_FN(split_v) {
  auto input = inputShape->at(0);
  // auto sizes = inputShape->at(1);

  auto shapeList = SHAPELIST();
  int rank = shape::rank(input);

  // 0 is just default axis
  int axis = 0;

  if (block.getIArguments()->size() > 0)
    axis = INT_ARG(0);
  else if (block.width() > 2) {
    auto _a = INPUT_VARIABLE(2);
    axis = _a->e<int>(0);
  }

  if (axis < 0) axis += shape::rank(input);

  // this op assumes we have sizes defined
  auto sizes = INPUT_VARIABLE(1);

  auto length = sizes->lengthOf();
  int pos = 0;
  for (sd::LongType e = 0; e < length; e++) {
    int c_size = sizes->e<int>(e);

    std::vector<sd::LongType> shape(rank);

    for (sd::LongType d = 0; d < rank; d++) {
      if (d != axis)
        shape[d] = shape::sizeAt(input, d);
      else
        shape[d] = c_size;
    }

    auto newShape =
        ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(input), shape::order(input), shape);
    shapeList->push_back(newShape);
  }

  return shapeList;
}
}  // namespace ops
}  // namespace sd

#endif