Skip to content

Add support for uint8_t as data type for GatherBlockQuantized#24239

Merged
sushraja-msft merged 8 commits intomainfrom
user/sushraja/gather_dequantize
Apr 4, 2025
Merged

Add support for uint8_t as data type for GatherBlockQuantized#24239
sushraja-msft merged 8 commits intomainfrom
user/sushraja/gather_dequantize

Conversation

@sushraja-msft
Copy link
Contributor

@sushraja-msft sushraja-msft commented Mar 28, 2025

Description

This change adds support for GatherBlockQuantized to use uin8_t as data's type with the same semantics as MatMulNBits. Zero_Points and Gather Axis other than 0 are not yet supported, in order to keep the change scoped.

Motivation and Context

With the newer llama models like Phi4 trained with shared embeddings, the weights of the lm_head matrix and the embeddings table are exactly the same. These embeddings are huge, unquantized embeddings are 1.2GB in Phi4 mini instruct, at int4 quantization the weights are still 300MB. We can go a step further and have these two ops the lm_head matmulnbits and GatherBlockQuantized share the same weights, that would save 300MB on the model size.

The two things that hinder that are the shape expectations for GatherBlockQuantized and the data type supported for data in GatherBlockQuantized. The shape can be solved via a simple reshape op, but the data type needs code changes and that is what this change does.

Here is Phi4 modified with shared weights between lm_head and matmulnbits, this model is just 2.1GB on disk.
image

@sushraja-msft sushraja-msft requested a review from guschmue March 28, 2025 19:30
@sushraja-msft sushraja-msft marked this pull request as ready for review March 28, 2025 19:30
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

sushraja-msft and others added 2 commits March 31, 2025 10:58
…d.cc

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…d.cc

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
guschmue
guschmue previously approved these changes Mar 31, 2025
@guschmue
Copy link
Contributor

in theory we'd need to rev to opdef but I belief there is no harm in this case:
it is a contrib op and if a model hits an older version of onnxruntime with uint8 there is no registration for that type so it will still gracefully fail.

Copy link
Contributor

@liqunfu liqunfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall there be test?

@sushraja-msft
Copy link
Contributor Author

shall there be test?

Yes there shall, can you tell me where the existing tests are and how to run them locally so I can add to it ? I am new to making ORT CPU changes.

@sushraja-msft
Copy link
Contributor Author

Are the tests here

https://github.com/microsoft/onnxruntime/blob/24620e70d9f14956a0dc84bb8a332dcd64c95a94/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc#L125C3-L125C27 ?

that tests seems to indicate that it I would fail because I added uint8_t support, yet I am passing the CI.

Copy link
Contributor

@yihonglyu yihonglyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe some .md files under the docs should be updated as well.

@sushraja-msft
Copy link
Contributor Author

I believe some .md files under the docs should be updated as well.

Could you point me to it please, I see a documentation string in contrib_defs.cc. That has been updated, is there some other documentation.

@sushraja-msft
Copy link
Contributor Author

I believe some .md files under the docs should be updated as well.

Could you point me to it please, I see a documentation string in contrib_defs.cc. That has been updated, is there some other documentation.

Done updated

@sushraja-msft sushraja-msft requested a review from liqunfu April 4, 2025 16:54
liqunfu
liqunfu previously approved these changes Apr 4, 2025
@sushraja-msft sushraja-msft merged commit a4976e3 into main Apr 4, 2025
85 of 89 checks passed
@sushraja-msft sushraja-msft deleted the user/sushraja/gather_dequantize branch April 4, 2025 22:43
zhaoxul-qti pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Apr 17, 2025
…oft#24239)

### Description
This change adds support for GatherBlockQuantized to use uin8_t as
data's type with the same semantics as MatMulNBits. Zero_Points and
Gather Axis other than 0 are not yet supported, in order to keep the
change scoped.

### Motivation and Context
With the newer llama models like Phi4 trained with shared embeddings,
the weights of the lm_head matrix and the embeddings table are exactly
the same. These embeddings are huge, unquantized embeddings are 1.2GB in
Phi4 mini instruct, at int4 quantization the weights are still 300MB. We
can go a step further and have these two ops the lm_head matmulnbits and
GatherBlockQuantized share the same weights, that would save 300MB on
the model size.

The two things that hinder that are the shape expectations for
GatherBlockQuantized and the data type supported for data in
GatherBlockQuantized. The shape can be solved via a simple reshape op,
but the data type needs code changes and that is what this change does.

Here is Phi4 modified with shared weights between lm_head and
matmulnbits, this model is just 2.1GB on disk.
<img width="164" alt="image"
src="https://github.com/user-attachments/assets/8bdddbb9-5b44-4839-ab48-605bee53d66b"
/>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
If `zero_points` is not provided, 0 is the zero point.
If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
Copy link
Contributor

@tianleiwu tianleiwu May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why default zero point is 8 for uint8? That does not sound reasonable to me.
Normally, the default is the middle value 2^(bits - 1), so 128 for 8 bits, and 8 for 4 bits.

Maybe add a description that this operator only supports 4 bits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uint8 stores two packed uint4s because this is how matmulnbits works. To resolve this issue, I was recently discussing adding a bits attribute - that would let the uint8_t be intepretted as packed uint4s or a single uint8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants