Skip to content

Conversation

@Welement-zyf
Copy link

Pull Request

What problem does this PR solve?

A new HE SIMD batch MatMat multiplication protocol, which can accelerate online HE MatMat computation. This protocol is focus on MatMul between ciphertext input and plaintext weight.

Possible side effects?

Increase the computation of encoding plaintext weight. But this part can be done offline when there are many inputs.

cd src
bazelisk test //libspu/mpc/cheetah/arith:simd_batchmm_test --test_output=all

@github-actions
Copy link

github-actions bot commented Oct 16, 2025

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Welement-zyf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a novel protocol for performing batch matrix-matrix multiplication within a Homomorphic Encryption (HE) context. The primary goal is to significantly speed up online computations where one input is a ciphertext and the other is a plaintext weight. The protocol intelligently adapts its strategy based on the input matrix dimensions relative to the SIMD lane size, employing techniques like Baby-Step Giant-Step for efficient rotations when applicable, and providing robust methods for data preparation, encryption, and result extraction.

Highlights

  • New Protocol Introduction: A new Homomorphic Encryption (HE) SIMD batch Matrix-Matrix multiplication protocol has been introduced, designed to accelerate online HE computations.
  • Targeted Acceleration: This protocol specifically focuses on optimizing matrix multiplication where the input is a ciphertext and the weight is a plaintext, a common scenario in HE applications.
  • Adaptive Strategy: The SIMDBatchMMProt class implements an adaptive multiplication strategy, utilizing Baby-Step Giant-Step (BSGS) rotations for efficient computation when the input shape is small relative to the SIMD lane, and a simpler approach otherwise.
  • Comprehensive Data Handling: The protocol includes robust methods for preparing input and weight vectors, batch encoding/decoding, symmetric encryption, and parsing the final results back into a matrix format.
  • Testing and Build System Integration: New build rules and a dedicated test suite have been added to integrate the protocol into the existing system and ensure its correctness and performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new HE SIMD batch matrix-matrix multiplication protocol, which is a significant feature. The implementation is well-structured and handles different input sizes with specific strategies, including a BSGS-style algorithm for efficiency. My review focuses on improving code quality by removing leftover debugging artifacts like print statements, cleaning up commented-out code, and eliminating unused variables and methods. These changes will enhance the readability and maintainability of the new protocol and its tests.

Comment on lines 86 to 95
// yacl::parallel_for(0, polys.size(), [&](int64_t bgn, int64_t end) {
// for (int64_t i = bgn; i < end; ++i) {
// seal::util::encrypt_zero_symmetric(secret_key, context,
// context.first_parms_id(), false,
// save_seed, out[i]);
// seal::util::multiply_add_plain_with_scaling_variant(
// polys[i], *context.first_context_data(),
// seal::util::RNSIter{out[i].data(), out[i].poly_modulus_degree()});
// }
// });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function contains a large block of commented-out code. If this code is no longer needed, it should be removed to improve readability and maintainability.

uint64_t baby_step = absl::bit_ceil(
static_cast<uint64_t>(std::sqrt(block_size * meta.dims[2] / (double)meta.dims[1])));
baby_step = std::min(baby_step, block_size);
std::cout << "baby_step: " << baby_step << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This std::cout statement seems to be for debugging purposes and should be removed from the final code.

absl::Span<const uint64_t> ans_poly,
absl::Span<uint64_t> res_mat) const;

// Shape2D GetInShape() const { return in_shape_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out function declaration should be removed to keep the header file clean.



private:
void NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The private method NoiseFloodInplace is declared here but is not used within the class. This appears to be dead code and should be removed along with its implementation.

std::shared_ptr<SIMDBatchMMProt> simd_batchmm_prot_;

