-
Notifications
You must be signed in to change notification settings - Fork 141
Zjj/wk25 merge join——join_uu #1368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: zjj/wk25_merge_join
Are you sure you want to change the base?
Zjj/wk25 merge join——join_uu #1368
Conversation
Summary of ChangesHello @cpz2024, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new secure join primitive, "join_uu", designed for two-party computation within the SPU framework. The primary goal is to enable two distinct parties to perform a join operation on their respective private datasets while maintaining data confidentiality. The implementation utilizes a Cuckoo Hash-based approach, incorporating secure pseudo-random function (SoPrf) outputs and cryptographic hashing (Blake3) to generate permutations that facilitate the secure comparison and merging of records. This enhancement significantly expands SPU's capabilities for privacy-preserving data analysis. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a secure two-party join implementation (join_uu) based on the "Private Join and Compute from PIR with Default" paper. The changes are comprehensive, including the core join logic, a new MPC kernel for Cuckoo Hashing, build file updates, and corresponding tests.
I've identified a critical correctness issue in the join logic where it selects matching rows, which could lead to incorrect results. Additionally, there are several opportunities for improvement in code maintainability, such as removing test dependencies from production code, improving comments, and reducing code duplication. I've also noted some minor issues in the test files that should be addressed.
| spu::Value col_result = | ||
| hal::constant(ctx, 0, table_2[col_idx].dtype(), table_1[0].shape()); | ||
| spu::Value control_bit = hal::constant(ctx, 0, join_result_cols[0].dtype(), | ||
| join_result_cols[0].shape()); | ||
| for (size_t hash_idx = 0; hash_idx < num_hash; ++hash_idx) { | ||
| // Not operation on control_bit. | ||
| control_bit = hal::bitwise_not(ctx, control_bit); | ||
| // And control_bit and join_result_cols[hash_idx]. | ||
| control_bit = | ||
| hal::bitwise_and(ctx, control_bit, join_result_cols[hash_idx]); | ||
| // Multiply the corresponding columns in control_bit and table_t_1. | ||
| spu::Value table_t_2_i_col = | ||
| table_t_1[(hash_idx * (table_2.size() + 1)) + col_idx]; | ||
| spu::Value mul_result = hal::mul(ctx, table_t_2_i_col, control_bit); | ||
| col_result = hal::add(ctx, col_result, mul_result); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for calculating control_bit to select the first match among multiple hash locations is incorrect. The current implementation fails after the second iteration, which will lead to incorrect join results when an item from table_1 matches items in table_2 at multiple hash locations. The logic should ensure that for each row, only the first match across all hash functions is considered.
| spu::Value col_result = | |
| hal::constant(ctx, 0, table_2[col_idx].dtype(), table_1[0].shape()); | |
| spu::Value control_bit = hal::constant(ctx, 0, join_result_cols[0].dtype(), | |
| join_result_cols[0].shape()); | |
| for (size_t hash_idx = 0; hash_idx < num_hash; ++hash_idx) { | |
| // Not operation on control_bit. | |
| control_bit = hal::bitwise_not(ctx, control_bit); | |
| // And control_bit and join_result_cols[hash_idx]. | |
| control_bit = | |
| hal::bitwise_and(ctx, control_bit, join_result_cols[hash_idx]); | |
| // Multiply the corresponding columns in control_bit and table_t_1. | |
| spu::Value table_t_2_i_col = | |
| table_t_1[(hash_idx * (table_2.size() + 1)) + col_idx]; | |
| spu::Value mul_result = hal::mul(ctx, table_t_2_i_col, control_bit); | |
| col_result = hal::add(ctx, col_result, mul_result); | |
| } | |
| spu::Value col_result = | |
| hal::constant(ctx, 0, table_2[col_idx].dtype(), table_1[0].shape()); | |
| spu::Value processed_mask = hal::constant(ctx, 0, join_result_cols[0].dtype(), | |
| join_result_cols[0].shape()); | |
| for (size_t hash_idx = 0; hash_idx < num_hash; ++hash_idx) { | |
| // Get a mask for the first match only. | |
| spu::Value control_bit = hal::bitwise_and( | |
| ctx, join_result_cols[hash_idx], hal::bitwise_not(ctx, processed_mask)); | |
| // Multiply the corresponding columns in control_bit and table_t_1. | |
| spu::Value table_t_2_i_col = | |
| table_t_1[(hash_idx * (table_2.size() + 1)) + col_idx]; | |
| spu::Value mul_result = hal::mul(ctx, table_t_2_i_col, control_bit); | |
| col_result = hal::add(ctx, col_result, mul_result); | |
| // Update the processed mask | |
| processed_mask = hal::bitwise_or(ctx, processed_mask, join_result_cols[hash_idx]); | |
| } |
| ":shape_ops", | ||
| ":utils", | ||
| "@yacl//yacl/utils:cuckoo_index", | ||
| "//libspu/kernel:test_util", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| #include "libspu/kernel/hal/utils.h" | ||
| #include "libspu/kernel/hlo/permute.h" | ||
| #include "libspu/kernel/hlo/soprf.h" | ||
| #include "libspu/kernel/test_util.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // 当field == FieldType::FM64 && num_join_keys == | ||
| // 1时,后面需要使用FM64,否则结果不对 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is in Chinese, while the rest of the codebase is in English. To ensure consistency and maintainability for all contributors, please translate this comment to English.
| // 当field == FieldType::FM64 && num_join_keys == | |
| // 1时,后面需要使用FM64,否则结果不对 | |
| // When field is FM64 and there is only one join key, we must use FM64 for | |
| // cuckoo hashing as well, otherwise the result is incorrect. |
| table_t_2_i_keys.reserve(num_join_keys); | ||
| for (size_t j = 0; j < num_join_keys; ++j) { | ||
| table_t_2_i_keys.push_back( | ||
| table_t_1[(i * (table_2.size() + 1)) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (lctx->Rank() == 0) { | ||
| std::cout << "Join communication sent bytes: " << (b1 - b0) | ||
| << ", sent actions: " << (r1 - r0) << std::endl; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // TEST(BigDataJoinTest, Work) { | ||
| // FieldType field = FieldType::FM64; | ||
| // ProtocolKind prot = ProtocolKind::SEMI2K; | ||
| // size_t num_join_keys = 1; | ||
| // const size_t num_hash = 3; | ||
| // const size_t scale_factor = 15; | ||
|
|
||
| // const Shape shape_1 = {2, 1000}; | ||
| // const Shape shape_2 = {2, 1000}; | ||
| // xt::xarray<uint64_t> data_1 = xt::random::randint<uint64_t>(shape_1, 0); | ||
| // xt::xarray<uint64_t> data_2 = xt::random::randint<uint64_t>(shape_2, 0); | ||
| // for (auto i = 0; i < shape_1[1]; ++i) { | ||
| // data_1(0, i) = i; | ||
| // data_1(1, i) = i + 100; | ||
| // } | ||
| // for (auto i = shape_2[1] - 1; i >= 0; --i) { | ||
| // data_2(0, i) = i; | ||
| // data_2(1, i) = i + 200; | ||
| // } | ||
|
|
||
| // mpc::utils::simulate( | ||
| // 2, [&](const std::shared_ptr<yacl::link::Context>& lctx) { | ||
| // SPUContext sctx = test::makeSPUContext(prot, field, lctx); | ||
|
|
||
| // std::vector<Value> table1_columns; | ||
| // for (int64_t i = 0; i < shape_1[0]; ++i) { | ||
| // xt::xarray<uint64_t> col_data = xt::row(data_1, i); | ||
| // Value col = test::makeValue(&sctx, col_data, VIS_SECRET); | ||
| // table1_columns.push_back(col); | ||
| // } | ||
|
|
||
| // std::vector<Value> table2_columns; | ||
| // for (int64_t i = 0; i < shape_2[0]; ++i) { | ||
| // xt::xarray<uint64_t> col_data = xt::row(data_2, i); | ||
| // Value col = test::makeValue(&sctx, col_data, VIS_SECRET); | ||
| // table2_columns.push_back(col); | ||
| // } | ||
|
|
||
| // absl::Span<const Value> table1_span = | ||
| // absl::MakeConstSpan(table1_columns); | ||
| // absl::Span<const Value> table2_span = | ||
| // absl::MakeConstSpan(table2_columns); | ||
|
|
||
| // size_t b0 = lctx->GetStats()->sent_bytes; | ||
| // size_t r0 = lctx->GetStats()->sent_actions; | ||
|
|
||
| // auto ret = join_uu(&sctx, table1_span, table2_span, num_join_keys, | ||
| // num_hash, scale_factor, field); | ||
| // size_t b1 = lctx->GetStats()->sent_bytes; | ||
| // size_t r1 = lctx->GetStats()->sent_actions; | ||
|
|
||
| // if (lctx->Rank() == 0) { | ||
| // std::cout << "Join communication sent bytes: " << (b1 - b0) | ||
| // << ", sent actions: " << (r1 - r0) << std::endl; | ||
| // } | ||
| // }); | ||
| // } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (isOwner(ctx, e_1.eltype())) { | ||
| DISPATCH_ALL_FIELDS(field, [&]() { | ||
| NdArrayView<ring2k_t> e_1_view(e_1); | ||
| for (int64_t i = 0; i < n_1; ++i) { | ||
| blake3.Reset(); | ||
| element = e_1_view[i]; | ||
| blake3.Update(yacl::ByteContainerView( | ||
| reinterpret_cast<const char*>(&element), sizeof(element))); | ||
| hash_output = blake3.CumulativeHash(); | ||
| result_tmp = 0; | ||
| memcpy(&result_tmp, hash_output.data(), | ||
| std::min(hash_output.size(), sizeof(uint128_t))); | ||
| result.push_back(result_tmp); | ||
| } | ||
| }); | ||
| } else if (isOwner(ctx, e_2.eltype())) { | ||
| DISPATCH_ALL_FIELDS(field, [&]() { | ||
| NdArrayView<ring2k_t> e_2_view(e_2); | ||
| for (int64_t i = 0; i < n_2; ++i) { | ||
| blake3.Reset(); | ||
| element = e_2_view[i]; | ||
| blake3.Update(yacl::ByteContainerView( | ||
| reinterpret_cast<const char*>(&element), sizeof(element))); | ||
| hash_output = blake3.CumulativeHash(); | ||
| result_tmp = 0; | ||
| memcpy(&result_tmp, hash_output.data(), | ||
| std::min(hash_output.size(), sizeof(uint128_t))); | ||
| result.push_back(result_tmp); | ||
| } | ||
| }); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code blocks for hashing e_1 (lines 1070-1084) and e_2 (lines 1085-1100) are nearly identical. This duplication can be avoided by extracting the hashing logic into a helper lambda or function. This would make the code more concise and easier to maintain.
auto hash_elements = [&](const NdArrayRef& e) {
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> e_view(e);
for (int64_t i = 0; i < e.numel(); ++i) {
blake3.Reset();
element = e_view[i];
blake3.Update(yacl::ByteContainerView(
reinterpret_cast<const char*>(&element), sizeof(element)));
hash_output = blake3.CumulativeHash();
result_tmp = 0;
memcpy(&result_tmp, hash_output.data(),
std::min(hash_output.size(), sizeof(uint128_t)));
result.push_back(result_tmp);
}
});
};
if (isOwner(ctx, e_1.eltype())) {
hash_elements(e_1);
} else if (isOwner(ctx, e_2.eltype())) {
hash_elements(e_2);
}
Pull Request
What problem does this PR solve?
Issue Number: Fixed #
Possible side effects?
Performance:
Backward compatibility: