Skip to content

Standalone Custom Tokens Tuner and integrated into LoRA#2376

Merged
githubnemo merged 92 commits intohuggingface:mainfrom
githubnemo:feature/custom-token-tuner
Feb 26, 2025
Merged

Standalone Custom Tokens Tuner and integrated into LoRA#2376
githubnemo merged 92 commits intohuggingface:mainfrom
githubnemo:feature/custom-token-tuner

Conversation

@githubnemo
Copy link
Copy Markdown
Collaborator

@githubnemo githubnemo commented Feb 13, 2025

This PR is based on the nifty addition of @marcusinthesky from #1541.

I took the liberty of bringing the branch up-to-date and do concept where we not only have CustomTokens as a PEFT method for selectively re-training tokens but also have a trainable_token_indices parameter in LoRA to combine both approaches (and possibly with other methods in the future).

What is this

When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values.

How to use this

Two possibilities:

  1. use the CustomTokens PEFT method
peft_config = CustomTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2])
peft_model = get_peft_model(model, peft_config)
  1. use in conjunction with LoRA
peft_config = LoraConfig(
    target_modules='all-linear',
    trainable_token_indices={'embed_tokens': [0, 1, 2]},
)
peft_model = get_peft_model(model, peft_config)

Implementation details

This is an early draft since I found no better way of implementing it without touching the modules_to_save infrastructure.

The idea is to abstract the ModulesToSaveWrapper into an AuxiliaryTrainingWrapper that allows for more functionality than simply setting requires_grad_(True) on specific modules and saving them alongside other modules. There are now three classes,

  • AuxiliaryTrainingWrapper the base class that provides a common interface for wrapping modules and forwarding getattr/forward calls from said modules
  • ModulesToSaveWrapper is the same as before but extended by having a method to get the state dict from the wrapped models for the given adapter so that we know which modules to save without having to match the state dict names
  • NewTokensWrapper is a thin wrapper around CustomTokensLayer that can be applied to layers specified by the trainable_token_indices parameter from LoraConfig (and others in the future)

To load and save these modules we iterate over the model's named_modules to filter all AuxiliaryTrainingWrapper instances, get their state dicts and - depending on load or save - read adapter-specific names and write them out to be adapter-less or vice versa. In theory this should handle saving modules_to_save as well as trainable_token_indices but that's one point that needs verification and careful review.

Things that I did not explicitly address as of yet:

  • I'm unsure about how weight-tying comes into play here, writing an explicit test for this is one of my immediate next steps but I think that it should be fine as long as we restore the embedding weight matrix properly
  • get_peft_model_state_dict will probably also mark the embedding layer as target since it is a valid embedding layer name. this is useless. we could prevent this by overriding the default setting for save_embedding_layers but unsure if it is a good idea. We could also just tell the user that they can delete the weights if they want to. Not sure about this yet.

Open tasks

  • add documentation and example on how to use this method (with LoRA and standalone)
  • Add custom model tests with TrainableTokensConfig being used directly
  • Add test that runs on GPU
  • allow for different token_indices per adapter
  • tests that cover having multiple custom token tuners at once
  • tests that cover having multiple targets for custom tokens

Marcus Gawronsky and others added 17 commits March 6, 2024 14:56
This change makes it possible to combine the `CustomTokens` tuner
with LoRA (and potentially other) tuners.
Particularly interesting is the method for enabling adapters which
now needs to check for `AuxiliaryTrainingWrapper` instead of
`ModulesToSaveWrapper`. This is something that ought to be done
for each tuner that does this type of enabling.
This will probably be moved to somewhere else but these are necessary
for development so they can live here for now.
It turns out that it is more common than I thought for the embedding
layer to be called something else so we need to support dictionary
inputs to the `trainable_token_indices` parameter.
It was too late to make that change.
There's a dependency of `super().__init__()` on `.update()` but
the latter depends on an attribute that is set in the child class.

Therefore initialization of that attribute now happens in `.update()`
which is not ideal but better than changing the parent class even
more.
In theory there are now two parts that handle modules to save so
a next step is to see if there are conflicts between the two.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@marcusinthesky
Copy link
Copy Markdown

Thanks for the shout-out. Looks super cool.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for picking up this feature, having this should be very useful for many PEFT users.

At this point, I have only done a quick review to iterate fast. In addition to my inline comments, I have some more general points:

  1. Naming: I wonder if "custom tokens" is the right name for the feature. WDYT about "extra tokens", "trainable tokens", or "additional tokens"? Or even something like "sparse embedding update" or so? Let's just hold on a for a sec and ensure we find the best name, as we can't change it once the feature is out.
  2. I'm just wondering out loud about whether a sparse matrix is the best way to implement this. If we have a high embedding dimension, the matrix will contain a lot of items, not sure if this could be inefficient. If it's implemented as a dense matrix (of course, only for the relevant columns), would that be possible? Perhaps via usage of index_add or scatter_add. I haven't investigated this option, just throwing some ideas out there.
  3. Let's try to address as many TODOs as possible before merging or else they tend to stick around.
  4. Did you run any realistic tests to ensure that this saves memory and reduces file size? I can help with that.
  5. We should have updates to the docs and examples to show the standalone version and the LoRA integration. It would be fine to do that in a separate PR after this one but ideally it'll be added here.
    Regarding testing:

Python 3.9 tests are failing because the foo: bar | baz type annotation syntax is not yet supported. Please add a from __future__ import annotations import where necessary.

Moreover, let's add a test case to test_custom_models.py, like here:

("Vanilla MLP 5 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),

This should result in nice bump in test coverage.

Comment thread src/peft/peft_model.py
Comment thread src/peft/peft_model.py Outdated
if target_layer in self.modules_to_save:
raise ValueError(
"The embedding layer is already marked to be trained fully, either specify "
f'`modules_to_save=[..., "{target_layer}", ...]` or `trainable_tokens=x` but not both.'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace x in the message with target_layer?

Comment thread src/peft/tuners/custom_tokens/config.py Outdated

@dataclass
class CustomTokensConfig(PeftConfig):
token_indices: List[int] = field(default_factory=list)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a help here too and a docstring for the config.

Comment thread src/peft/tuners/custom_tokens/layer.py Outdated
Comment on lines +58 to +60
values = torch.rand(
(self.num_trainable_embeddings * self.base_layer.weight.shape[-1],)
) # we initialize the values from a normal distribution N(0, 1), as in https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but could we not take the values from the actual embedding matrix and use a copy of those to initialize the weights?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm wrong, since this would add the same value twice, right? So to keep the initial values, this would have to be zeros.

I'm wondering: If a user extends the embedding for new tokens and then uses the custom token tuner, after training time when they load the model, they would need to ensure that when they resize the embedding again, the seed is exactly the same right? Or else they would need to save the original embedding, but that would mean much larger file sizes for the adapter, which we want to avoid.

This is okay I guess, but not super user friendly. For instance, with modules_to_save, we don't have this issue as the adapter weights contain all the info we need (but of course it's a full copy of the original weights, so quite large).

In an ideal world, the extra params for the custom tokens would replace to params of the original dict, so that users won't have to worry about restoring those. Not sure if that's possible. As an alternative, I wonder if we can save a checksum of the weights that are being replaced as a buffer and then, when loading, we can raise an error if the checksum does not match?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the initialization of the delta values to zero. I don't think that it is important to initialize them randomly as they're probably used in different contexts so random is not important. This is also nice because most tests assume that initializing a PEFT model does not change the parameters.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@githubnemo I agree.

As Reparameterized PEFT aims train the delta's, initializing at zero-may help when the user accidentally initializes tokens which are not in their post-training/FT corpora. This may also be a safer option with 'Token Drag'.

Comment thread src/peft/tuners/custom_tokens/layer.py Outdated
orig_weights += self.sparse_delta_tokens[active_adapter]

if safe_merge and not torch.isfinite(orig_weights).all():
raise ValueError(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this fails, the original weights have still been mutated, right? The idea of safe_merge is that if it fails, the model stays in its original state. It is acceptable if that means we need to create a copy in case of safe_merge=True, but for safe_merge=False, copies should be avoided.

Comment thread tests/test_custom_tokens.py Outdated
Comment on lines +43 to +44
output_mod = peft_model.forward(output_hidden_states=True, **X)
output_org = original_model.forward(output_hidden_states=True, **X)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_mod = peft_model.forward(output_hidden_states=True, **X)
output_org = original_model.forward(output_hidden_states=True, **X)
output_mod = peft_model(output_hidden_states=True, **X)
output_org = original_model(output_hidden_states=True, **X)

I was also confused about _org. Maybe change to _orig?

Comment thread tests/test_custom_tokens.py Outdated
"input_ids": torch.tensor([[0, 1, 2, 3]]),
"attention_mask": torch.tensor([[1, 1, 1, 1]]),
}
output_trn = peft_model.forward(output_hidden_states=True, **X)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_trn?

Comment thread tests/test_custom_tokens.py Outdated
assert not torch.allclose(W_mod[:, :3], W_org[:, :3])
assert torch.allclose(W_mod[:, 3:], W_org[:, 3:])

def test_combined_with_lora_usage(self, model, tokenizer, tmp_path):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to refactor the test to avoid most duplication? Especially if we plan on supporting other PEFT methods too.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrized peft_config. The model is not parametrized as we can also test such combinations in testing common at a later point.

Comment thread src/peft/peft_model.py
Comment thread tests/test_custom_tokens.py Outdated

def test_stand_alone_usage(self, model, tokenizer, tmp_path):
original_model = copy.deepcopy(model)
peft_config = CustomTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 2])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: Would the test cover corner cases a little better if token_indices were not consecutive tokens starting at 0? So e.g. [1, 3] instead?

Comment thread tests/test_custom_tokens.py Outdated
assert torch.allclose(W_mod, W_trn)

assert not torch.allclose(W_mod[:, :3], W_org[:, :3])
assert torch.allclose(W_mod[:, 3:], W_org[:, 3:])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also ensure that there are tests that cover:

  • multiple targets for custom tokens
  • having multiple custom token tuners at once

@githubnemo
Copy link
Copy Markdown
Collaborator Author

githubnemo commented Feb 13, 2025

  1. Naming: I wonder if "custom tokens" is the right name for the feature. WDYT about "extra tokens", "trainable tokens", or "additional tokens"? Or even something like "sparse embedding update" or so? Let's just hold on a for a sec and ensure we find the best name, as we can't change it once the feature is out.

Agreed. A part of me wants it to be more general but I think that TrainableTokens is general enough without being too specific. I'll change it.

  1. I'm just wondering out loud about whether a sparse matrix is the best way to implement this. If we have a high embedding dimension, the matrix will contain a lot of items, not sure if this could be inefficient. If it's implemented as a dense matrix (of course, only for the relevant columns), would that be possible? Perhaps via usage of index_add or scatter_add. I haven't investigated this option, just throwing some ideas out there.

Naïvely I would expect that the sparse implementation is able to cope with this but I agree, it is not certain that this is the best way for implementing it (or the best single way, depending on the conditions). Let's skip this discussion before we don't have a benchmark in place.

  1. Let's try to address as many TODOs as possible before merging or else they tend to stick around.

Yep. Most of the TODOs are points where I was unsure about how to proceed before the initial review(s). Getting on these now.

  1. Did you run any realistic tests to ensure that this saves memory and reduces file size? I can help with that.

Nope, just functional tests. It would be great if you could do a bit of benchmarking, especially with the points from above regarding efficiency with larger embedding sizes.

  1. We should have updates to the docs and examples to show the standalone version and the LoRA integration. It would be fine to do that in a separate PR after this one but ideally it'll be added here.

Yes, adding it as a to do item in the PR description.

@BenjaminBossan
Copy link
Copy Markdown
Member

Agreed. A part of me wants it to be more general but I think that TrainableTokens is general enough without being too specific. I'll change it.

👍

It would be great if you could do a bit of benchmarking, especially with the points from above regarding efficiency with larger embedding sizes.

I'll do a comparison with what would be the current approach, adding the embedding to modules_to_save. I plan to check that tomorrow.

nemo added 5 commits February 13, 2025 17:50
Merge onto the base weights only after checks have completed.
Refactor PEFT method as parameter and use non-consecutive indices for testing the layer modification
Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a pass on the testing and doc changes specifically. Overall looks good, just some smaller comments.

Comment thread docs/source/developer_guides/lora.md Outdated