void SetUp() override {
std::cout << "setup" << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This std::cout statement appears to be for debugging and should be removed from the test's SetUp method.

Comment on lines 54 to 57
// modulus_bits = {60, 30, 52, plain_bits};
} else {
// modulus_bits = {60, 45, 45, 58, plain_bits};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code in the SetUp method should be removed to improve test readability.

return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});

TEST_P(SIMDBatchMMTest, ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test case has an empty name. Please provide a descriptive name that indicates what is being tested, for example, Correctness.

Suggested change
TEST_P(SIMDBatchMMTest, ) {
TEST_P(SIMDBatchMMTest, Correctness) {

int64_t d0 = 24;
int64_t d1 = 2048;
int64_t d2 = 1408;
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test contains several std::cout statements for debugging. Tests should be silent on success. Please remove this and other print statements throughout the test.

size_t num_row_blocks = CeilDiv(static_cast<uint64_t>(d1), block_size);
size_t num_col_blocks = CeilDiv(static_cast<uint64_t>(d2), block_size);
size_t simd_lane = simd_batchmm_prot_->SIMDLane();
// size_t row_size = simd_batchmm_prot_->SIMDLane() / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out line should be removed to improve test code clarity.


namespace spu::mpc::cheetah {

class SIMDBatchMMProt {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should inherit EnableCPRNG here, rather than in the unittest.

};


static constexpr int kNoiseFloodBits = 40;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the logic of noise flooding has not been utilized; it may need to be added.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the cheetah_dot.h defines BatchDotOLE. You might consider testing its performance against your implementation.

return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});

TEST_P(SIMDBatchMMTest, ) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Regarding how to elegantly generate test data and perform correctness testing, you can refer to cheetah_dot_test.cc.
  • For measuring elapsed time, you can use yacl::ElapsedTimer, and the usage can be referenced in src/libspu/mpc/utils/lowmc_test.cc.
  • If you need to track communication volume and round counts, you can directly retrieve link statistics from yacl; the usage can be referenced in src/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc.

int64_t d0 = 24;
int64_t d1 = 2048;
int64_t d2 = 1408;
std::cout << "Testing Batch MatMat " << batch << "x" << d0 << "x" << d1 << " * " << d1 << "x" << d2 << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer SPDLOG_INFO if you really need to print something, but you should make them as less as possible.

});

TEST_P(SIMDBatchMMTest, ) {
size_t batch = 64;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the parameterized testing framework of gtest to elegantly test the correctness and performance of the implementation with various configurations.

You can define something like:

class CheetahDotTest
    : public ::testing::TestWithParam<std::tuple<FieldType, Shape3D>> {};

INSTANTIATE_TEST_SUITE_P(
    Cheetah, CheetahDotTest,
    testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128),
                     testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1},
                                     Shape3D{30, 57, 1}, Shape3D{18, 8, 41},
                                     Shape3D{500, 13, 25},
                                     Shape3D{1, 2048, 768},
                                     Shape3D{18, 768, 78})),
    [](const testing::TestParamInfo<CheetahDotTest::ParamType>& p) {
      return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)),
                         std::get<1>(std::get<1>(p.param)),
                         std::get<2>(std::get<1>(p.param)),
                         std::get<0>(p.param));
    });


SIMDBatchMMProt(SIMDBatchMMProt&&) = delete;

// Same as SIMDMulProt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Unless there is a necessity for external calls, the APIs should be made private as much as possible.
  • You should refer to the API design in cheetah_dot.h, where the input provided to the upper-level APIs needs to be of the NdArrayRef type.

@deadlywing
Copy link
Collaborator

Please sign the CLA, btw.

@deadlywing
Copy link
Collaborator

I feel that your batch matmul also needs to define a State, which you can refer to in src/libspu/mpc/cheetah/state.h, where CheetahDotState is defined. In src/libspu/mpc/cheetah/arithmetic.cc, you can see that the matmul series of interfaces are basically implemented using CheetahDotState to achieve specific functionalities.

@Welement-zyf
Copy link
Author

