diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index b4fdcbff58b4..3b767547357b 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -389,4 +389,19 @@ inline DLDataType String2DLDataType(std::string s) { using DataType = runtime::DataType; } // namespace tvm + +namespace std { +template <> +struct hash { + inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } + std::size_t operator()(tvm::DataType const& dtype) const { + int a = dtype.code(); + int b = dtype.bits(); + int c = dtype.lanes(); + int d = cantor_pairing_function(a, b); + return cantor_pairing_function(c, d); + } +}; +} // namespace std + #endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 4c693fe64ee0..2e509a111c4a 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -29,6 +29,7 @@ debug, register_external_compiler, register_fake_quantization_to_integer, + register_mixed_precision_conversion, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index ccf011819a97..0d90a5cdeafa 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -18,10 +18,11 @@ """The base node types for the Relay language.""" import tvm._ffi import tvm.ir -from tvm.driver import lower, build -from tvm.target import get_native_generic_func, GenericFunc -from tvm.runtime import Object import tvm.ir._ffi_api +from tvm.driver import build, lower +from tvm.runtime import Object +from tvm.target import GenericFunc, get_native_generic_func + from . import _make @@ -457,6 +458,32 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) +def register_mixed_precision_conversion(op_name, func=None, level=10): + """Register mixed precision conversion function for an op + + Given an op the function should return information on how the value should be + converted. Specifically the function should take a call node and the target + mixed precision datatype (e.g. FP16) and return the conversion category + (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation + and output datatype of the operation in the mixed precision dtype space. + + Parameters + ---------- + op_name : str + The name of the operator + + func: function (call_node: relay.Call, target_dtype: string) + -> [conversion category, accumulation dtype, output dtype]: [int, string, string] + A function which given a call_node and target_dtype (e.g. FP16) returns the + conversion category and associated accumulation/output of the operation + when transformed into the mixed precision dtype space. + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py new file mode 100644 index 000000000000..6aa3ac09cfee --- /dev/null +++ b/python/tvm/relay/transform/mixed_precision.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long,unused-argument +"""Default behavior for ops in mixed_precision pass. Import this file to use.""" +from typing import List + +from tvm import relay +from tvm.relay.op import register_mixed_precision_conversion + +# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +# numerical reasons. +MIXED_PRECISION_ALWAYS = 0 +MIXED_PRECISION_FOLLOW = 1 +MIXED_PRECISION_NEVER = 2 + +# Default lists inspired from TF's classifications: +# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. +DEFAULT_ALWAYS_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + # "nn.batch_matmul", # Handled by a special case +] +DEFAULT_FOLLOW_LIST = [ + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + "zeros", + "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + # Comparison + "less", + "greater", + "less_equal", + "greater_equal", + # By definition copy and cast will depend on inputs for output. + "copy", + "cast", + "cast_like", + # Simple arithmetic + "add", + "subtract", + "multiply", + "divide", + "nn.bias_add", + "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", + # Simple activations + "max", + "min", + "maximum", + "minimum", + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + # "nn.global_max_pool1d", # does not exist yet + "nn.global_max_pool2d", + # "nn.global_max_pool3d", # does not exist yet + # "nn.global_avg_pool1d", # does not exist yet + "nn.global_avg_pool2d", + # "nn.global_avg_pool3d", # does not exist yet + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +] +DEFAULT_NEVER_LIST = [ + # In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + "nn.l2_normalize", + # Error function doesn't seem to be able to be lowered into fp16 version in llvm. + # Move to follow list when it does. + "erf", +] + + +# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType +def register_func_to_op_list(list_ops: List): + def decorator(func): + for op_name in list_ops: + register_mixed_precision_conversion(op_name, func=func) + + return decorator + + +def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + """A function which returns output dtypes in a way which works for most ops. + + Parameters + --------- + call_node: relay.Call + The call node containing the op. + mixed_precision_type: str + The target type to run the operation in. + Returns + ------- + output_dtypes : [str, str] + A list of two strings. The first represents the datatype used for accumulation + in the operation. The second represents the actual output datatype. + """ + # Assume support accumulation dtypes <---> has out_dtype attr. + # This is because there is no better way right now to tell which ops support accumulating + # at different data types. + # Some discussion here about making this better is here: + # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo + if hasattr(call_node.attrs, "out_dtype"): + return ["float32", mixed_precision_type] + + # [accumulation_dtype, output_dtype] for the operations + return [mixed_precision_type, mixed_precision_type] + + +# Functions for FTVMMixedPrecisionConversionType which +# Take in CallNodes and a DType and returns a conversion type, +# an accumulation dtype, and an output_dtype. +@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) +def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) +def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) +def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: + return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) + + +@register_mixed_precision_conversion("nn.batch_matmul") +def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: + # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. + # Batched matmul has inconsistent support for mixed precision operations. + # Many schedules ignore the out_dtype attribute which leads to errors when + # input types do not match the out_dtype. Therefore, accumulate to output_dtype. + return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e045abab6c..fa7f4c4db644 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -18,16 +18,15 @@ """ Relay pass transformation infrastructure. """ -import types -import inspect import functools +import inspect +import types import warnings import tvm.ir -from tvm import te +from tvm import relay, te from tvm.runtime import ndarray as _nd -from tvm import relay from . import _ffi_api @@ -1168,7 +1167,7 @@ def AnnotateSpans(): Returns ------- ret : tvm.transform.Pass - The regsistered AnnotateSpans pass. + The registered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() @@ -1199,3 +1198,29 @@ def FakeQuantizationToInteger(): The registered SimplifyExpr pass. """ return _ffi_api.FakeQuantizationToInteger() + + +def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): + """ + Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version + where as many operations as possible are in the target mixed_precision_type. + + Parameters + ---------- + mixed_precision_type: str + The target datatype to transform operations in the graph to use. + + missing_op_mode: int + Determines how to handle ops not registered with FTVMMixedPrecisionConversionType + 0: Does not allow any missing ops. Will throw errors when encountering any. + 1: Allow missing ops but emit warnings. + 2: Allow missing ops and silently ignore them. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass. + """ + if missing_op_mode < 0 or missing_op_mode > 2: + raise ValueError("Missing op mode is either 0, 1, or 2") + return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 130eb4b69844..3f72bdc4b667 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -18,13 +18,15 @@ # pylint: disable=unused-argument, redefined-builtin """Conv2D operators""" from __future__ import absolute_import as _abs + from collections import namedtuple + import tvm -from tvm import te, auto_scheduler +from tvm import auto_scheduler, te +from ..utils import get_const_int, get_const_tuple, simplify, tag from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify, get_const_tuple, get_const_int, tag from .winograd_util import winograd_transform_matrices # workload description of conv2d @@ -548,7 +550,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ow * WSTR + kw * dilation_w, idxmod(ic, ic_bn), ].astype(out_dtype) - * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block], + * kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype( + out_dtype + ), axis=[ic, kh, kw], ), name="conv2d_NCHWc", diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc new file mode 100644 index 000000000000..ae10c937ff1c --- /dev/null +++ b/src/relay/transforms/to_mixed_precision.cc @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file to_mixed_precision.cc + * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16. + * + */ + +#include +#include +#include +#include + +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// A callable which hashes std::pair +struct pair_hash { + template + std::size_t operator()(const std::pair& pair) const { + auto h1 = std::hash()(pair.first); + auto h2 = std::hash()(pair.second); + + // Use boost's combine_hash strategy + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); + } +}; + +// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +// numerical reasons. +enum MixedTypeConversionCategory : int { + MIXED_PRECISION_ALWAYS = 0, + MIXED_PRECISION_FOLLOW = 1, + MIXED_PRECISION_NEVER = 2 +}; + +// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype +using CachedCastNodes = std::unordered_map, Expr, pair_hash>; + +// Return array is of type : [MixedTypeConversionCategory (int), String, String] +// The fields are : [ConversionCategory, accumulation_datatype, output_datatype] +// Call is a call node, DataType is the mixed precision type +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( + const Call& call_node, const std::string& target_dtype_str)>; + +/*! \brief This class transforms the given relay module into a version where + * as many operations as possible operate in the target mixed precision dtype. + * + * Input : A Relay module with operations registered with FTVMMixedPrecisionConversionType + * functions. These describe when and how the operations will be transformed + * into the target precision dtype. + * + * Output : A Relay module with some operations transformed according to the below + * methodology. + * + * Methodology : + * 1) Each relay Op is either of conversion category ALWAYS, FOLLOW, NEVER + * defined by the associated FTVMMixedPrecisionConversionType function. + * If an operation is not registered, it by default is assumed to be + * FOLLOW. + * 2) ALWAYS operations always convert the input floating point args into + * the target mixed precision dtype. FOLLOW Ops will convert the input + * floating point args back into FP32 unless all floating point args + * are in the target mixed precision dtypes. NEVER ops will always cast + * inputs back into FP32. + * 3) Each ALWAYS Op, and FOLLOW Op with mixed precision dtype arguments + * also have an associated accumulation_dtype and output_dtype which + * describe whether a larger dtype is used to accumulate the results + * of the operation. The output_dtype meanwhile describes the dtype + * most Ops should use from this accumulator. + */ +class MixedPrecisionPass : public MixedModeMutator { + private: + /*! \brief A cache of nodes + target dtype to a cast version of the node with target dtype. */ + CachedCastNodes cast_nodes_cache_; + + /*! \brief The target datatype we want to convert to e.g. FP16 */ + const DataType mixed_precision_type_; + + /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were + * encountered. Used for emitting warnings on missing ops in the pass. + */ + std::unordered_map missing_ops_; + + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { + /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ + Attrs cur_attrs = call->attrs; + if (cur_attrs.get() != nullptr) { + // TODO(AndrewZhaoLuo): Figure out a better way to do this + // modify output_dtype attributes (accumulation dtypes for ops) + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } + + // modify dtype attributes (creating new tensors of type dtype) + if (auto attrs = cur_attrs.as()) { + return ModifyAttrsDType(attrs, accumulation_dtype); + } + } + + return cur_attrs; + } + + template + Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with out_dtype type. + These represent accumulation dtypes for some operations e.g. + conv2d might take in fp16 and give a fp32 result. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->out_dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + template + Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with dtype type. + This determines the output dtype for some ops. For example + zeros creates a tensor of zeros of the specified dtype. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->dtype); + ObjectPtr new_attrs = make_object(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + Type GetType(const Expr& expr) const { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main")->checked_type(); + } else { + return mod->Lookup("main").as()->body->checked_type(); + } + } + + bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { + /* Returns whether t is a type with only target mixed precision type elements. + If ignore_non_float, then ignore non-floating types. + */ + if (const TensorTypeNode* tensor_type = t.as()) { + return (!ignore_non_float || (tensor_type->dtype).is_float()) && + tensor_type->dtype == mixed_precision_type_; + } else if (const TupleTypeNode* tuple_type = t.as()) { + for (Type t : tuple_type->fields) { + if (!IsMixedPrecisionType(t, ignore_non_float)) return false; + } + return true; + } else { + LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle"; + return false; + } + } + + Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) { + /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ + + // If this is not a floating point type, do not cast. E.g. it might be an integer + if (!expr_dtype.is_float()) { + return expr; + } + + if (expr_dtype == wanted_dtype) { + return expr; + } + + const ExprNode* expr_node = expr.as(); + CHECK(expr_node) << "Non-expression node found in cast: " << expr; + + // Use cached result if possible. + auto search = cast_nodes_cache_.find({expr_node, wanted_dtype}); + if (search != cast_nodes_cache_.end()) { + return search->second; + } + + Expr result = Cast(expr, wanted_dtype); + cast_nodes_cache_[{expr_node, wanted_dtype}] = result; + + // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node + const ExprNode* new_expr_node = result.as(); + cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr; + return result; + } + + Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) { + /* Helper for casting arguments to call_nodes handling all relevant cases. */ + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return CachedCast(expr, tensor_type->dtype, wanted_dtype); + } else if (const TupleTypeNode* tuple_type = expr_type.as()) { + Array new_expr; + bool all_same = true; + for (size_t i = 0; i < (tuple_type->fields).size(); i++) { + Expr tuple_element = GetField(expr, i); + Type tuple_element_dtype = (tuple_type->fields)[i]; + Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype); + new_expr.push_back(casted_element); + all_same &= casted_element.same_as(tuple_element); + } + return all_same ? expr : Tuple(new_expr); + } + CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!"; + return expr; + } + + std::pair, Array> CastAllArgs(const Array& cur_args, + const Array& cur_arg_types, + const DataType& wanted_dtype) { + Array new_args; + Array new_arg_types; + for (size_t i = 0; i < cur_args.size(); i++) { + Expr cur_arg = cur_args[i]; + Type cur_arg_type = cur_arg_types[i]; + Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype); + Type new_arg_type = GetType(new_arg); + new_args.push_back(new_arg); + new_arg_types.push_back(new_arg_type); + } + return {new_args, new_arg_types}; + } + + public: + using MixedModeMutator::VisitExpr_; + + explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16)) + : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) { + if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) { + LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " + << mixed_precision_type_; + } + } + + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + const CallNode* post_call_node = post.as(); + CHECK(post_call_node) << "Expected a CallNode, but got " << post; + + Expr cur_op = post_call_node->op; + + // TODO(AndrewZhaoLuo): Support ADTs + // Relay's algebraic data types are not supported yet. + ICHECK(!cur_op.as() // used to declare functions for recursion + && !cur_op.as() // constructing ADT types + && !cur_op.as()) // used for calling recursive functions + << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass."; + + // Get info on the operation being called: + // conversion category (int), accumulation dtype (str), output dtype (str) + MixedTypeConversionCategory initial_category; + DataType accumulation_dtype, output_dtype; + if (cur_op.as()) { + // Avoid messing with functions to avoid changing signature + initial_category = MIXED_PRECISION_NEVER; + accumulation_dtype = DataType::Float(32); + output_dtype = DataType::Float(32); + } else if (cur_op.as()) { + static auto attr_map = + Op::GetAttrMap("FTVMMixedPrecisionConversionType"); + Op op = Downcast(cur_op); + if (attr_map.count(op)) { + // Calculate the conversion category and dtypes from registered attribute. + FTVMMixedPrecisionConversionType func = attr_map[op]; + Array op_descriptor = + func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); + ICHECK(op_descriptor.size() == 3) + << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() + << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false); + + int64_t op_conversion_type = Downcast(op_descriptor[0])->value; + initial_category = static_cast(op_conversion_type); + accumulation_dtype = DataType(String2DLDataType(Downcast(op_descriptor[1]))); + output_dtype = DataType(String2DLDataType(Downcast(op_descriptor[2]))); + } else { + missing_ops_[op->name] += 1; + + // If not registered, by default assume is a generic FOLLOW operation. + initial_category = MIXED_PRECISION_FOLLOW; + accumulation_dtype = mixed_precision_type_; + output_dtype = mixed_precision_type_; + } + } else { + LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; + } + + // First check if all the new mutated args are in lower precision form + Array cur_arg_types; + bool all_args_mixed_type_compatible = true; + for (Expr arg : post_call_node->args) { + Type cur_arg_type = GetType(arg); + cur_arg_types.push_back(cur_arg_type); + + if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) { + // We can cast Vars and Constants to the right types so don't care about the types. + bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) || + arg->IsInstance() || + arg->IsInstance(); + all_args_mixed_type_compatible &= is_mixed_type_compatible; + } + } + + // Determine the final category we want for conversion + MixedTypeConversionCategory final_category = initial_category; + if (initial_category == MIXED_PRECISION_FOLLOW) { + final_category = + all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; + } + + // Create the new arguments to the call. + DataType wanted_arg_dtypes = + final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); + auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); + Array new_args = call_args_and_types.first; + Array new_arg_types; + + if (pre_call_node->op.as()) { + // Function Nodes don't store type info in the Call, it should be a [] + new_arg_types = pre_call_node->type_args; + } else { + new_arg_types = call_args_and_types.second; + } + + // Finally create the new attributes. + if (final_category == MIXED_PRECISION_ALWAYS) { + Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); + Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); + if (accumulation_dtype != output_dtype) { + output = CastArg(output, GetType(output), output_dtype); + } + return output; + } + + return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); + } + + Expr VisitExpr_(const FunctionNode* func) final { + // Erase the ret_type annotation and let the normal pass recalculate + const_cast(func)->ret_type = Type(nullptr); + return ExprMutator::VisitExpr_(func); + } + + Expr VisitExpr_(const LetNode* op) final { + // First convert as much of the bound computation to lower precision as possible + Expr value = this->Mutate(op->value); + + // Then rewrite the var type and associated expression + Var var = Downcast(this->Mutate(op->var)); + VarNode* mutable_var = const_cast((op->var).as()); + mutable_var->type_annotation = GetType(value); + mutable_var->checked_type_ = mutable_var->type_annotation; + + // Mutate body last as it may depend on previous results + Expr body = this->Mutate(op->body); + return Let(var, value, body, op->span); + } + + // To access map of ops not registered for error reporting + friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, + int missing_op_mode); +}; + +Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) { + /* + missing_op_mode: + + 0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any. + 1: Allow missing ops but throw warnings. + 2: Allow missing ops and silently ignore them. + */ + ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) + << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; + + MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type); + auto result = converter.Mutate(expr); + + for (auto it = converter.missing_ops_.begin(); + missing_op_mode != 2 && it != converter.missing_ops_.end(); it++) { + std::string op_name = it->first; + int appear_count = it->second; + + LOG(WARNING) << "Op \"" << op_name << "\" not registered " + << "FTVMMixedPrecisionConversionType appears " << appear_count + << " times in graph."; + } + + if (converter.missing_ops_.size() != 0 && missing_op_mode == 0) { + CHECK(0) << "Missing ops were found!"; + } + return result; +} + +namespace transform { + +Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToMixedPrecision(f, mixed_precision_type, missing_op_mode)); + }; + return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 362a9b623d25..a6c3d6efec56 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -14,22 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import operator +import random +import numpy as np +import pytest import tvm -from tvm import te +import tvm.testing +from tvm import relay, te from tvm.contrib import graph_executor -from tvm import relay -import mxnet as mx +import model_zoo +import mxnet as mx from mxnet import gluon from mxnet.gluon.model_zoo import vision -import random -import pytest -import model_zoo - -import tvm.testing def verify_mxnet_frontend_impl( @@ -1231,7 +1229,9 @@ def verify(shape, axis=1, epsilon=1e-5): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) op_res = intrp.evaluate()(x, gamma, beta) - tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5 + ) verify((2, 3, 4, 5)) verify((32, 64, 80, 64)) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py new file mode 100644 index 000000000000..caccd52d60c2 --- /dev/null +++ b/tests/python/relay/test_to_mixed_precision.py @@ -0,0 +1,446 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for testing ToMixedPrecision pass""" +from typing import Any, Dict, List + +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.testing import lstm +from tvm.relay.transform import InferType, ToMixedPrecision, mixed_precision + + +def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: + dev = tvm.device("llvm", 0) + intrp = relay.create_executor("debug", mod, device=dev, target="llvm") + result = intrp.evaluate()(**mod_params) + if isinstance(result, tvm.runtime.container.ADT): + result = [r.asnumpy() for r in result] + return result + else: + return [result.asnumpy()] + + +def verify_mixed_precision_output_close( + mod: tvm.runtime.Module, + mod_params: Dict[str, Any], + mixed_precision_dtype="float16", + rtol: float = 1e-3, + atol: float = 0, +) -> tvm.runtime.Module: + + mod = InferType()(mod) + result_fp32 = run_module(mod, mod_params) + fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) + result_fp16 = run_module(fp16_mod, mod_params) + # Ensure the results are close + for fp32, fp16 in zip(result_fp32, result_fp16): + np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) + + return fp16_mod + + +def test_lstm(): + """A small stress test on a single unrolled lstm unit. + + Has internal functions and let statements the pass must work on. + """ + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) + + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") + + verify_mixed_precision_output_close(mod, mod_params, rtol=0.01, atol=0.01) + + +def test_lstm_float64(): + """Tests if can handle other mixed precision types. + + As a toy example show can convert graph to float64 and have it run. + + It doesn't really make sense to do it, this just shows we can change + the target mixed_precision_dtype. + """ + units = 3 + iterations = 5 + mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) + + # This is an unrolled lstm so each data should be the previous results but + # we don't care, we just want to stress test things. + for i in range(iterations): + mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform( + -10, 10, (1, units) + ).astype("float32") + + verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", rtol=0.01, atol=0.01 + ) + + +def test_convert_single_conv(): + """Conv is a green listed operation meaning it will always use fp16 workload. + + By default it accumulates to fp32 and outputs fp16. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_convert_single_conv_fp64(): + """As above but checks choosing a mixed_precision_type other than FP16 works""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype="float64", atol=0.01, rtol=1e-3 + ) + + # Note we still accumulate to FP32 by default, a user would need to overwrite default + # behavior to make this make more sense. + expected_mod = tvm.IRModule.from_expr( + relay.cast( + relay.nn.conv2d( + relay.cast(data, "float64"), + relay.cast(weight, "float64"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float64", + ) + ) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_convert_conv_bn(): + """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green.""" + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + + bn_shape = [5] + gamma = relay.var("gamma", shape=bn_shape) + beta = relay.var("beta", shape=bn_shape) + moving_mean = relay.var("moving_mean", shape=bn_shape) + moving_var = relay.var("moving_var", shape=bn_shape) + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + mod = tvm.IRModule.from_expr(bn[0]) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Creating expected module + data = relay.cast(relay.var("data", shape=data_shape), "float16") + weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") + conv = relay.cast( + relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"), + "float16", + ) + + bn_shape = [5] + gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") + beta = relay.cast(relay.var("beta", shape=bn_shape), "float16") + moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16") + moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16") + bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var) + + expected_mod = tvm.IRModule.from_expr(bn[0]) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_do_not_convert_softmax(): + """Softmax is a red listed operation and therefore should never be fp16.""" + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + mod = tvm.IRModule.from_expr(b) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_green_gray_propagates_simple(): + """Conv is a green listed operation, while addition is gray. + + As Conv outputs fp16 the add should be done in fp16. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + conv = conv + conv + mod = tvm.IRModule.from_expr(conv) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + conv_expr = relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ), + "float16", + ) + expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) + expected_mod = tvm.relay.transform.InferType()(expected_mod) + + assert not tvm.ir.structural_equal(fp16_mod, mod) + assert tvm.ir.structural_equal(fp16_mod, expected_mod) + + +def test_green_red_not_use_extraneous_cast(): + """Conv. is a green listed operation, while softmax is red. + + Conv. also by default accumulates to fp32 but outputs fp16. + + We want to avoid a situation where we have extraneous casts. + E.g. because softmax wants to operate on FP32 we might have + + conv (FP32) -> cast (FP16) -> cast (FP32) -> softmax (FP32) + + To get around this internally when we cast in the pass we cache + the output nodes and the reverse of the cast back to the original + node. For example casting the `conv (FP32)` to FP16 would produce: + + `conv (FP32) -> cast (FP16)` + + As the outputs. Now anytime we try to cast the `conv (FP32)` node + to FP16 it would return the cached result instead of a new cast node: + + `conv (FP32) -> cast (FP16)` + + Furthermore, if we try to cast the `cast (FP16)` node back to FP32 it + would just return + + `conv (FP32)`. + + This test makes sure this behavior occurs. + """ + data_shape = (1, 3, 32, 32) + weight_shape = (5, 3, 3, 3) + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32") + result = relay.nn.softmax(conv) + mod = tvm.IRModule.from_expr(result) + + mod_params = { + "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), + "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), + } + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + + # Construct expected structure + conv = relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float32", + ) + result = relay.nn.softmax(conv) + expected_mod = tvm.IRModule.from_expr(result) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, fp16_mod) + + +def test_red_gray_propagates_simple(): + """Everything after a softmax should be in FP32 (exception green colored ops)""" + shape = [1, 2, 3] + a = relay.var("a", shape=shape) + b = relay.nn.softmax(a) + c = b + b + mod = tvm.IRModule.from_expr(c) + mod = tvm.relay.transform.InferType()(mod) + + mod_params = { + "a": np.random.uniform(-1, 1, size=shape).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0.0) + + assert tvm.ir.structural_equal(mod, output_mod) + + +def test_let_statement_simple(): + """A 'simple' let statement example. + + Noticeable is the mutation of the bound variable types. + """ + var1 = relay.var("var1", shape=[1, 20]) + var2 = relay.var("var2", shape=[1, 20]) + + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + + r1 = var1 + var1 + + r2 = var2 + var2 + let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2) + let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2) + + mod = tvm.IRModule.from_expr(let1) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Construct expected structure + var1 = relay.var("var1", shape=[1, 20], dtype="float16") + var2 = relay.var("var2", shape=[1, 20], dtype="float16") + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + r1 = var1 + var1 + r2 = var2 + var2 + let2 = relay.Let( + var2, + relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"), + r2, + ) + let1 = relay.Let( + var1, + relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"), + let2, + ) + expected_mod = tvm.IRModule.from_expr(let1) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_where_simple(): + data = relay.var("data", shape=[1, 20]) + weight = relay.var("weight", shape=[20, 20]) + a = relay.nn.dense(data, weight, units=20) + b = relay.where(data, a, a) + mod = tvm.IRModule.from_expr(b) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), + } + + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") + a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16") + b = relay.where(data, a, a) + expected_mod = tvm.IRModule.from_expr(b) + expected_mod = InferType()(expected_mod) + + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_batch_matmul_simple(): + """Batch matmul is a special case where we try to accumulate to fp16. + + This is due to the fact heterogenous accumulation dtypes does not work + on all platforms at the moment. + """ + data = relay.var("data", shape=[1, 1, 20]) + weight = relay.var("weight", shape=[1, 20, 20]) + a = relay.nn.batch_matmul(data, weight) + mod = tvm.IRModule.from_expr(a) + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"), + "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16") + weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16") + a = relay.nn.batch_matmul(data, weight, out_dtype="float16") + expected_mod = tvm.IRModule.from_expr(a) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +if __name__ == "__main__": + pytest.main([__file__])