Skip to content
1 change: 0 additions & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ class DynamicToStaticMutator : public MixedModeMutator {
ICHECK_EQ(scale_w->data->ndim, 0);
const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>();
ICHECK(param);

return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data),
ToScalar(scale_h->data), ToScalar(scale_w->data),
param->layout, param->method,
Expand Down
50 changes: 32 additions & 18 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_

#include <builtin_fp16.h>
#include <dmlc/optional.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
Expand Down Expand Up @@ -380,43 +381,56 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \brief Convert an element of a NDArray with type int or float to scalar.
* \param array Input NDArray
* \param i element index
* \return Converted scalar value.
* \return Converted scalar value, or None if conversion failed
*/
static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
return reinterpret_cast<int16_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
return reinterpret_cast<int32_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return reinterpret_cast<int64_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<uint8_t*>(array->data)[i];
if (array->dtype.bits == 1) { // bool
return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 8) {
return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
return reinterpret_cast<uint16_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
return reinterpret_cast<uint32_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return reinterpret_cast<uint64_t*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLFloat) {
if (array->dtype.bits == 16) {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]);
return dmlc::optional<long double>(
__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
if (array->dtype.bits == 32) {
return reinterpret_cast<float*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
return reinterpret_cast<double*>(array->data)[i];
return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
}
}
LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
// make compiler happy
return -std::numeric_limits<double>::infinity();
return dmlc::optional<long double>();
}

/*!
* \brief Convert an element of a NDArray with type int or float to scalar.
* \param array Input NDArray
* \param i element index
* \return Converted scalar value
*/
static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
auto try_value = TryToScalar(array, i);
ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
return try_value.value();
}

/*!
Expand Down
Loading