Added more generic monkey patch function#42
Conversation
| apply_liger_kernel_to_gemma, | ||
| apply_liger_kernel_to_llama, | ||
| apply_liger_kernel_to_mistral, |
There was a problem hiding this comment.
I am thinking to only expose one generic patch instead of all individual models'..
There was a problem hiding this comment.
One reason to keep these apply_liger_kernel_to_{model_type} functions is to provide a more well-defined interface for each model type. Users can see documentation/type hints on exactly which kernels are supported vs. the generic method.
|
may you elaborate
How is it related to casualLM, QA, etc? Also, let's put the PR on hold at least for the first public release, we want to keep the public APIs intact |
From my understanding, when a new model gets added to transformers there is the base model (e.g. LlamaModel) that has all the core nn.Modules. Then there are the task-specific variants like LlamaForCausalLM, LlamaForTokenClassification that reference the base model but change the head layer to accomplish a specific task. The kernels generally are applicable to the core model layers defined in the base model. If someone wanted to train a LlamaForTokenClassification model, they would do something like: So by mapping liger kernel application to the model type (e.g. |
Sounds good, also this would still keep the existing APIs (see other comment) going forward |
|
LGTM, we haven't tested the convergence of other classes which we can add a few more in convergence tests later on, though functionality wise it should work for other class. |
Summary
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence