Skip to content

[LoRA]Add LoRA model converter with dynamic class inheritance#2923

Draft
mori360 wants to merge 1 commit intopytorch:mainfrom
mori360:lora_adapter
Draft

[LoRA]Add LoRA model converter with dynamic class inheritance#2923
mori360 wants to merge 1 commit intopytorch:mainfrom
mori360:lora_adapter

Conversation

@mori360
Copy link
Copy Markdown
Contributor

@mori360 mori360 commented Apr 9, 2026

Summary

  • Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  • LoRA uses dynamic subclass creation (class swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., Float8Linear, FakeQuantizedLinear)
  • Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  • Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  • Add llama3_debugmodel_lora debug config
  • Add converter ordering validation: quantization must come before LoRA

Parallelism Compatibility

FSDP (Data Parallel) — Works out of the box. LoRA adapters (lora_a, lora_b) are registered as child nn.Modules of the wrapped linear, so fully_shard() treats them like any other submodule. Base weights are frozen (requires_grad=False) and LoRA weights are trainable — FSDP only all-gathers/reduces gradients for trainable parameters, so no wasted communication.

Pipeline Parallel (PP) — Works. LoRA doesn't change the model's module structure or the number of layers, so pipeline stage splitting is unaffected. Each stage's linear layers independently get LoRA adapters during convert().

Context Parallel (CP) — Works. CP operates on the attention module's sequence dimension, not on linear layer internals. LoRA adapters participate in the same forward pass as the base linear.

Tensor Parallel (TP) — Not supported yet. PyTorch's ColwiseParallel._partition_linear_fn (in
torch/distributed/tensor/parallel/style.py) uses named_parameters(recurse=True), which yields dotted names like
lora_a.weight from child modules. register_parameter() rejects names containing ".", causing a KeyError. RowwiseParallel
is safe (it explicitly accesses module.weight and module.bias only). This requires an upstream PyTorch fix to
ColwiseParallel._partition_linear_fn — either switching to recurse=False or explicitly accessing
module.weight/module.bias like RowwiseParallel does. Until then, LoRA should not be used with TP. The debug config
llama3_debugmodel_lora uses TP degree 1.

Expert Parallel (EP) — Works in principle. LoRA applies to nn.Linear subclasses, and MoE expert layers use
GroupedExperts (not nn.Linear), so LoRA would only wrap the shared/dense linear layers (attention projections, shared
expert FFN), not the routed experts. EP sharding of the expert dimension is unaffected.

Test Plan

Unit tests (pytest tests/unit_tests/test_model_converter.py):

  • test_lora_freeze_and_training: base params frozen, LoRA params trainable, base weights unchanged after 5 training steps
  • test_lora_filter_fqns: filter_fqns correctly skips specified modules from LoRA conversion
  • test_lora_before_quantization_raises: converter ordering validation rejects LoRA before quantization
  • Integration tests, 3D compile test (PP+DP+TP, 8 GPU), 3D compile test (PP+DP+CP, 8 GPU)
  • Loss tests, running full finetune, lora finetune(rank=8), lora finetune(rank=64) on llama3 debugmodel
loss_curves

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
@mori360
Copy link
Copy Markdown
Contributor Author

mori360 commented Apr 9, 2026

@claude review the pr

@claude
Copy link
Copy Markdown

claude bot commented Apr 9, 2026

Claude finished @mori360's task in 6m 48s —— View job


PR Review: [LoRA] Add LoRA model converter with dynamic class inheritance

  • Read changed files and gather context
  • Review torchtitan/components/lora.py
  • Review torchtitan/protocols/model_converter.py
  • Review tests/unit_tests/test_model_converter.py
  • Review torchtitan/models/llama3/config_registry.py
  • Post review feedback

Overall this is a well-structured PR. The dynamic class inheritance approach is a good design choice for composing LoRA with arbitrary nn.Linear subclasses (Float8Linear, FakeQuantizedLinear, etc.). The loss comparison shows expected behavior across full finetune, LoRA r=8, and LoRA r=64 configurations.

