diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index f9c6265d9b7..69f9b0fe093 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -33,30 +33,42 @@ def isQuantArg(arg): ) +# Check if scale32 mode is used for given output element type +def isScale32(type): + return type == ts.DType.INT8 + + # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. # This utility function is for calculating the multier and shift given a scale. # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def computeMultiplierAndShift(scale): +def computeMultiplierAndShift(scale, scaleWidth=32): + if scaleWidth == 16: + offset = 15 + elif scaleWidth == 32: + offset = 31 + else: + raise AssertionError("unsupported scale width") + assert isinstance(scale, float) mantissa, exponent = math.frexp(scale) shift = exponent - const_two_to_31 = 1 << 31 - shifted_mantissa = round(mantissa * const_two_to_31) + const_2_power_15_or_31 = 1 << offset + shifted_mantissa = round(mantissa * const_2_power_15_or_31) - assert shifted_mantissa <= const_two_to_31 + assert shifted_mantissa <= const_2_power_15_or_31 - if shifted_mantissa == const_two_to_31: + if shifted_mantissa == const_2_power_15_or_31: shifted_mantissa = shifted_mantissa / 2 shift += 1 - # TOSA expects right shift to be positive, and embed (1 << 31) into right shift bits. - shift = 31 - shift + # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. + shift = offset - shift # INT32_MAX, 2^31 - 1 - assert shifted_mantissa <= (const_two_to_31 - 1) + assert shifted_mantissa <= (const_2_power_15_or_31 - 1) multiplier = shifted_mantissa @@ -66,8 +78,41 @@ def computeMultiplierAndShift(scale): return multiplier, shift +def buildRescale( + tosa_fb, + scale, + input_node, + output_type, + output_shape, + input_zp, + output_zp, + is_double_round, +): + is_scale32 = isScale32(output_type) + scale_width = 32 if is_scale32 else 16 + multiplier, shift = computeMultiplierAndShift(scale, scale_width) + + attr_rescale = ts.TosaSerializerAttribute() + attr_rescale.RescaleAttribute( + input_zp=input_zp, + output_zp=output_zp, + multiplier=[multiplier], + shift=[shift], + scale32=is_scale32, + double_round=is_double_round, + per_channel=False, + ) + + rescale_out = tosa_fb.addIntermediate(output_shape, output_type) + tosa_fb.addOperator( + TosaOp.Op().RESCALE, [input_node.name], [rescale_out.name], attr_rescale + ) + + return rescale_out + + def buildRescaleToInt32( - tosa_fb, input, input_zp, rescale_scale + tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True ) -> TosaSerializerTensor: multiplier, shift = computeMultiplierAndShift(rescale_scale) attr_rescale = ts.TosaSerializerAttribute() @@ -76,8 +121,8 @@ def buildRescaleToInt32( output_zp=0, multiplier=[multiplier], shift=[shift], - scale32=True, - double_round=True, + scale32=is_scale32, + double_round=is_double_round, per_channel=False, ) input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32) @@ -92,7 +137,13 @@ def buildRescaleToInt32( def buildRescaleFromInt32( - tosa_fb, input_name, output_name, output_zp, rescale_scale + tosa_fb, + input_name, + output_name, + output_zp, + rescale_scale, + is_scale32=True, + is_double_round=True, ) -> TosaSerializerTensor: multiplier, shift = computeMultiplierAndShift(rescale_scale) attr_rescale_output = ts.TosaSerializerAttribute() @@ -101,8 +152,8 @@ def buildRescaleFromInt32( output_zp=output_zp, multiplier=[multiplier], shift=[shift], - scale32=True, - double_round=True, + scale32=is_scale32, + double_round=is_double_round, per_channel=False, )