Skip to content

Added more generic monkey patch function#42

Merged
lancerts merged 3 commits intomainfrom
sshimizu/monkey-patch-refactor
Aug 19, 2024
Merged

Added more generic monkey patch function#42
lancerts merged 3 commits intomainfrom
sshimizu/monkey-patch-refactor

Conversation

@shimizust
Copy link
Copy Markdown
Collaborator

@shimizust shimizust commented Aug 17, 2024

Summary

  • Added a more generic monkey patch function to be used primarily in transformers integration. Map the specified model_type to the corresponding monkey patch function.
  • Use of model_type (e.g. llama) will more broadly cover cases compared to specifying model architecture (e.g. LlamaForCausalLM, LlamaForQuestionAnswering, etc...)

Testing Done

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
jobuser [ ~/Liger-Kernel ]$ make checkstyle
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
45 files left unchanged.
jobuser [ ~/Liger-Kernel ]$ make test
pytest --disable-warnings test/ --ignore=test/convergence
===================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 114 items                                                                                                                                                                                                                                            

test/transformers/test_cross_entropy.py ..........................................................                                                                                                                                                       [ 50%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                                                                              [ 56%]
test/transformers/test_geglu.py ........                                                                                                                                                                                                                 [ 63%]
test/transformers/test_rms_norm.py ................                                                                                                                                                                                                      [ 77%]
test/transformers/test_rope.py ............                                                                                                                                                                                                              [ 87%]
test/transformers/test_swiglu.py ........                                                                                                                                                                                                                [ 94%]
test/transformers/test_trainer_integration.py ...                                                                                                                                                                                                        [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                                                                                                    [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                                                                               [100%]

================================================================================================================ 114 passed in 63.54s (0:01:03) ================================================================================================================
jobuser [ ~/Liger-Kernel ]$ make test-convergence
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
===================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                                                                                                                                              

test/convergence/test_mini_models.py ......                                                                                                                                                                                                              [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                                                                                                        [100%]

================================================================================================================= 8 passed in 92.32s (0:01:32) =================================================================================================================

@shimizust shimizust marked this pull request as ready for review August 17, 2024 08:12
Comment on lines 3 to 5
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking to only expose one generic patch instead of all individual models'..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason to keep these apply_liger_kernel_to_{model_type} functions is to provide a more well-defined interface for each model type. Users can see documentation/type hints on exactly which kernels are supported vs. the generic method.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense

@ByronHsu
Copy link
Copy Markdown
Contributor

may you elaborate

Use of model_type (e.g. llama) will more broadly cover cases compared to specifying model architecture (e.g. LlamaForCausalLM, LlamaForQuestionAnswering, etc...)

How is it related to casualLM, QA, etc?

Also, let's put the PR on hold at least for the first public release, we want to keep the public APIs intact

@shimizust
Copy link
Copy Markdown
Collaborator Author

shimizust commented Aug 19, 2024

may you elaborate

Use of model_type (e.g. llama) will more broadly cover cases compared to specifying model architecture (e.g. LlamaForCausalLM, LlamaForQuestionAnswering, etc...)

How is it related to casualLM, QA, etc?

Also, let's put the PR on hold at least for the first public release, we want to keep the public APIs intact

From my understanding, when a new model gets added to transformers there is the base model (e.g. LlamaModel) that has all the core nn.Modules. Then there are the task-specific variants like LlamaForCausalLM, LlamaForTokenClassification that reference the base model but change the head layer to accomplish a specific task.

The kernels generally are applicable to the core model layers defined in the base model. If someone wanted to train a LlamaForTokenClassification model, they would do something like:

model = LlamaForTokenClassification.from_pretrained("some_model_path", labels=...)
apply_liger_to_llama()

# Do training on the model

So by mapping liger kernel application to the model type (e.g. llama), this would cover all potential task-specific model arch variants (e.g. LlamaForCausalLM, LlamaForTokenClassification, LlamaForQuestionAnswering, etc.)

@shimizust
Copy link
Copy Markdown
Collaborator Author

may you elaborate

Use of model_type (e.g. llama) will more broadly cover cases compared to specifying model architecture (e.g. LlamaForCausalLM, LlamaForQuestionAnswering, etc...)

How is it related to casualLM, QA, etc?

Also, let's put the PR on hold at least for the first public release, we want to keep the public APIs intact

Sounds good, also this would still keep the existing APIs (see other comment) going forward

@JasonZhu1313
Copy link
Copy Markdown
Collaborator

LGTM, we haven't tested the convergence of other classes which we can add a few more in convergence tests later on, though functionality wise it should work for other class.

@lancerts lancerts merged commit 9109842 into main Aug 19, 2024
@ByronHsu ByronHsu deleted the sshimizu/monkey-patch-refactor branch August 23, 2024 06:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants