From b56f3cffd683aa9cebc940c1dfcde3862954a2af Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 20 Nov 2025 11:26:39 -0800 Subject: [PATCH] tag scales for external data --- backends/xnnpack/operators/node_visitor.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 68226644859..4643ada9336 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -275,7 +275,7 @@ def get_per_channel_dtype( return dtype def get_quant_params( - self, quant_params: QuantParams, xnn_graph: XNNGraph + self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None ) -> XNNQuantParams: if quant_params.per_channel: scale = cast(torch.Tensor, quant_params.scale) @@ -291,13 +291,18 @@ def get_quant_params( ctypes.POINTER(ctypes.c_char * num_bytes), ).contents scale_name = hashlib.sha256(bytes(scale_array)).hexdigest() + scale_name = "scale_" + scale_name xnn_graph.constant_data.append( ConstantDataOffset( offset=UINT64_MAX, size=num_bytes, named_key=scale_name ) ) + if external_tag is not None: + logging.info( + f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( - scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT + scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag ) if quant_params.per_channel_group: @@ -470,13 +475,19 @@ def define_tensor( # noqa: C901 assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1." # Serialize tensor value + custom_meta = tensor.meta.get("custom", None) + external_tag = ( + custom_meta.get("delegate_constant_tag", None) if custom_meta else None + ) ser_val = ( XValue(xvalue_union=tvalue) if quant_params is None else XValue( xvalue_union=XNNQuantizedTensorValue( tensor_value=tvalue, - quant_params=self.get_quant_params(quant_params, xnn_graph), + quant_params=self.get_quant_params( + quant_params, xnn_graph, external_tag + ), ) ) ) @@ -614,7 +625,7 @@ def get_serialized_buffer_index( f"Serializing constant data node {tensor} but tensor value has no bytes", ) sha256_hash = hashlib.sha256(bytes(array)) - named_key = sha256_hash.hexdigest() + named_key = tensor.name + "_" + sha256_hash.hexdigest() size = const_val.untyped_storage().nbytes() xnn_graph.constant_data.append( @@ -626,7 +637,6 @@ def get_serialized_buffer_index( custom_meta.get("delegate_constant_tag", None) if custom_meta else None ) if external_tag is not None: - external_tag = custom_meta.get("delegate_constant_tag", None) logging.info( f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" )