From 7579d91466024410eee8f3dd6812708616eb865b Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 13 Aug 2025 15:38:29 -0700 Subject: [PATCH] Arm backend: use tosa_ref_model only if installed Not making these tests XFail because there is still value in running these tests w/o validating output since we do go through a lot of other checks in the AoT flow when generating a PTE. TOSA ref model being not installed is not the common case anyway. Added explicit warnings (which should show through pytest unlike logger) as a reminder for the comparison is being skipped. Test: with and w/o tosa_reference_model installed locally --- backends/arm/test/tester/test_pipeline.py | 44 ++++++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index cbe3f5f613d..28bb25d1cae 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import warnings as _warnings from typing import ( Any, @@ -226,6 +227,12 @@ def find_pos(self, stage_id: str): raise Exception(f"Stage id {stage_id} not found in pipeline") + def has_stage(self, stage_id: str): + try: + return self.find_pos(stage_id) >= 0 + except: + return False + def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): """Adds a stage after the given stage id.""" pos = self.find_pos(stage_id) + 1 @@ -271,7 +278,34 @@ def run(self): raise e -class TosaPipelineINT(BasePipelineMaker, Generic[T]): +class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): + + @staticmethod + def is_tosa_ref_model_available(): + """Checks if the TOSA reference model is available.""" + # Not all deployments of ET have the TOSA reference model available. + # Make sure we don't try to use it if it's not available. + try: + import tosa_reference_model + + # Check if the module has content + return bool(dir(tosa_reference_model)) + except ImportError: + return False + + def run(self): + if ( + self.has_stage("run_method_and_compare_outputs") + and not self.is_tosa_ref_model_available() + ): + _warnings.warn( + "Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available." + ) + self.pop_stage("run_method_and_compare_outputs") + super().run() + + +class TosaPipelineINT(TOSAPipelineMaker, Generic[T]): """ Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model. @@ -375,7 +409,7 @@ def __init__( ) -class TosaPipelineFP(BasePipelineMaker, Generic[T]): +class TosaPipelineFP(TOSAPipelineMaker, Generic[T]): """ Lowers a graph to FP TOSA spec and tests it with the TOSA reference model. @@ -629,7 +663,7 @@ def __init__( ) -class PassPipeline(BasePipelineMaker, Generic[T]): +class PassPipeline(TOSAPipelineMaker, Generic[T]): """ Runs single passes directly on an edge_program and checks operators before/after. @@ -719,7 +753,7 @@ def __init__( self.add_stage(self.tester.run_method_and_compare_outputs) -class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]): +class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]): """ Runs transform_for_annotation_pipeline passes directly on an exported program and checks output. @@ -775,7 +809,7 @@ def __init__( ) -class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]): +class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]): """ Runs the partitioner on a module and checks that ops are not delegated to test SupportedTOSAOperatorChecks.