diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 20018610fce..df6067a7d68 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -144,9 +144,10 @@ def check_common_constraints( return True def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): - # Check inputs are valid dtypes + # Check inputs are valid and have the same dtypes # Gather all args which are nodes args_to_check = [] + reference_dtype = None for arg in node.args: if isinstance(arg, list) or isinstance(arg, tuple): for item in arg: @@ -174,11 +175,32 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False + # Use the first dtype as reference + reference_dtype = reference_dtype or arg_val.dtype + + # Check for mixed dtypes + if arg_val.dtype != reference_dtype: + # Get op name if the attribute exists, otherwise use the full node target for logging + op_name = ( + node.target.__name__ + if hasattr(node.target, "__name__") + else str(node.target) + ) + why( + node, + reason=( + f"{op_name} does not support mixed input dtypes, " + f"got: [{reference_dtype}, {arg_val.dtype}]" + ), + ) + return False + return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid dtype + # Check outputs are valid node_val = node.meta.get("val", None) + if node_val is None: return True