diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index cbe3f5f613d..28bb25d1cae 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import warnings as _warnings from typing import ( Any, @@ -226,6 +227,12 @@ def find_pos(self, stage_id: str): raise Exception(f"Stage id {stage_id} not found in pipeline") + def has_stage(self, stage_id: str): + try: + return self.find_pos(stage_id) >= 0 + except: + return False + def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): """Adds a stage after the given stage id.""" pos = self.find_pos(stage_id) + 1 @@ -271,7 +278,34 @@ def run(self): raise e -class TosaPipelineINT(BasePipelineMaker, Generic[T]): +class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): + + @staticmethod + def is_tosa_ref_model_available(): + """Checks if the TOSA reference model is available.""" + # Not all deployments of ET have the TOSA reference model available. + # Make sure we don't try to use it if it's not available. + try: + import tosa_reference_model + + # Check if the module has content + return bool(dir(tosa_reference_model)) + except ImportError: + return False + + def run(self): + if ( + self.has_stage("run_method_and_compare_outputs") + and not self.is_tosa_ref_model_available() + ): + _warnings.warn( + "Warning: Skipping run_method_and_compare_outputs stage. TOSA reference model is not available." + ) + self.pop_stage("run_method_and_compare_outputs") + super().run() + + +class TosaPipelineINT(TOSAPipelineMaker, Generic[T]): """ Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model. @@ -375,7 +409,7 @@ def __init__( ) -class TosaPipelineFP(BasePipelineMaker, Generic[T]): +class TosaPipelineFP(TOSAPipelineMaker, Generic[T]): """ Lowers a graph to FP TOSA spec and tests it with the TOSA reference model. @@ -629,7 +663,7 @@ def __init__( ) -class PassPipeline(BasePipelineMaker, Generic[T]): +class PassPipeline(TOSAPipelineMaker, Generic[T]): """ Runs single passes directly on an edge_program and checks operators before/after. @@ -719,7 +753,7 @@ def __init__( self.add_stage(self.tester.run_method_and_compare_outputs) -class TransformAnnotationPassPipeline(BasePipelineMaker, Generic[T]): +class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]): """ Runs transform_for_annotation_pipeline passes directly on an exported program and checks output. @@ -775,7 +809,7 @@ def __init__( ) -class OpNotSupportedPipeline(BasePipelineMaker, Generic[T]): +class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]): """ Runs the partitioner on a module and checks that ops are not delegated to test SupportedTOSAOperatorChecks.