Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/xnnpack/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
from typing import List, Optional, Type

from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
Expand All @@ -29,7 +29,9 @@

class XNNPACKPassManager:
def __init__(
self, exported_program: ExportedProgram, passes: Optional[List[PassType]] = None
self,
exported_program: ExportedProgram,
passes: Optional[List[Type[PassType]]] = None,
) -> None:
"""
A helper class to run multiple XNNPack passes on a program
Expand Down
7 changes: 4 additions & 3 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ python_unittest(

python_unittest(
name = "test_xnnpack_passes",
srcs = [
srcs = glob([
"passes/*.py",
]) + [
"test_xnnpack_passes.py",
"test_xnnpack_utils_classes.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/xnnpack/passes:xnnpack_passes",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
Expand Down Expand Up @@ -127,9 +130,7 @@ python_unittest(
]),
deps = [
"//caffe2:torch",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/exir:lib",
"//pytorch/vision:torchvision",
],
)
57 changes: 57 additions & 0 deletions backends/xnnpack/test/passes/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple

import torch
from executorch.backends.xnnpack.passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestBatchNormFusion(unittest.TestCase):
PassStage = RunPasses([FuseBatchNormWithConvPass])
bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"

class ModelConvBN(torch.nn.Module):
def __init__(
self, in_features: int, out_features: int, kernel_size: Tuple[int, int]
):
super().__init__()
self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
self.bn = torch.nn.BatchNorm2d(out_features)

def forward(self, x):
y = self.conv2d(x)
y = self.bn(y)
y = self.conv2d(y)
y = y + y
return self.bn(y)

def test_fp32_batch_norm_fusion(self):
(
Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),))
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count({self.bn_name: 1})
.run_method()
.compare_outputs()
)

def test_q8_batch_norm_fusion(self):
(
Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),))
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count({self.bn_name: 1})
.run_method()
.compare_outputs()
)
180 changes: 180 additions & 0 deletions backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
)
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
OpSequencesAddConv2d,
)
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestChannelsLastTaggedReshapePass(unittest.TestCase):
PassStage = RunPasses([ChannelsLastTaggedReshapePass])
# Dictionary mapping modules to expected number of reshapes
modules = {
OpSequencesAddConv2d(0, 0).eval(): 0,
OpSequencesAddConv2d(1, 1).eval(): 2,
OpSequencesAddConv2d(2, 2).eval(): 2,
}
to_copy_name = "executorch_exir_dialects_edge__ops_aten__to_copy_default"
quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"

def test_fp32_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
(
Tester(module, (torch.randn(1, 1, 6, 6),))
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count(
{
self.to_copy_name: num_reshape,
}
)
.run_method()
.compare_outputs()
)

def test_qs8_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
(
Tester(module, (torch.randn(1, 1, 6, 6),))
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.check(
[
self.quant_name,
self.dequant_name,
self.to_copy_name,
self.quant_name,
self.dequant_name,
]
* num_reshape
)
.run_method()
.compare_outputs()
)

class ConvRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))

def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self):
(
Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
.export()
.to_edge()
.run_passes(self.PassStage)
.check(
[self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name]
)
.run_method()
.compare_outputs()
)

def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self):
(
Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.check(
[
self.to_copy_name,
self.quant_name,
self.dequant_name,
self.conv_name,
self.relu_name,
self.quant_name,
self.dequant_name,
self.to_copy_name,
]
)
.run_method()
.compare_outputs()
)

class Conv2dBnHardtanhMeanSequenceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=(3, 3),
stride=[2, 2],
padding=[1, 1],
groups=1,
dilation=[1, 1],
bias=True,
)
self.native_batchnorm = torch.nn.BatchNorm2d(1)
self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
self.eval()

def forward(self, x):
x = self.conv(x)
x = self.native_batchnorm(x)
x = self.hardtanh(x)
x = torch.mean(x, (-1, -2), keepdim=True)
return x

def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
# Copy #1 is for input to conv, nchw -> nhwc
# Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw
# Copy #3 is for input to mean, nchw -> nhwc
# Copy #4 is for output, nhwc -> nchw

# The graph looks like:
# graph():
# %arg0_1 : [#users=1] = placeholder[target=arg0_1]
# %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last})
# %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
# %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
# %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
# %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format})
# %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
# %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
# %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
# %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {})
# %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {})
# %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last})
# %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {})
# %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format})
# return [aten__to_copy_default_3]
(
Tester(
self.Conv2dBnHardtanhMeanSequenceModule().eval(),
(torch.randn(1, 1, 6, 6),),
)
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count(
{
self.to_copy_name: 4,
}
)
.run_method()
.compare_outputs()
)
100 changes: 100 additions & 0 deletions backends/xnnpack/test/passes/test_remove_get_item_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestRemoveGetItemPass(unittest.TestCase):
PassStage = RunPasses([RemoveGetItemPass])
max_pool2d_name = "executorch_exir_dialects_edge__ops_aten_max_pool2d_default"
amax_name = "executorch_exir_dialects_edge__ops_aten_amax_default"

class MaxPool2dModule(torch.nn.Module):
def __init__(
self,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)

def forward(self, x):
return self.max_pool2d_module(x)

def test_fp32_max_pool2d_remove_getitem(self):
(
Tester(self.MaxPool2dModule(), (torch.randn(4, 3, 24, 24),))
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count({self.max_pool2d_name: 1})
.run_method()
.compare_outputs()
)

def test_q8_max_pool2d_remove_getitem(self):
(
Tester(self.MaxPool2dModule(), (torch.randn(4, 3, 24, 24),))
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count({self.max_pool2d_name: 1})
.run_method()
.compare_outputs()
)

class MaxModule(torch.nn.Module):
def __init__(
self,
):
super().__init__()

def forward(self, x):
max_vals, indices = torch.max(x, dim=2, keepdim=True)
return max_vals

def test_fp32_max_remove_getitem(self):
(
Tester(self.MaxModule(), (torch.randn(4, 3, 24, 24),))
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count(
{
self.amax_name: 1,
}
)
.run_method()
.compare_outputs()
)

def test_q8_max_remove_getitem(self):
(
Tester(self.MaxModule(), (torch.randn(4, 3, 24, 24),))
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count(
{
self.amax_name: 1,
}
)
.run_method()
.compare_outputs()
)
Loading