## Efficiently train tokens alongside LoRA

Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of specific tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched which saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the [`~TrainableTokenLayer`].
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of specific tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched which saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the [`~TrainableTokenLayer`].
Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of specific tokens alongside fine-tuning of other layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched. This saves memory and doesn't throw away learned context of existing token embeddings in contrast to training the whole embedding matrix. Under the hood this method uses the [`~TrainableTokenLayer`].

A bit more readable, WDYT?

Comment thread docs/source/developer_guides/lora.md Outdated
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})

# make room for new tokens in the embedding matrix
base_model.resize_token_embeddings(len(tokenizer))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should change this to:

Suggested change
base_model.resize_token_embeddings(len(tokenizer))
base_model.resize_token_embeddings(max(len(tokenizer), base_model.model.embed_tokens.num_embeddings)

For this specific model, it makes no difference. However, for some models the embedding matrix is actually larger than the vocab size (e.g. so that it's size is a multiple of some power of 2). See e.g. the Qwen models. Thus, len(tokenizer.vocab) could be smaller than the embedding size, even after adding new tokens. In that case, transformers actually shrinks the embedding, which is not a good idea most of the time.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good point. It is not widely known, I think, that this is a thing. Makes it even more important.

Comment thread docs/source/developer_guides/lora.md
Comment thread docs/source/developer_guides/lora.md
peft_model = get_peft_model(base_model, lora_config)

# proceed to train the model like normal
[...]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some results from my training script (all with LoRA rank 32 and bfloat16):

  1. modules_to_save=['embed_tokens']

cuda memory avg: 15038MB
cuda memory max: 16316MB
total time: 10.81s
file size of checkpoint: 302.0MB

  1. LoRA on embedding

cuda memory avg: 14056MB
cuda memory max: 15581MB
total time: 9.75s
file size of checkpoint: 306.4MB

  1. Trainable tokens (6 indices)

cuda memory avg: 14039MB
cuda memory max: 15562MB
total time: 9.02s
file size of checkpoint: 52.1MB

It's not a huge saving in terms of VRAM, but it can make a difference.

# Trainable Tokens

The Trainable Tokens method provides a way to target specific token embeddings for fine-tuning without resorting to
training the full embedding matrix or using a low-rank adapter. It is based on the initial implementation from
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
training the full embedding matrix or using a low-rank adapter. It is based on the initial implementation from
training the full embedding matrix or using an adapter on the embedding matrix. It is based on the initial implementation from

To make it less LoRA specific.


Some preliminary benchmarks acquired with [this script](https://github.com/huggingface/peft/blob/main/scripts/train_memory.py)
suggest that for `gemma-2-2b` (which has a rather large embedding matrix) you can save 4.8GiB VRAM with Trainable Tokens
over fully fine-tuning. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
over fully fine-tuning. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target
over fully fine-tuning the embedding matrix. While LoRA will use even less memory (-6.3GiB total over fine-tuning) it might also target

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the check again (all with LoRA rank 32 and bfloat16):

  1. modules_to_save=['embed_tokens']

cuda memory avg: 9621MB
cuda memory max: 10880MB
total time: 11.78s
file size of checkpoint: 1149.4MB

  1. LoRA on embedding

cuda memory avg: 5245MB
cuda memory max: 6988MB
total time: 9.60s
file size of checkpoint: 1180.9MB

  1. Trainable tokens (6 indices)

cuda memory avg: 5117MB
cuda memory max: 6890MB
total time: 10.28s
file size of checkpoint: 24.4MB

So LoRA on embedding vs trainable tokens is pretty much on par when it comes to VRAM.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gemma2 2b again?

Comment thread tests/test_gpu_examples.py Outdated
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.single_gpu_tests
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You put this test into wrong class, which results in it trying to load a GPTQ quantized base model. Please put it in the previous test class at line ~1393.

Comment thread tests/test_gpu_examples.py Outdated
)

model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute is undefined, you can use "facebook/opt-350m".

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's from being in the wrong test class.

Comment thread tests/test_gpu_examples.py Outdated
)

# add 2 new tokens
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

@githubnemo
Copy link
Copy Markdown
Collaborator Author

Thanks for the review :) Addressed your comments.

