deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/reduce/argmin.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// Created by raver119 on 01.11.2017.
// Modified by GS <sgazeos@gmail.com> 4/5/2018.

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_argmin)

#include <helpers/ConstantTadHelper.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>

namespace sd {
namespace ops {

DECLARE_TYPES(argmin) {
  getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})->setAllowedOutputTypes({ALL_INTS});
}

CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) {
  auto input = INPUT_VARIABLE(0);
  auto axis = *block.getIArguments();

  auto output = OUTPUT_VARIABLE(0);

  if (output->isEmpty()) return sd::Status::OK;

  // axis might be dynamic (i.e. tf mode)
  if (block.width() > 1 && axis.size() == 0) {
    auto axisVector = INPUT_VARIABLE(1);
    helpers::adjustAxis(input->rankOf(), axisVector, axis);
    helpers::argMin(*input, *output, axis);
  } else {
    helpers::argMin(*input, *output, axis);
  }

  STORE_RESULT(output);

  return sd::Status::OK;
}

DECLARE_SHAPE_FN(argmin) {
  std::vector<sd::LongType> dims;

  if (block.width() == 1) {
    dims = *block.getIArguments();
  } else {
    auto y = INPUT_VARIABLE(1);
    dims = y->template asVectorT<sd::LongType>();
  }

  auto keepDims = block.numB() ? B_ARG(0) : false;
  auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;

  // we're resolving negative axis here
  helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);

  auto in = inputShape->at(0);
  for (auto d : dims) {
    // we have special case here
    if (d == sd::DataTypeUtils::max<int>()) continue;

    REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank")
    REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
  }

  // special case - output is scalar
  if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
    return SHAPELIST(ConstantShapeHelper::getInstance().scalarShapeInfo(dtype));
  }

  return SHAPELIST(
      ShapeUtils::evalReduceShapeInfo('c', &dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
}

}  // namespace ops
}  // namespace sd

#endif