deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/generic/images/image_resize.cpp

Summary

Maintainability
Test Coverage
/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author sgazeos@gmail.com
//

#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_image_resize)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_resize.h>

namespace sd {
namespace ops {
CUSTOM_OP_IMPL(image_resize, 2, 1, false, -2, -2) {
  auto image = INPUT_VARIABLE(0);
  auto size = INPUT_VARIABLE(1);

  auto output = OUTPUT_VARIABLE(0);

  int width;
  int height;
  bool antialias = false;
  REQUIRE_TRUE(size->lengthOf() == 2, 0, "image_resize: Resize params is a pair of values, not %lld.",
               size->lengthOf());
  width = size->e<int>(1);
  height = size->e<int>(0);
  if (block.numB() >= 2) {
    antialias = B_ARG(1);
  }
  bool exclude_outside = true;
  double bicubicCoefficient = helpers::KeysCubicKernelFunc<double>::KEYS_CUBIC_COEF;
  auto method = helpers::ImageResizeMethods::kResizeBilinear;
  helpers::CoordinateTransformationMode coorMode = helpers::CoordinateTransformationMode::HALF_PIXEL;

  if (block.numB() >= 3) {
    exclude_outside = B_ARG(2);
  }
  if (block.numT() > 0) {
    bicubicCoefficient = T_ARG(0);
  }
  if (block.numI() >= 1) {
    method = (helpers::ImageResizeMethods)INT_ARG(0);
  }
  if (block.numI() >= 2) {
    coorMode = static_cast<helpers::CoordinateTransformationMode>(INT_ARG(1));
  } else if (method == helpers::ImageResizeMethods::kResizeNearest) {
    // retain old behavour
    coorMode = helpers::CoordinateTransformationMode::HALF_PIXEL_NN;
  }
  helpers::NearestMode nearestMode = helpers::NearestMode::FLOOR;
  if (method == helpers::ImageResizeMethods::kResizeNearest && block.numI() == 3) {
    nearestMode = static_cast<helpers::NearestMode>(INT_ARG(2));
    REQUIRE_TRUE(nearestMode >= helpers::NearestMode::FLOOR && nearestMode <= helpers::NearestMode::CEIL, 0,
                 "image_resize: nearest Mode should be between %i and %i, but %i was given.",
                 (int)helpers::NearestMode::FLOOR, (int)helpers::NearestMode::CEIL, (int)nearestMode);
  }
  REQUIRE_TRUE(method == helpers::ImageResizeMethods::kResizeNearest || output->dataType() == DataType::FLOAT32, 0,
               "image_resize: Output data type should be FLOAT32 for this method %i", (int)method);
  REQUIRE_TRUE(
      method >= helpers::ImageResizeMethods::kResizeFirst && method <= helpers::ImageResizeMethods::kResizeLast, 0,
      "image_resize: Resize method should be between %i and %i, but %i was given.",
      (int)helpers::ImageResizeMethods::kResizeFirst, (int)helpers::ImageResizeMethods::kResizeLast, (int)method);
  auto inRank = image->rankOf();
  REQUIRE_TRUE(inRank >= 3 && inRank <= 4, 0, "image_resize: Input rank should be 4 or 3, but %i given.",
               image->rankOf());
  auto source =
      inRank == 4
          ? image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)})
          : image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
  auto target =
      inRank == 4
          ? output->reshape(output->ordering(),
                            {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false)
          : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);

  // inform the user about the current state of the implementation
  if (antialias && method != helpers::ImageResizeMethods::kResizeNearest) {
    REQUIRE_TRUE(coorMode == helpers::CoordinateTransformationMode::HALF_PIXEL && exclude_outside, 0,
                 "antialiasing is effective only with HALF_PIXEL and exclude_outside being set true");
  }
  //
  if ((method != helpers::ImageResizeMethods::kResizeBicubic &&
       method != helpers::ImageResizeMethods::kResizeNearest)) {
    REQUIRE_TRUE(coorMode == helpers::CoordinateTransformationMode::HALF_PIXEL && exclude_outside, 0,
                 "this method supports only HALF_PIXEL and exclude_outside being set true");
  }

  return helpers::resizeFunctor(block.launchContext(), image, width, height, method, coorMode, exclude_outside,
                                nearestMode, bicubicCoefficient, antialias, output);
}

DECLARE_SHAPE_FN(image_resize) {
  auto in = inputShape->at(0);

  sd::LongType* outputShape;
  auto method = helpers::ImageResizeMethods::kResizeBilinear;
  if (block.numI() >= 1) {
    method = (helpers::ImageResizeMethods)INT_ARG(0);
  }

  int width;
  int height;
  double ratio = shape::sizeAt(in, static_cast<sd::LongType>(1)) / (0.0 + shape::sizeAt(in, static_cast<sd::LongType>(2)));
  auto newImageSize = INPUT_VARIABLE(1);
  REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.",
               newImageSize->lengthOf());

  width = newImageSize->e<int>(1);
  height = newImageSize->e<int>(0);
  if (block.numB() > 0) {
    if (B_ARG(0)) {
      width = math::sd_ceil<double, int>(height / ratio);
    }
  }
  auto dtype = DataType::FLOAT32;
  if (method == helpers::ImageResizeMethods::kResizeNearest) dtype = ArrayOptions::dataType(in);
  auto shape = ConstantShapeHelper::getInstance().createShapeInfo(
      dtype, 'c',
      shape::rank(in) == 4 ? std::vector<sd::LongType>{in[1], height, width, in[4]}
                           : std::vector<sd::LongType>{height, width, in[4]});

  return SHAPELIST(shape);
}
DECLARE_TYPES(image_resize) {
  getOpDescriptor()
      ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
      ->setAllowedInputTypes(1, {ALL_INTS})
      ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS});
}

}  // namespace ops
}  // namespace sd

#endif