I have read the CLA Document and I hereby sign the CLA

@Welement-zyf
Copy link
Author

cd src
bazelisk test //libspu/mpc/cheetah/arith:simd_batchmm_test --test_output=all
bazelisk test //libspu/mpc/cheetah/arith:batch_matmul_test --test_output=all

I imitate the implement in spu/src/libspu/mpc/cheetah/arith/cheetah_mul.cc to support FM32 and FM64 ring data (support NdArrayRef type API) for my protocol, which works on BFV SIMD encoding.
The wrapped implement for spu/src/libspu/mpc/cheetah/arith/simd_batchmm_prot.cc is in spu/src/libspu/mpc/cheetah/arith/batch_matmul.cc. It can also simulate the communication between the client and the server.

Because Seal only support at most 60 bit plaintext modulus, we need to use CRT to transfer data in 2^n bit ring into several plaintexts. I use the setting in spu/src/libspu/mpc/cheetah/arith/cheetah_mul.cc, where data in FM32 is transferred into 3 plaintexts, data in FM64 is transferred into 4 plaintexts. This makes the computation overhead increase largely. Maybe using CKKS can solve this problem.

I compare my BatchMatMul with BatchDotOLE in spu/src/libspu/mpc/cheetah/arith/cheetah_dot.cc. Considering the 4x overhead introduced by CRT, the online computation latency of them are comparable, both far more faster than DotOLE.

Shape SIMD BatchMatMul 44bit field Comp. (ms) SIMD BatchMatMul FM64 Comp. (ms) Coefficient BatchDot FM64 Comp. (ms)
4x1x2048x768 268.335 1031.279 208.212
4x18x768x768 151.154 449.103 125.368
4x1024x16x16 62.753 271.044 190.668

SIMD BatchMatMul reported the total online computation time. BatchDot reported the interleaving (key-switch) overhead.

@deadlywing
Copy link
Collaborator

cd src
bazelisk test //libspu/mpc/cheetah/arith:simd_batchmm_test --test_output=all
bazelisk test //libspu/mpc/cheetah/arith:batch_matmul_test --test_output=all

I imitate the implement in spu/src/libspu/mpc/cheetah/arith/cheetah_mul.cc to support FM32 and FM64 ring data (support NdArrayRef type API) for my protocol, which works on BFV SIMD encoding. The wrapped implement for spu/src/libspu/mpc/cheetah/arith/simd_batchmm_prot.cc is in spu/src/libspu/mpc/cheetah/arith/batch_matmul.cc. It can also simulate the communication between the client and the server.

Because Seal only support at most 60 bit plaintext modulus, we need to use CRT to transfer data in 2^n bit ring into several plaintexts. I use the setting in spu/src/libspu/mpc/cheetah/arith/cheetah_mul.cc, where data in FM32 is transferred into 3 plaintexts, data in FM64 is transferred into 4 plaintexts. This makes the computation overhead increase largely. Maybe using CKKS can solve this problem.

I compare my BatchMatMul with BatchDotOLE in spu/src/libspu/mpc/cheetah/arith/cheetah_dot.cc. Considering the 4x overhead introduced by CRT, the online computation latency of them are comparable, both far more faster than DotOLE.

Shape SIMD BatchMatMul 44bit field Comp. (ms) SIMD BatchMatMul FM64 Comp. (ms) Coefficient BatchDot FM64 Comp. (ms)
4x1x2048x768 268.335 1031.279 208.212
4x18x768x768 151.154 449.103 125.368
4x1024x16x16 62.753 271.044 190.668
SIMD BatchMatMul reported the total online computation time. BatchDot reported the interleaving (key-switch) overhead.

Hello, I'm having a little trouble understanding your experimental results. Based on the data, it seems that BatchDot's latency is actually lower than BatchMatMul in most cases.
Could you clarify why you stated that 'Considering the 4x overhead introduced by CRT, the online computation latency of them are comparable, both far more faster than DotOLE.'?