Below are findings organized by severity.


Issues

1. _lora_scaling is not a proper parameter/buffer — survives state_dict roundtrip?
torchtitan/components/lora.py:52self._lora_scaling = alpha / rank stores a plain float attribute. This value is not in the state_dict and will be lost on checkpoint reload. If checkpoint loading only restores lora_a.weight and lora_b.weight without re-running the converter, the scaling factor is lost. This is fine if the converter always runs before checkpoint load (which appears to be the case in the current training flow), but it's fragile and worth documenting explicitly.

Fix this →

2. filter_fqns substring matching could be overly broad
torchtitan/components/lora.py:122-124 — The filter uses any(f in fqn for f in self.filter_fqns) which is a substring match. For example, filter_fqns=["q"] would match wq, wk (if it contained 'q'), quantize, etc. This is consistent with how the quantization converter does it (torchtitan/components/quantization/utils.py:25), so it's not unique to this PR, but the docstring at line 103 should clarify this is a substring match to help users avoid surprising behavior.

3. init_weights patching doesn't chain well with multiple converters
torchtitan/components/lora.py:127-137 — The init_weights patching via object.__setattr__ captures the current init_weights at patch time. If another converter later patches init_weights the same way, only the last patch survives. Currently LoRA is the only converter that does this, but it's a latent fragility. Consider adding a comment noting this single-patch assumption.

4. No validation that alpha > 0
torchtitan/components/lora.py:106-108rank is validated to be positive, but alpha is not. A negative or zero alpha would silently produce broken scaling. Consider adding validation for alpha.

Fix this →


Suggestions (Non-blocking)

5. Consider reusing the shared module_filter_fn pattern
torchtitan/components/lora.py:121-125 — The quantization converters use module_filter_fn from torchtitan/components/quantization/utils.py for filter logic. The LoRA converter reimplements this inline. For consistency and to avoid divergence, consider extracting the shared substring-match filter to a common location (e.g., torchtitan/components/utils.py) or reusing the existing one. This isn't blocking since the quantization version also checks dimension divisibility (not applicable to LoRA).

6. _init_weight naming is potentially confusing
torchtitan/components/lora.py:71 — The method _init_weight (singular) coexists with the model-level init_weights (plural) that gets patched. This naming proximity could confuse readers. Consider _reset_lora_parameters or similar to make the scope clearer.

7. Moving _validate_quantization to __post_init__ is a behavioral change
torchtitan/protocols/model_converter.py:56-58 — The PR moves _validate_quantization from ModelConvertersContainer.__init__ to Config.__post_init__. This means validation now runs at config construction time rather than at container build time. This is generally better (fail-fast), but it's a behavioral change for existing code that constructs Config objects and later modifies converters before calling build(). I don't see such patterns in the codebase today, but it's worth noting.

8. The post_optimizer_hook is a no-op — could omit
torchtitan/components/lora.py:139-140 — Since ModelConverter is a Protocol (structural typing), you don't need to implement methods you don't use. The post_optimizer_hook no-op can be removed and the Protocol will still be satisfied structurally. However, keeping it is fine for explicitness.


Positive Notes

  • The dynamic class inheritance via __class__ swap is the right approach for composing LoRA with arbitrary Linear subclasses. The class cache at module level prevents redundant class creation.
  • Using _LoRALinear = Module.from_nn_module(nn.Linear) ensures LoRA adapter submodules satisfy the Module protocol, avoiding verify_module_protocol failures.
  • The converter ordering validation (_validate_converter_ordering) correctly enforces that quantization runs before LoRA, with a clear error message explaining why.
  • The list(module.named_modules()) snapshot at line 121 correctly avoids iterating over newly-created lora_a/lora_b children during mutation.
  • Tests cover the key behaviors: freeze correctness, filter_fqns, and ordering validation.
  • Loss curves show expected behavior: full finetune > LoRA r=64 > LoRA r=8 in convergence quality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant