Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2773,24 +2773,24 @@ TEST_P(TmaInnerReductionManualTest, Basic) {
auto cg_outputs = ke.run({t0});
testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__);
}

INSTANTIATE_TEST_SUITE_P(
,
TmaInnerReductionManualTest,
::testing::Combine(
::testing::Values(2, 3), // ndims
::testing::ValuesIn([] { // inner_size
testing::Combine(
testing::Values(2, 3), // ndims
testing::ValuesIn([] { // inner_size
std::vector<int64_t> vals(
Pow2Vals1to1Million.begin(), Pow2Vals1to1Million.end());
// Add some irregular numbers
vals.insert(vals.end(), {1024 * 1024 + 8, 1024 * 1024 + 7, 1023});
return vals;
}())),
[](const testing::TestParamInfo<TmaInnerReductionManualTestParams>& info) {
int64_t ndims = std::get<0>(info.param);
int64_t inner_size = std::get<1>(info.param);
([](const testing::TestParamInfo<TmaInnerReductionManualTestParams>& info) {
auto [ndims, inner_size] = info.param;
return "ndim_" + std::to_string(ndims) + "_inner_size_" +
std::to_string(inner_size);
});
}));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra pair of parenthesis is the trick to allow structured bindings. I'm not sure about the reason -- it has something to do with C++ templates.


namespace tma_reduction_check {
bool isTmaParams(const FusionExecutorCache& executor_cache) {
Expand Down Expand Up @@ -2879,7 +2879,7 @@ TEST_P(TmaInnerReductionTest, Sum) {
INSTANTIATE_TEST_SUITE_P(
,
TmaInnerReductionTest,
::testing::Combine(
testing::Combine(
testing::Values(DataType::Float, DataType::BFloat16),
testing::ValuesIn([] {
std::vector<int64_t> vals(
Expand All @@ -2888,12 +2888,11 @@ INSTANTIATE_TEST_SUITE_P(
vals.insert(vals.end(), {1024 * 1024 + 8, 1024 * 1024 + 7, 1023});
return vals;
}())),
[](const testing::TestParamInfo<TmaInnerReductionTestParams>& info) {
auto dtype = std::get<0>(info.param);
auto reduction_size = std::get<1>(info.param);
([](const testing::TestParamInfo<TmaInnerReductionTestParams>& info) {
auto [dtype, reduction_size] = info.param;
std::ostringstream os;
os << dtype << "_" << reduction_size;
return os.str();
});
}));

} // namespace nvfuser