diff --git a/exir/backend/test/backend_with_delegate_mapping_demo.py b/exir/backend/test/backend_with_delegate_mapping_demo.py index 99d44d7c758..61910fbc1e5 100644 --- a/exir/backend/test/backend_with_delegate_mapping_demo.py +++ b/exir/backend/test/backend_with_delegate_mapping_demo.py @@ -159,21 +159,51 @@ def preprocess( @staticmethod # The sample model that will work with BackendWithDelegateMapping show above. def get_test_model_and_inputs(): - class ConvReLUAddModel(nn.Module): + class SimpleConvNet(nn.Module): def __init__(self): - super(ConvReLUAddModel, self).__init__() - # Define a convolutional layer - self.conv_layer = nn.Conv2d( - in_channels=1, out_channels=64, kernel_size=3, padding=1 + super(SimpleConvNet, self).__init__() + + # First convolutional layer + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + + # Second convolutional layer + self.conv2 = nn.Conv2d( + in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1 ) + self.relu2 = nn.ReLU() + + def forward(self, x): + # Forward pass through the first convolutional layer + x = self.conv1(x) + x = self.relu1(x) + + # Forward pass through the second convolutional layer + x = self.conv2(x) + x = self.relu2(x) + + return x + + class ConvReLUTanModel(nn.Module): + def __init__(self): + super(ConvReLUTanModel, self).__init__() + + # Define a convolutional layer + self.conv_layer = SimpleConvNet() def forward(self, x): # Forward pass through convolutional layer conv_output = self.conv_layer(x) - # Apply ReLU activation - relu_output = nn.functional.relu(conv_output) - # Perform tan on relu output - added_output = torch.tan(relu_output) - return added_output - return (ConvReLUAddModel(), (torch.randn(1, 1, 32, 32),)) + # Perform tan on conv_output + tan_output = torch.tan(conv_output) + + return tan_output + + batch_size = 4 + channels = 3 + height = 64 + width = 64 + return (ConvReLUTanModel(), (torch.randn(batch_size, channels, height, width),)) diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index 6a2bbd60af4..5c173feacaa 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -121,12 +121,26 @@ def test_backend_with_delegate_mapping(self) -> None: debug_handle_map = lowered_module.meta.get("debug_handle_map") self.assertIsNotNone(debug_handle_map) # There should be 3 backend ops in this model. - self.assertEqual(len(debug_handle_map), 4) + self.assertEqual(len(debug_handle_map), 5) # Check to see that all the delegate debug indexes in the range [0,2] are present. self.assertTrue( all(element in debug_handle_map.keys() for element in [0, 1, 2, 3]) ) - lowered_module.program() + + class CompositeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.lowered_module = lowered_module + + def forward(self, x): + return self.lowered_module(x) + + composite_model = CompositeModule() + # TODO: Switch this to lowered_module.program() once lowered_module has support + # for storing debug delegate identifier maps. + exir.capture( + composite_model, inputs, exir.CaptureConfig() + ).to_edge().to_executorch() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/runtime/executor/test/test_backend_with_delegate_mapping.cpp b/runtime/executor/test/test_backend_with_delegate_mapping.cpp index a7313e7e1e8..9d5bb6156fb 100644 --- a/runtime/executor/test/test_backend_with_delegate_mapping.cpp +++ b/runtime/executor/test/test_backend_with_delegate_mapping.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include /* strtol */ #include @@ -125,22 +126,23 @@ class BackendWithDelegateMapping final : public PyTorchBackendInterface { "Op name = %s Delegate debug index = %ld", op_list->ops[index].name, op_list->ops[index].debug_handle); + event_tracer_log_profiling_delegate( + context.event_tracer(), + nullptr, + op_list->ops[index].debug_handle, + 0, + 1); + /** + If you used string based delegate debug identifiers then the profiling + call would be as below. + event_tracer_log_profiling_delegate( + context.event_tracer(), + pointer_to_delegate_debug_string, + -1, + 0, + 1); + */ } - // The below API's are not available yet but they are a representative - // example of what we'll be enabling. - /* - Option 1: Log performance event with an ID. An integer ID must have been - provided to DelegateMappingBuilder during AOT compilation. - */ - // EVENT_TRACER_LOG_DELEGATE_PROFILING_EVENT_ID(op_list->ops[index].debug_handle, - // start_time, end_time); - /* - Option 2: Log performance event with a name. A string - name must have been provided to DelegateMappingBuilder during AOT - compilation. - */ - // EVENT_TRACER_LOG_DELEGATE_PROFILING_EVENT_NAME(op_list->ops[index].name, - // start_time, end_time); return Error::Ok; }