From a3371611a9b81b471a54e50137aaf3098d1b4e1c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Mar 2023 14:45:34 -0700 Subject: [PATCH 1/4] Rewrite ParallelDimensionMap with expr simplifier --- csrc/expr_simplifier.cpp | 5 +- csrc/parallel_dimension_map.cpp | 299 +++++++------------------------- csrc/parallel_dimension_map.h | 42 +---- test/test_expr_simplifier.cpp | 2 + test/test_gpu2.cpp | 5 - 5 files changed, 77 insertions(+), 276 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index c42b279edd4..3bf451986fe 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1522,8 +1522,9 @@ Val* eliminateTrivialComputation(Val* value, const Context& context) { return output; } } - { // b && b -> b, b || b -> b - if (op == BinaryOpType::And || op == BinaryOpType::Or) { + { // b && b -> b, b || b -> b, max(i, i) -> i, min(i, i) -> i + if (op == BinaryOpType::And || op == BinaryOpType::Or || + op == BinaryOpType::Max || op == BinaryOpType::Min) { std::vector dedup_input; for (auto v : fop->inputs()) { bool found_dup = false; diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index ca5a9a6bb8e..3efda2221b5 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -17,169 +18,64 @@ namespace nvfuser { void ParallelDimensionMap::build(Fusion* fusion) { - // Scan all TVs to build ParallelType maps + // Scan all TVs to build dim_map_ auto all_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(all_vals)) { for (auto id : tv->domain()->domain()) { - registerConstantExtent(id); - if (!isParallelTypeThread(id->getParallelType())) { + auto ptype = id->getParallelType(); + if (!isParallelTypeThread(ptype)) { continue; } - handleParallelDomain(id); - } - } - - // Populate the dimension map for each parallel type - for (const auto& kv : concrete_dom_map_) { - auto pt = kv.first; - const auto& concrete_dom_set = kv.second; - TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty()); - if (concrete_dom_set.size() == 1) { - populateDimensionMapWithSingleCASet(pt, concrete_dom_set); - } else { - populateDimensionMapWithMultipleCASet(pt, concrete_dom_set); - } - } - - adjustMappingsForWarpPadding(); -} - -void ParallelDimensionMap::registerConstantExtent(IterDomain* id) { - if (!id->extent()->isConstScalar()) { - // Nothing to do if not constant - return; - } - - TORCH_INTERNAL_ASSERT( - id->extent()->isConstInt(), - "Extent of ", - id->toString(), - " should have been constant, but could not be evaluated at compile time."); - - auto const_extent = id->extent()->evaluateInt(); - - // Uses index map - auto concrete_id = getCAMappedConcreteDomain(id); - - auto existing_it = constant_extent_map_.find(id); - - // Adds the constant extent to the set for the concrete domain. If - // multiple constants are found, this concrete domain has multiple - // distinctive extents, which can happen with broadcast. - if (existing_it == constant_extent_map_.end()) { - constant_extent_map_.insert({concrete_id, {const_extent}}); - } else { - existing_it->second.insert(const_extent); - } -} + exact_types_.insert(ptype); // insert now and cleanup later -// Adds the conrecte domain of id to the mappsed set for its -// parallel type -void ParallelDimensionMap::handleParallelDomain(IterDomain* id) { - auto pt = id->getParallelType(); - TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt)); - auto concrete_id = getCAMappedConcreteDomain(id); + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); + if (concrete_id->isBroadcast()) { + // Broadcasted concrete id's don't specify anything about shape + continue; + } - auto it = concrete_dom_map_.find(pt); - if (it == concrete_dom_map_.end()) { - concrete_dom_map_.insert({pt, {concrete_id}}); - } else { - it->second.insert(concrete_id); + if (dim_map_.count(ptype) == 0) { + dim_map_[ptype] = concrete_id->extent(); + } else { + dim_map_.at(ptype) = SimplifyingIrBuilder::maxExpr( + dim_map_.at(ptype), concrete_id->extent()); + } + } } -} - -void ParallelDimensionMap::populateDimensionMapWithSingleCASet( - ParallelType pt, - const std::unordered_set& dom_set) { - TORCH_INTERNAL_ASSERT(dom_set.size() == 1); - - // pt is used by only one concrete domain - auto id = *dom_set.begin(); - auto it = constant_extent_map_.find(id); - if (it != constant_extent_map_.end()) { - TORCH_INTERNAL_ASSERT( - it->second.size() == 1, - "Only one value found mapped to parallel type ", - stringifyThread(pt), - " yet its bound to multiple extents."); - dim_map_.insert({pt, IrBuilder::create(*(it->second.begin()))}); - exact_types_.insert(pt); - } else { - // Prefer to use blockDim/gridDim if not constant - dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); - exact_types_.insert(pt); + // Simplify dim_map_ + for (auto& [k, v] : dim_map_) { + v = simplifyExpr(v); } -} - -void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( - ParallelType pt, - const std::unordered_set& dom_set) { - TORCH_INTERNAL_ASSERT(dom_set.size() > 1); - bool all_equal = true; - // Use nullptr to signal it's not initialied yet - Val* known_dimension = nullptr; - // Use -1 to signal it's not initialied yet - int64_t known_const = -1; - - // Check all of concrete domains to see if they match all together. - for (auto concrete_id : dom_set) { - if (concrete_id->isBroadcast()) { - // Broadcasted concrete id's don't specify anything about shape - continue; - } - // If this concrete domain has a constant extent, check if it - // matches with the known constant extent. - auto it = constant_extent_map_.find(concrete_id); - if (it != constant_extent_map_.end()) { - const auto& const_extent_set = it->second; - // If multiple constants are detected, it's not exact. - if (const_extent_set.size() > 1) { - all_equal = false; - break; + // Compute exact_types_ + for (auto tv : ir_utils::filterByType(all_vals)) { + for (auto id : tv->domain()->domain()) { + auto ptype = id->getParallelType(); + if (exact_types_.count(ptype) == 0) { + continue; } - auto this_const = *(const_extent_set.begin()); - // known_const is initialized to -1 - if (known_const == -1) { - known_const = this_const; - } else if (known_const == this_const) { - // Matched with previously known const. The extent of this - // domain must be equal to that's previously known. + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); + if (concrete_id->isBroadcast()) { + // Broadcasted concrete id's don't specify anything about shape continue; - } else { - // Unmatched. This dom_set extents may not be unique. - all_equal = false; - break; } - } - - // At this point, it still remains undetermined whether this id - // matches with those previously looked at. Constant check failed, - // but symbolic matching may succeed. - auto this_dimension = concrete_id->extent(); - if (known_dimension == nullptr) { - // No previous dimension found yet - known_dimension = this_dimension; - } else { - if (!equalDim(known_dimension, this_dimension)) { - all_equal = false; - break; + if (simplifyExpr(SimplifyingIrBuilder::eqExpr( + dim_map_.at(ptype), concrete_id->extent())) + ->getBool() != true) { + std::cout << dim_map_.at(ptype)->toInlineString() + << " != " << concrete_id->extent()->toInlineString() + << std::endl; + exact_types_.erase(ptype); } } } - // If all_equal is still true, the dimension of this paralel type - // must be exact. - if (all_equal) { - exact_types_.insert(pt); - } - // Use the const value, if found, as its dimension - if (all_equal && known_const != -1) { - dim_map_.insert({pt, IrBuilder::create(known_const)}); - } else { - dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); - } + adjustMappingsForWarpPadding(); + + // std::cout << toString() << std::endl; } void ParallelDimensionMap::adjustMappingsForWarpPadding() { @@ -196,32 +92,26 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { } const auto tidx_pt = ParallelType::TIDx; - auto warp_size = at::cuda::warp_size(); + auto warp_size = 32; // If the dimension of TIDx is actually a multple of the warp size // before padding, it can be left as exact if (isExact(tidx_pt)) { - auto tidx_dim = dynamic_cast(get(tidx_pt)); - if (tidx_dim && tidx_dim->isConst()) { - auto tidx_dim_val = tidx_dim->value().value(); - if (tidx_dim_val % warp_size == 0) { - // Dimension of TIDx is a multiple of the warp size - return; + auto tidx_dim = dynamic_cast(getRaw(tidx_pt)); + if (tidx_dim) { + if (tidx_dim->isConst()) { + auto tidx_dim_val = tidx_dim->value().value(); + if (tidx_dim_val % warp_size == 0) { + // Dimension of TIDx is a multiple of the warp size + return; + } } - } - // If tidx is strictly defined as blockDim.x then it must be set to a - // multiple of the warp and can be considered exact - bool tidx_def_trivial = true; - for (auto entry : concrete_dom_map_.at(tidx_pt)) { - if (!entry->isA() || - !entry->as()->sameAs( - NamedScalar::getParallelDim(tidx_pt))) { - tidx_def_trivial = false; + // If tidx is strictly defined as blockDim.x then it must be set to a + // multiple of the warp and can be considered exact + if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) { + return; } } - if (tidx_def_trivial) { - return; - } } // TIDx is padded to a multiple of warp. If it's known to be a @@ -238,7 +128,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } -Val* ParallelDimensionMap::get(ParallelType pt) const { +Val* ParallelDimensionMap::getRaw(ParallelType pt) const { TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt); auto it = dim_map_.find(pt); if (it == dim_map_.end()) { @@ -248,84 +138,25 @@ Val* ParallelDimensionMap::get(ParallelType pt) const { } } -bool ParallelDimensionMap::isExact(ParallelType pt) const { - return exact_types_.find(pt) != exact_types_.end(); -} - -IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) { - return GpuLower::current()->caMap()->getConcreteMappedID( - id, IdMappingMode::EXACT); -} - -// Symbolically compares equality of two KIR vals. Comparison is done -// conservatively, so returning false does not guarantee non-equality. -bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) { - TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr); - - if (dim1 == dim2) { - return true; - } - - // When Both are Int, they are same if both have the same constant - auto dim1_int = dynamic_cast(dim1); - auto dim2_int = dynamic_cast(dim2); - if (dim1_int && dim2_int) { - if (dim1_int->isConst() && dim2_int->isConst()) { - return dim1_int->value() == dim2_int->value(); - } - } - - // When both are NamedScalar, they are same if Both have the same - // name - auto dim1_ns = dynamic_cast(dim1); - auto dim2_ns = dynamic_cast(dim2); - if (dim1_ns && dim2_ns) { - return dim1_ns->name() == dim2_ns->name(); - } - - // Check recursively their definitions - - auto dim1_def = dim1->definition(); - auto dim2_def = dim2->definition(); - - if (dim1_def == nullptr || dim2_def == nullptr) { - return false; - } - - // If both are BinaryOp or UnaryOp, check their inputs. Since these - // Vals are IterDomain extents, UnaryOp should not occur, but - // checking shouldn't be harmful. - // TODO: - // We might be able to replace this with dim1->toInlineString() == - // dim2->toInlineString() - // If we want this less conservative we could make an "exact map" which - // could be another mode in compute at that maps all iter domains, but not - // concretized broadcast axes and only forwards through non-concretized - // broadcast axes. - if ((dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->getBinaryOpType() == - dim2_def->as()->getBinaryOpType())) || - (dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->getUnaryOpType() == - dim2_def->as()->getUnaryOpType()))) { - for (const auto i : c10::irange(dim1_def->inputs().size())) { - if (!equalDim(dim1_def->inputs().at(i), dim2_def->inputs().at(i))) { - return false; - } - } - return true; +Val* ParallelDimensionMap::get(ParallelType pt) const { + auto raw = getRaw(pt); + if (raw != nullptr && !raw->isConstInt()) { + return NamedScalar::getParallelDim(pt); } + return raw; +} - return false; +bool ParallelDimensionMap::isExact(ParallelType pt) const { + return exact_types_.find(pt) != exact_types_.end(); } std::string ParallelDimensionMap::toString() const { std::stringstream ss; for (auto pt : kParallelTypeThreads) { ss << pt << ": "; - auto dim = get(pt); + auto dim = getRaw(pt); if (dim != nullptr) { - ss << dim->toString(); + ss << dim->toInlineString(); if (isExact(pt)) { ss << ", exact"; } else { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 238cdf64448..43ddb4ec029 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -15,46 +15,30 @@ namespace nvfuser { -//! Maps TID/BID to its dimension. It is by default blockDim/gridDim, -//! but if use of a ParallelType is mapped to a unique constant -//! extent, the constant value is used instead since presumably it's -//! more efficient. +//! Maps TID/BID to its dimension. class TORCH_CUDA_CU_API ParallelDimensionMap { public: void build(Fusion* fusion); //! Returns the dimension of a ParallelType. nullptr is returned if - //! a ParallelType is unused. + //! a ParallelType is unused. If a dimension is not a constant, return + //! blockDim/gridDim instead. Val* get(ParallelType pt) const; + //! Returns the raw dimension of a ParallelType. nullptr is returned if + //! a ParallelType is unused. + Val* getRaw(ParallelType pt) const; + //! True if the dimension of a ParallelType is known to be exact bool isExact(ParallelType pt) const; std::string toString() const; - //! Symbolically analyze if two extent vals are equal - static bool equalDim(Val* dim1, Val* dim2); - private: - //! Register the extent of an IterDomain if its constant - void registerConstantExtent(IterDomain* id); - - void handleParallelDomain(IterDomain* id); - - void populateDimensionMapWithSingleCASet( - ParallelType pt, - const std::unordered_set& dom_set); - - void populateDimensionMapWithMultipleCASet( - ParallelType pt, - const std::unordered_set& dom_set); - //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size. void adjustMappingsForWarpPadding(); - static IterDomain* getCAMappedConcreteDomain(IterDomain* id); - private: //! Maps from parallel types to dimensions, which are constant if //! a unique value is found. @@ -62,18 +46,6 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; - - // Below are temporary maps to build the ParallelType-to-dimension - // map. Only used during build(). - - //! Map from a parallel type to a set of concrete domains where the - //! parallel type is used. - std::unordered_map, TypeHash> - concrete_dom_map_; - //! Keep track of constant extents found for a CA domain set - //! represented by the concrete domain. - std::unordered_map> - constant_extent_map_; }; } // namespace nvfuser diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index b6ad97ea5e1..577708594ec 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -440,6 +440,8 @@ TEST_F(ExprSimplifierTest, EliminateTrivialComputation_CUDA) { TORCH_CHECK(simplifyExpr("b && b"_)->sameAs("b"_)); TORCH_CHECK(simplifyExpr("b || b"_)->sameAs("b"_)); + TORCH_CHECK(simplifyExpr(IrBuilder::maxExpr("i"_, "i"_))->sameAs("i"_)); + TORCH_CHECK(simplifyExpr(IrBuilder::minExpr("i"_, "i"_))->sameAs("i"_)); TORCH_CHECK(simplifyExpr("i / 1"_)->sameAs("i"_)); TORCH_CHECK(simplifyExpr("d / 1.0"_)->sameAs("d"_)); diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 138467b5f95..815bf092001 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -7889,11 +7889,6 @@ TEST_F(NVFuserTest, FusionParallelDimensionMap1_CUDA) { // actual values are not statically known GpuLower gpulw(fusion.get()); const auto& pdmap = gpulw.parallelDimensionMap(); - for (const auto i : c10::irange(tv1->domain()->domain().size())) { - auto dom1 = tv1->domain()->domain()[i]; - auto dom2 = tv2->domain()->domain()[i]; - TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent())); - } TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( From 1bac1307bb44d7a6af91bbebb8b353322fa0baa8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Mar 2023 14:47:08 -0700 Subject: [PATCH 2/4] cleanup --- csrc/parallel_dimension_map.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 3efda2221b5..694df152816 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -65,17 +65,12 @@ void ParallelDimensionMap::build(Fusion* fusion) { if (simplifyExpr(SimplifyingIrBuilder::eqExpr( dim_map_.at(ptype), concrete_id->extent())) ->getBool() != true) { - std::cout << dim_map_.at(ptype)->toInlineString() - << " != " << concrete_id->extent()->toInlineString() - << std::endl; exact_types_.erase(ptype); } } } adjustMappingsForWarpPadding(); - - // std::cout << toString() << std::endl; } void ParallelDimensionMap::adjustMappingsForWarpPadding() { From 548d5d2698f55205cbe84bec8ca2aff2051fb88b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Mar 2023 14:54:31 -0700 Subject: [PATCH 3/4] fix test --- test/test_gpu1.cpp | 12 ++++++------ test/test_gpu2.cpp | 45 ++++++++++++++++++++------------------------- test/test_gpu3.cpp | 16 ++++++++-------- 3 files changed, 34 insertions(+), 39 deletions(-) diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 4f47271c189..f51fd3bfce1 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1202,17 +1202,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i164; - i164 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i164 < T0.size[0])) { + int64_t i409; + i409 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i409 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i164]; + = T1[i409]; float T4[1]; T4[0] = 0; T4[0] - = T0[i164]; + = T0[i409]; float T2[1]; T2[0] = T4[0] @@ -1221,7 +1221,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i164] + T3[i409] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 815bf092001..7b5201a9ca0 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -7950,7 +7950,6 @@ TEST_F(NVFuserTest, FusionParallelDimensionMap2_CUDA) { fusion.get(), outputs, {input1, input2}, {ref}, __LINE__, __FILE__); } -// Mix symbolic and concrete tensors TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -7983,14 +7982,10 @@ TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { GpuLower gpulw(fusion.get()); const auto& pdmap = gpulw.parallelDimensionMap(); - TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); - TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDy)->isConst() && - pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); + ASSERT_FALSE(pdmap.isExact(ParallelType::TIDx)); + ASSERT_EQ(pdmap.get(ParallelType::TIDx)->getInt(), 20); + ASSERT_TRUE(pdmap.isExact(ParallelType::TIDy)); + ASSERT_EQ(pdmap.get(ParallelType::TIDy)->getInt(), 10); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); @@ -9035,27 +9030,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1307; - i1307 = T0.size[2] * T0.size[1]; - int64_t i1310; - i1310 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1312; - i1312 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1344; - i1344 = i1310 % i1312; - int64_t i1321; - i1321 = T0.size[2] * T0.size[3]; - int64_t i1345; - i1345 = i1344 % i1321; - if ((i1310 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1829; + i1829 = T0.size[2] * T0.size[1]; + int64_t i1832; + i1832 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1834; + i1834 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1866; + i1866 = i1832 % i1834; + int64_t i1843; + i1843 = T0.size[2] * T0.size[3]; + int64_t i1867; + i1867 = i1866 % i1843; + if ((i1832 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1307 * T0.size[3]) * (i1310 / i1312)) + (i1307 * (i1345 % T0.size[3]))) + (T0.size[2] * (i1344 / i1321))) + (i1345 / T0.size[3]))]; + = T2[(((((i1829 * T0.size[3]) * (i1832 / i1834)) + (i1829 * (i1867 % T0.size[3]))) + (T0.size[2] * (i1866 / i1843))) + (i1867 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1310]; + = T0[i1832]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9075,7 +9070,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1310] + T7[i1832] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index dbaacc1ce51..0bfd159cff8 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1750,21 +1750,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i111; - i111 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i238; + i238 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b241; - b241 = i111 < i7; + bool b368; + b368 = i238 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b241) { + if (b368) { T1[0] - = sinf(T0[i111]); + = sinf(T0[i238]); } - if (b241) { - T2[i111] + if (b368) { + T2[i238] = T1[0] + f8; } From 1fd73d687fc86e01a611069ffe5988a8916ed568 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Mar 2023 10:10:06 -0700 Subject: [PATCH 4/4] update --- csrc/parallel_dimension_map.cpp | 63 +++++++++++++++++++-------------- csrc/parallel_dimension_map.h | 3 +- test/test_gpu1.cpp | 12 +++---- test/test_gpu2.cpp | 32 ++++++++--------- test/test_gpu3.cpp | 16 ++++----- 5 files changed, 69 insertions(+), 57 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 694df152816..2b80e413d65 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,17 +8,37 @@ #include #include +#include #include #include #include #include +#include #include +#include +#include + +using PAndID = std::pair; + +namespace std { + +template <> +struct hash { + std::size_t operator()(const PAndID& data) const noexcept { + size_t ptype = static_cast(data.first); + size_t address = reinterpret_cast(data.second); + size_t combined = (address << 8) | ptype; + return std::hash()(combined); + } +}; + +} // namespace std namespace nvfuser { void ParallelDimensionMap::build(Fusion* fusion) { - // Scan all TVs to build dim_map_ + VectorOfUniqueEntries all_concrete_ids; auto all_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(all_vals)) { for (auto id : tv->domain()->domain()) { @@ -26,21 +46,24 @@ void ParallelDimensionMap::build(Fusion* fusion) { if (!isParallelTypeThread(ptype)) { continue; } - exact_types_.insert(ptype); // insert now and cleanup later - auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( id, IdMappingMode::EXACT); if (concrete_id->isBroadcast()) { // Broadcasted concrete id's don't specify anything about shape continue; } + all_concrete_ids.pushBack(std::make_pair(ptype, concrete_id)); + } + } - if (dim_map_.count(ptype) == 0) { - dim_map_[ptype] = concrete_id->extent(); - } else { - dim_map_.at(ptype) = SimplifyingIrBuilder::maxExpr( - dim_map_.at(ptype), concrete_id->extent()); - } + // Scan all TVs to build dim_map_ + for (auto [ptype, concrete_id] : all_concrete_ids) { + exact_types_.insert(ptype); // insert now and cleanup later + if (dim_map_.count(ptype) == 0) { + dim_map_[ptype] = concrete_id->extent(); + } else { + dim_map_.at(ptype) = SimplifyingIrBuilder::maxExpr( + dim_map_.at(ptype), concrete_id->extent()); } } @@ -50,23 +73,11 @@ void ParallelDimensionMap::build(Fusion* fusion) { } // Compute exact_types_ - for (auto tv : ir_utils::filterByType(all_vals)) { - for (auto id : tv->domain()->domain()) { - auto ptype = id->getParallelType(); - if (exact_types_.count(ptype) == 0) { - continue; - } - auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( - id, IdMappingMode::EXACT); - if (concrete_id->isBroadcast()) { - // Broadcasted concrete id's don't specify anything about shape - continue; - } - if (simplifyExpr(SimplifyingIrBuilder::eqExpr( - dim_map_.at(ptype), concrete_id->extent())) - ->getBool() != true) { - exact_types_.erase(ptype); - } + for (auto [ptype, concrete_id] : all_concrete_ids) { + if (simplifyExpr(SimplifyingIrBuilder::eqExpr( + dim_map_.at(ptype), concrete_id->extent())) + ->getBool() != true) { + exact_types_.erase(ptype); } } diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 43ddb4ec029..a5eccf0f42e 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -10,8 +10,9 @@ #include #include -#include +#include #include +#include namespace nvfuser { diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index f51fd3bfce1..738b35b02c0 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1202,17 +1202,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i409; - i409 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i409 < T0.size[0])) { + int64_t i241; + i241 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i241 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i409]; + = T1[i241]; float T4[1]; T4[0] = 0; T4[0] - = T0[i409]; + = T0[i241]; float T2[1]; T2[0] = T4[0] @@ -1221,7 +1221,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i409] + T3[i241] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 7b5201a9ca0..dfcf004fe7b 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9030,27 +9030,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i1829; - i1829 = T0.size[2] * T0.size[1]; - int64_t i1832; - i1832 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1834; - i1834 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1866; - i1866 = i1832 % i1834; - int64_t i1843; - i1843 = T0.size[2] * T0.size[3]; - int64_t i1867; - i1867 = i1866 % i1843; - if ((i1832 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + int64_t i1405; + i1405 = T0.size[2] * T0.size[1]; + int64_t i1408; + i1408 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1410; + i1410 = (T0.size[1] * T0.size[2]) * T0.size[3]; + int64_t i1442; + i1442 = i1408 % i1410; + int64_t i1419; + i1419 = T0.size[2] * T0.size[3]; + int64_t i1443; + i1443 = i1442 % i1419; + if ((i1408 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1829 * T0.size[3]) * (i1832 / i1834)) + (i1829 * (i1867 % T0.size[3]))) + (T0.size[2] * (i1866 / i1843))) + (i1867 / T0.size[3]))]; + = T2[(((((i1405 * T0.size[3]) * (i1408 / i1410)) + (i1405 * (i1443 % T0.size[3]))) + (T0.size[2] * (i1442 / i1419))) + (i1443 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1832]; + = T0[i1408]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9070,7 +9070,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1832] + T7[i1408] = T10[0]; } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 0bfd159cff8..21d34e7a688 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1750,21 +1750,21 @@ TEST_F(NVFuserTest, FusionIndexHoist3_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { - int64_t i238; - i238 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); + int64_t i194; + i194 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b368; - b368 = i238 < i7; + bool b324; + b324 = i194 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b368) { + if (b324) { T1[0] - = sinf(T0[i238]); + = sinf(T0[i194]); } - if (b368) { - T2[i238] + if (b324) { + T2[i194] = T1[0] + f8; }