diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 9e49f9e4e15..a9412d513c7 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -159,13 +159,27 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass def filter_fn(m, fqn): + # Check if it's a regular nn.Linear is_linear = isinstance(m, nn.Linear) + + # Check if it's a LoRALinear (which has a base weight parameter to quantize) + is_lora_linear = False + try: + from executorch.examples.models.llama.lora import LoRALinear + + is_lora_linear = isinstance(m, LoRALinear) + except ImportError: + pass + + # Check if the weight shape is compatible with group size has_shape_compatible_with_group_size = False - if is_linear: + if is_linear or is_lora_linear: has_shape_compatible_with_group_size = ( m.weight.shape[1] % group_size == 0 ) - return is_linear and has_shape_compatible_with_group_size + return ( + is_linear or is_lora_linear + ) and has_shape_compatible_with_group_size quantize_( model,