@Welement-zyf
Copy link
Author

Thank you for your instructions.

  1. '4x overhead introduced by CRT': To support NdArrayRef type of API, the I/O data must be in power of 2 ring, such as FM32/FM64. However, seal support at most 60 bit of plaintext modulus. I use the method in spu/src/libspu/mpc/cheetah/arith/cheetah_mul.cc to transform one data in FM64 into four data in BFV plaintext modulus, which means that to complete computation on FM64, I need to invoke my simd_batchmm_prot.cc four times. The BatchDot in spu/src/libspu/mpc/cheetah/arith/cheetah_dot.cc is implemented with CKKS and can direct encode data in FM64 into a plaintext. I might consider to try CKKS for my protocol.
  2. BatchMatMul and BatchDot have better performance in different cases. For the batch matmul of b x m x n x k, BatchMatMul needs $2\sqrt{\frac{bmnk}{N}}$ key switch (rotation) ops, and BatchDot needs about $\frac{bmk\log N}{2\sqrt{N}}$ key switch ops. Thus BatchMatMul performs better especially when b, m or k is very large.
Shape SIMD BatchMatMul 44bit field Comp. (ms) SIMD BatchMatMul FM64 Comp. (ms) Coefficient BatchDot FM64 Comp. (ms)
64x32x1024x512 3346.478 14125.025 19559.728

The dimension of MatMul in my work is about 64x32x2048x2048 and BatchMatMul is very suitable for this case. I use smaller dimension here because it seems that [external/yacl~/yacl/link/transport/channel_mem.cc:53] Get data timeout, key=sim.2:P2P-3:1->0 is thrown if the computation time is larger than 30s.


if (nxt_rank == 0) { // server
// receive sk for debug
// TODO: remove the debug codes?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove these debug codes?

std::make_shared<seal::SecretKey>(keygen.secret_key()));

// send sk for debug
// TODO: remove the debug codes?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove these debug codes?

std::vector<NdArrayRef> input(kWorldSize);
NdArrayRef weight;

// TODO: inputs should be multi-dim tensors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original implementation of your input is a long one-dimensional vector, but what we actually need is to support a three-dimensional vector as follows. However, if I simply change the dimensions, it seems that there will be correctness issues in the unit tests. Could you please make the necessary modifications?



// [x], w => [x*w] for private input and plaintext weight
NdArrayRef MatMulServer(const NdArrayRef& x, const NdArrayRef& w,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个api设计似乎和cheetah其他协议不太一致,比如cheetah_dot.h中的BatchDotOLE,这种设计方式可以更好的兼容MPC的kernel层 BatchMatMulAV (一方sharing一方明文) 和 BatchMatMulAA(两方sharing) 的实现,kernel层后续使用可参考:https://github.com/AntCPLab/OpenBumbleBee/blob/47c5560d069543591f3b0176eb530ede28e33fc3/libspu/mpc/cheetah/arithmetic.cc#L476

当然,如果你这样设计api是为了更高效的实现,那是可以讨论的,如果性能差不多,可以修改一下api设计。


NdArrayRef out;
if (rank != owner) {
// TODO: refactor bmm_prot api to support both av and aa cases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我是仿照了https://github.com/AntCPLab/OpenBumbleBee/blob/47c5560d069543591f3b0176eb530ede28e33fc3/libspu/mpc/cheetah/arithmetic.cc#L476
中的实现,你可以看看是否合适?如果ok的话,可以修改一下bmm_prot的实现,填充一下TODO中的调用

@deadlywing
Copy link
Collaborator

@Welement-zyf

我往你的代码分支push了一些流程代码,然后稍微清理了一下你的代码,其中仍有一些问题。
辛苦修改一下,修改完后我接着做上下层算子打通,我当前初步做了一些工作,但是因为算子层还没有跑通所以不能做测试;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants