From a459cdf565a85011c74812ea7cec49882eea18d3 Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Wed, 16 Aug 2023 11:50:05 -0700 Subject: [PATCH] Dtype compliance: clamp (#69) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/69 Adjust clamp portable op have near-complete Dtype compliance with aten version. The Aten version of clamp has a weird quirk where it allows you to pass in a value for the min and/or max args which is below the normal range of the input/output tensor dataype when that datatype is uint8 specifically. But it doesn't allow ABOVE the range, and it doesn't allow below or above for any other datatype. We are choosing to leave a discrepancy between aten and portable by making uint8 behave like the rest of the datatypes for portable (not allowing below the range). This is already tested by the ByteTensorNegativeClampDies test (which is skipped when running the aten tests). Reviewed By: SS-JIA, manuelcandales Differential Revision: D47573238 fbshipit-source-id: 94f3288dd4b5d2b58f6bbda66359196d2ad3f214 --- kernels/portable/cpu/op_clamp.cpp | 109 +++++++++++++++++++++++------- kernels/test/op_clamp_test.cpp | 4 +- 2 files changed, 87 insertions(+), 26 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 8a824444814..1cc3c4af4b0 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -84,38 +84,99 @@ Tensor& clamp_out( Error err = resize_tensor(out, in.sizes()); ET_CHECK_MSG(err == Error::Ok, "Could not resize output"); - ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out); + ScalarType in_type = in.scalar_type(); + ScalarType min_type = in_type; + ScalarType max_type = in_type; + ScalarType common_type = in_type; + ScalarType out_type = out.scalar_type(); + + bool has_min = min_opt.has_value(); + if (has_min) { + min_type = utils::get_scalar_dtype(min_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); + } + bool has_max = max_opt.has_value(); + if (has_max) { + max_type = utils::get_scalar_dtype(max_opt.value()); + common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); + } + + ET_CHECK_MSG( + has_min || has_max, "At least one of 'min' or 'max' must not be None"); - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() { + ET_CHECK(common_type == out_type); + + ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { // Extract optional min value - CTYPE min = 0; - bool has_min = min_opt.has_value(); + CTYPE_OUT min = 0; if (has_min) { - bool ok = utils::extract_scalar(min_opt.value(), &min); - ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range"); + ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() { + CTYPE_MIN min_val = 0; + ET_EXTRACT_SCALAR(min_opt.value(), min_val); + if (isIntegralType(out_type, /*includeBool=*/false)) { + if (static_cast(min_val) < + std::numeric_limits::lowest() || + static_cast(min_val) > + std::numeric_limits::max()) { + ET_CHECK_MSG(false, "minimum value out of bounds"); + } + } + if (isFloatingType(out_type)) { + if (std::isfinite(min_val) && + (static_cast(min_val) < + std::numeric_limits::lowest() || + static_cast(min_val) > + std::numeric_limits::max())) { + ET_CHECK_MSG(false, "minimum value out of bounds"); + } + } + min = static_cast(min_val); + }); } + // Extract optional max value - CTYPE max = 0; - bool has_max = max_opt.has_value(); + CTYPE_OUT max = 0; if (has_max) { - bool ok = utils::extract_scalar(max_opt.value(), &max); - ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range"); - } - - apply_unary_map_fn( - [has_min, min, has_max, max](const CTYPE val_in) { - CTYPE val_out = val_in; - if (has_min) { - val_out = max_override(val_out, min); + ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() { + CTYPE_MAX max_val = 0; + ET_EXTRACT_SCALAR(max_opt.value(), max_val); + if (isIntegralType(out_type, /*includeBool=*/false)) { + if (static_cast(max_val) < + std::numeric_limits::lowest() || + static_cast(max_val) > + std::numeric_limits::max()) { + ET_CHECK_MSG(false, "maximum value out of bounds"); } - if (has_max) { - val_out = min_override(val_out, max); + } + if (isFloatingType(out_type)) { + if (std::isfinite(max_val) && + (static_cast(max_val) < + std::numeric_limits::lowest() || + static_cast(max_val) > + std::numeric_limits::max())) { + ET_CHECK_MSG(false, "maximum value out of bounds"); } - return val_out; - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); + } + max = static_cast(max_val); + }); + } + + ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { + apply_unary_map_fn( + [has_min, min, has_max, max](const CTYPE_IN val_in) { + CTYPE_OUT val_out = static_cast(val_in); + if (has_min) { + val_out = max_override(val_out, min); + } + if (has_max) { + val_out = min_override(val_out, max); + } + return val_out; + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); }); return out; diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index b505f91bc85..08d898733e1 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) { #ifndef USE_ATEN_LIB TEST(OpClampOutTest, IntTensorTooSmallClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(-2147483649); } TEST(OpClampOutTest, IntTensorTooLargeClampDies) { - // Cannot be represented by a uint32_t. + // Cannot be represented by a int32_t. expect_bad_clamp_value_dies(2147483648); } #endif