Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
namespace tvm {
namespace relay {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
// A callable which hashes std::pair
struct pair_hash {
template <class T1, class T2>
Expand Down Expand Up @@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
* encountered. Used for emitting warnings on missing ops in the pass.
*/
std::unordered_map<std::string, int> missing_ops_;
const RelayExprNode* root_;
std::vector<DataType> original_dtype_;
bool keep_orig_output_dtype_;

Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
Expand Down Expand Up @@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
public:
using MixedModeMutator::VisitExpr_;

explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
: MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
DataType mixed_precision_type = DataType::Float(16))
: MixedModeMutator(),
mixed_precision_type_(mixed_precision_type),
root_(Downcast<Function>(base)->body.get()),
keep_orig_output_dtype_(keep_orig_output_dtype) {
if (keep_orig_output_dtype_) {
if (root_->IsInstance<tvm::relay::TupleNode>()) {
const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
for (Type t : tuple_type->fields) {
const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
original_dtype_.push_back(tensor_type->dtype);
}
} else if (root_->IsInstance<tvm::relay::CallNode>()) {
original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
}
}
if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) {
LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
<< mixed_precision_type_;
Expand Down Expand Up @@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
if (accumulation_dtype != output_dtype) {
output = CastArg(output, GetType(output), output_dtype);
}
if (pre_call_node == root_ && keep_orig_output_dtype_) {
if (original_dtype_[0] != output_dtype) {
output = CastArg(output, GetType(output), original_dtype_[0]);
}
}
return output;
}

Expand All @@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
if (pre == root_ && keep_orig_output_dtype_) {
Array<Expr> new_expr;
bool all_same = true;
for (size_t i = 0; i < original_dtype_.size(); i++) {
Expr output_element = GetField(post, i);
Expr casted_element;
auto output_element_type = transform::InferTypeLocal(output_element);
casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
new_expr.push_back(casted_element);
all_same &= casted_element.same_as(output_element);
}
if (!all_same) {
return Tuple(new_expr);
}
}
return post;
}

Expand All @@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
}

// To access map of ops not registered for error reporting
friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
int missing_op_mode);
friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
const DataType& mixed_precision_type, int missing_op_mode);
};

Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
const DataType& mixed_precision_type, int missing_op_mode) {
/*
missing_op_mode:

Expand All @@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
<< " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;

MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
MixedPrecisionPass converter =
MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
auto result = converter.Mutate(expr);

for (auto it = converter.missing_ops_.begin();
Expand All @@ -460,7 +501,12 @@ namespace transform {
Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ToMixedPrecision(f, mixed_precision_type, missing_op_mode));
bool keep_orig_output_dtype = false;
keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
Bool(keep_orig_output_dtype))
.value();
return Downcast<Function>(
ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
};
return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
}
Expand Down
39 changes: 29 additions & 10 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
mixed_precision_dtype="float16",
rtol: float = 1e-3,
atol: float = 0,
keep_orig_output_dtype=False,
) -> tvm.runtime.Module:

mod = InferType()(mod)
result_fp32 = run_module(mod, mod_params)
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

if not keep_orig_output_dtype:
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)
else:
with tvm.transform.PassContext(
config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}
):
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

# Ensure the results are close
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)

if keep_orig_output_dtype:
assert (
np.array(result_fp16).dtype == np.array(result_fp32).dtype
), "output type and original type mismatch"

return fp16_mod


Expand Down Expand Up @@ -117,16 +131,21 @@ def test_convert_single_conv():
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(
mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True
)

expected_mod = tvm.IRModule.from_expr(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
"float32",
)
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down