-
Notifications
You must be signed in to change notification settings - Fork 141
HE SIMD batch MatMat multiplication protocol #1271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
All contributors have signed the CLA ✍️ ✅ |
Summary of ChangesHello @Welement-zyf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a novel protocol for performing batch matrix-matrix multiplication within a Homomorphic Encryption (HE) context. The primary goal is to significantly speed up online computations where one input is a ciphertext and the other is a plaintext weight. The protocol intelligently adapts its strategy based on the input matrix dimensions relative to the SIMD lane size, employing techniques like Baby-Step Giant-Step for efficient rotations when applicable, and providing robust methods for data preparation, encryption, and result extraction. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new HE SIMD batch matrix-matrix multiplication protocol, which is a significant feature. The implementation is well-structured and handles different input sizes with specific strategies, including a BSGS-style algorithm for efficiency. My review focuses on improving code quality by removing leftover debugging artifacts like print statements, cleaning up commented-out code, and eliminating unused variables and methods. These changes will enhance the readability and maintainability of the new protocol and its tests.
| // yacl::parallel_for(0, polys.size(), [&](int64_t bgn, int64_t end) { | ||
| // for (int64_t i = bgn; i < end; ++i) { | ||
| // seal::util::encrypt_zero_symmetric(secret_key, context, | ||
| // context.first_parms_id(), false, | ||
| // save_seed, out[i]); | ||
| // seal::util::multiply_add_plain_with_scaling_variant( | ||
| // polys[i], *context.first_context_data(), | ||
| // seal::util::RNSIter{out[i].data(), out[i].poly_modulus_degree()}); | ||
| // } | ||
| // }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| uint64_t baby_step = absl::bit_ceil( | ||
| static_cast<uint64_t>(std::sqrt(block_size * meta.dims[2] / (double)meta.dims[1]))); | ||
| baby_step = std::min(baby_step, block_size); | ||
| std::cout << "baby_step: " << baby_step << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| absl::Span<const uint64_t> ans_poly, | ||
| absl::Span<uint64_t> res_mat) const; | ||
|
|
||
| // Shape2D GetInShape() const { return in_shape_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
|
|
||
| private: | ||
| void NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| std::shared_ptr<SIMDBatchMMProt> simd_batchmm_prot_; | ||
|
|
||
| void SetUp() override { | ||
| std::cout << "setup" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // modulus_bits = {60, 30, 52, plain_bits}; | ||
| } else { | ||
| // modulus_bits = {60, 45, 45, 58, plain_bits}; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); | ||
| }); | ||
|
|
||
| TEST_P(SIMDBatchMMTest, ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| int64_t d0 = 24; | ||
| int64_t d1 = 2048; | ||
| int64_t d2 = 1408; | ||
| std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| size_t num_row_blocks = CeilDiv(static_cast<uint64_t>(d1), block_size); | ||
| size_t num_col_blocks = CeilDiv(static_cast<uint64_t>(d2), block_size); | ||
| size_t simd_lane = simd_batchmm_prot_->SIMDLane(); | ||
| // size_t row_size = simd_batchmm_prot_->SIMDLane() / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| namespace spu::mpc::cheetah { | ||
|
|
||
| class SIMDBatchMMProt { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should inherit EnableCPRNG here, rather than in the unittest.
| }; | ||
|
|
||
|
|
||
| static constexpr int kNoiseFloodBits = 40; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the logic of noise flooding has not been utilized; it may need to be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the cheetah_dot.h defines BatchDotOLE. You might consider testing its performance against your implementation.
| return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); | ||
| }); | ||
|
|
||
| TEST_P(SIMDBatchMMTest, ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Regarding how to elegantly generate test data and perform correctness testing, you can refer to
cheetah_dot_test.cc. - For measuring elapsed time, you can use yacl::ElapsedTimer, and the usage can be referenced in
src/libspu/mpc/utils/lowmc_test.cc. - If you need to track communication volume and round counts, you can directly retrieve link statistics from yacl; the usage can be referenced in
src/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc.
| int64_t d0 = 24; | ||
| int64_t d1 = 2048; | ||
| int64_t d2 = 1408; | ||
| std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer SPDLOG_INFO if you really need to print something, but you should make them as less as possible.
| }); | ||
|
|
||
| TEST_P(SIMDBatchMMTest, ) { | ||
| size_t batch = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the parameterized testing framework of gtest to elegantly test the correctness and performance of the implementation with various configurations.
You can define something like:
class CheetahDotTest
: public ::testing::TestWithParam<std::tuple<FieldType, Shape3D>> {};
INSTANTIATE_TEST_SUITE_P(
Cheetah, CheetahDotTest,
testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128),
testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1},
Shape3D{30, 57, 1}, Shape3D{18, 8, 41},
Shape3D{500, 13, 25},
Shape3D{1, 2048, 768},
Shape3D{18, 768, 78})),
[](const testing::TestParamInfo<CheetahDotTest::ParamType>& p) {
return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)),
std::get<1>(std::get<1>(p.param)),
std::get<2>(std::get<1>(p.param)),
std::get<0>(p.param));
});|
|
||
| SIMDBatchMMProt(SIMDBatchMMProt&&) = delete; | ||
|
|
||
| // Same as SIMDMulProt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Unless there is a necessity for external calls, the APIs should be made
privateas much as possible. - You should refer to the API design in
cheetah_dot.h, where the input provided to the upper-level APIs needs to be of theNdArrayReftype.
|
Please sign the CLA, btw. |
|
I feel that your batch matmul also needs to define a State, which you can refer to in |
|
I have read the CLA Document and I hereby sign the CLA |
I imitate the implement in Because Seal only support at most 60 bit plaintext modulus, we need to use CRT to transfer data in 2^n bit ring into several plaintexts. I use the setting in I compare my
SIMD |
Hello, I'm having a little trouble understanding your experimental results. Based on the data, it seems that |
|
Thank you for your instructions.
The dimension of MatMul in my work is about |
|
|
||
| if (nxt_rank == 0) { // server | ||
| // receive sk for debug | ||
| // TODO: remove the debug codes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should remove these debug codes?
| std::make_shared<seal::SecretKey>(keygen.secret_key())); | ||
|
|
||
| // send sk for debug | ||
| // TODO: remove the debug codes? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should remove these debug codes?
| std::vector<NdArrayRef> input(kWorldSize); | ||
| NdArrayRef weight; | ||
|
|
||
| // TODO: inputs should be multi-dim tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original implementation of your input is a long one-dimensional vector, but what we actually need is to support a three-dimensional vector as follows. However, if I simply change the dimensions, it seems that there will be correctness issues in the unit tests. Could you please make the necessary modifications?
|
|
||
|
|
||
| // [x], w => [x*w] for private input and plaintext weight | ||
| NdArrayRef MatMulServer(const NdArrayRef& x, const NdArrayRef& w, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个api设计似乎和cheetah其他协议不太一致,比如cheetah_dot.h中的BatchDotOLE,这种设计方式可以更好的兼容MPC的kernel层 BatchMatMulAV (一方sharing一方明文) 和 BatchMatMulAA(两方sharing) 的实现,kernel层后续使用可参考:https://github.com/AntCPLab/OpenBumbleBee/blob/47c5560d069543591f3b0176eb530ede28e33fc3/libspu/mpc/cheetah/arithmetic.cc#L476
当然,如果你这样设计api是为了更高效的实现,那是可以讨论的,如果性能差不多,可以修改一下api设计。
|
|
||
| NdArrayRef out; | ||
| if (rank != owner) { | ||
| // TODO: refactor bmm_prot api to support both av and aa cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我是仿照了https://github.com/AntCPLab/OpenBumbleBee/blob/47c5560d069543591f3b0176eb530ede28e33fc3/libspu/mpc/cheetah/arithmetic.cc#L476
中的实现,你可以看看是否合适?如果ok的话,可以修改一下bmm_prot的实现,填充一下TODO中的调用
|
我往你的代码分支push了一些流程代码,然后稍微清理了一下你的代码,其中仍有一些问题。 |
Pull Request
What problem does this PR solve?
A new HE SIMD batch MatMat multiplication protocol, which can accelerate online HE MatMat computation. This protocol is focus on MatMul between ciphertext input and plaintext weight.
Possible side effects?
Increase the computation of encoding plaintext weight. But this part can be done offline when there are many inputs.