Weight-tying will be handled in a follow-up PR: #2399

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, really thorough PR and it should be helpful to many users. I have noticed a few minor issues still with the docs, up to you if you want to fix them. Anyway, feel free to merge once the CI is green.


Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the
embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify.
This method can also be used in conjunction with LoRA layers! See [`~peft.LoraConfig.trainable_token_indices`].
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link does not appear to be working :/

image

Comment thread docs/source/developer_guides/lora.md Outdated

## Efficiently train tokens alongside LoRA

Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of other tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched. This saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the [`~TrainableTokenLayer`].
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link is also broken :/

image

The core module from Megatron to use, defaults to `"megatron.core"`.
trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`)
Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole
embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

Comment thread docs/source/developer_guides/lora.md Outdated
Sometimes it is necessary to not only change some layer's weights but to add new tokens as well. With larger models this can be a memory-costly endeavour. PEFT LoRA adapters support the `trainable_token_indices` parameter which allows tuning of other tokens alongside fine-tuning of specific layers with LoRA. This method only trains the tokens you specify and leaves all other tokens untouched. This saves memory and doesn't throw away learned context of existing token embeddings in contrast to when training the whole embedding matrix. Under the hood this method uses the [`~TrainableTokenLayer`].

```py
# for layer 'embedding'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# for layer 'embedding'
# for layer 'embed_tokens'

@githubnemo githubnemo merged commit f51203f into huggingface:main Feb 26, 2025
githubnemo added a commit that referenced this pull request Mar 6, 2025
This is a follow-up PR of #2376 to add support for weight-tying.

Some models, such as gpt2, tie the weights between the LM head and the input embeddings for various reasons. If we use the trainable tokens adapter, we're changing the result of the forward() of the input embeddings but we do not change the weights (unless we merge()). This means that the changes are not reflected in the tied weights, such as the LM head, leading to wrong results when training.

The current approach is searching for tied layers and putting TrainableTokensLayer adapters on them as well but initialized to use the parameters from the embedding layer's TrainableTokensLayer. This is done via the tied_adapter argument of TrailableTokensLayer.__init__().

Notable other changes:

* Implement weight-tying for encoder-decoder models

Notably we are removing the duplication filter of `named_modules` when searching for
the (tied) target modules since tied weights are by definition duplicates.

* Implement embedding name inference

It's now possible to let the adapter decide which is the input embedding layer based on the output
of `model.get_input_embeddings()`. If that fails, the default is still `embed_tokens`.

* Refactor getattr in AuxiliaryTrainingWrapper

Before this change only the selection of the module that was supposed to have the queried
attribute was given to the wrapper implemention (via `_{has,get}attr_wrapped`). Now the full
`getattr()` call is done by the implementation.

This change is motivated by the need for access to `embedding.weight` at certain times which,
for `ModulesToSaveWrapper` is not a problem - but it is for `TrainableTokensWrapper` since
the original module's weights differ from the current weights, at least potentially.

What we do now is to merge the weights and return those when `embedding.weight` is accessed.
No other attributes are currently forwarded.

* initialization from buffers was broken since `persistent` flag was set too late
  (update() is called before setting the flag)

* update from other BufferDict was broken since it was assumed that BufferDict was
  a mapping collection object. we cannot simply change it to a Mapping since it
  then will break pytorch code which assumes that modules are hashable.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
…2376)

This change is based on the nifty addition of @marcusinthesky from huggingface#1541.

When adding tokens or fine-tuning the representation of specific tokens we currently have little choice but to retrain the whole embedding matrix which can be huge and adds to the memory footprint (in RAM but also on disk). This method creates a sparse matrix of shape (n, embed_dim) where n is the number of tokens to be customized and only trains these few values.

This change introduces two ways of using it:

```
peft_config = TrainableTokensConfig(target_modules=['embed_tokens'], token_indices=[0, 1, 2])
peft_model = get_peft_model(model, peft_config)
```

and with LoRA

```
peft_config = LoraConfig(
    target_modules='all-linear',
    trainable_token_indices={'embed_tokens': [0, 1, 2]},
)
peft_model = get_peft_model(model, peft_config)
```

Adding this feature to adapters other than LoRA should be relatively easy, mostly adding the `trainable_token_indices` config option and some debugging.

To make this change it was necessary to change the `modules_to_save` infrastructure as combining this feature with LoRA is quite similar. This refactoring entailed moving most of the basic functionality of `ModulesToSave` to the `AuxiliaryTrainingWrapper` class. This also changes the logic how `modules_to_save` is loaded/saved from from the state dict, so there could still be bugs here.

This implementation does not entail support for weight-tied layers yet. This will follow in a future change.

---

Notable commits in this squash:

* Use unload_and_optionally_merge_module protocol

With `AuxiliaryTrainingWrapper` as abstraction it is probably a good idea to
have support for `unload_and_optionally_merge_module`.

Since the wrapper is more akin to a PEFT layer than a model the name semantics
are fine and it does basically the same job.

* trainable tokens is also trained in certain adapters

Before, the assumption was that modules_to_save was the only thing that
is trained alongside an adapter's parameters. Now there's also the
token_adapter delta tokens via `NewTokensWrapper`.

* Remove old modules_to_save handling

This is now all handled via the `AuxiliaryTrainingWrapper`.

* Fix modules_to_save module overwriting

The state dict imlementation of ModulesToSaveWrapper was incorrect in that
it did not include its own parameters, just the parameters it needs to overwrite
in the end. I.e. if layer `lin1` is modules to save wrapped,
`lin1.{weight,bias}` is saved and overwritten but `lin1.modules_to_save.<adpater_name>.[...]`
is not saved.

* Introduce a load key map for aux. train wrapper

Before this change it was only possible to remove a key prefix from the wrapper's
state dict (e.g., `modules_to_save.default.weight` -> `weight`); now it is possible
to restore such reduced value by mapping the key back
(i.e., `weight` -> `modules_to_save.default.weight`).

* Replace sparse matrix with dense + index_copy

This change is mostly because sparse matrices are not that beneficial in this case
(at least not from what we can see right now) and they do not solve the problem
of having to change the new tokens in-place to avoid outdated deltas when new token
vectors are initialized randomly after loading the deltas.

* Make peft_config.layers_to_transform optional

Before this change the base tuner class was forcing this attribute
to be present on the config class even though the attribute is not
specified in the base config.

* Implement missing key logic in `_set_trainable`

Before this it was not checked if the targeted module by `modules_to_save` or `trainable_token_indices` existed
or not (when used in conjunction with a PEFT method). In this case an error message similar to the `inject_adapter`
error is raised when no module is found.

---------

Co-authored-by: Marcus Gawronsky <marcus.g@myrunway.co.za>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
This is a follow-up PR of huggingface#2376 to add support for weight-tying.

Some models, such as gpt2, tie the weights between the LM head and the input embeddings for various reasons. If we use the trainable tokens adapter, we're changing the result of the forward() of the input embeddings but we do not change the weights (unless we merge()). This means that the changes are not reflected in the tied weights, such as the LM head, leading to wrong results when training.

The current approach is searching for tied layers and putting TrainableTokensLayer adapters on them as well but initialized to use the parameters from the embedding layer's TrainableTokensLayer. This is done via the tied_adapter argument of TrailableTokensLayer.__init__().

Notable other changes:

* Implement weight-tying for encoder-decoder models

Notably we are removing the duplication filter of `named_modules` when searching for
the (tied) target modules since tied weights are by definition duplicates.

* Implement embedding name inference

It's now possible to let the adapter decide which is the input embedding layer based on the output
of `model.get_input_embeddings()`. If that fails, the default is still `embed_tokens`.

* Refactor getattr in AuxiliaryTrainingWrapper

Before this change only the selection of the module that was supposed to have the queried
attribute was given to the wrapper implemention (via `_{has,get}attr_wrapped`). Now the full
`getattr()` call is done by the implementation.

This change is motivated by the need for access to `embedding.weight` at certain times which,
for `ModulesToSaveWrapper` is not a problem - but it is for `TrainableTokensWrapper` since
the original module's weights differ from the current weights, at least potentially.

What we do now is to merge the weights and return those when `embedding.weight` is accessed.
No other attributes are currently forwarded.

* initialization from buffers was broken since `persistent` flag was set too late
  (update() is called before setting the flag)

* update from other BufferDict was broken since it was assumed that BufferDict was
  a mapping collection object. we cannot simply change it to a Mapping since it
  then will break pytorch code which assumes that modules are hashable.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants