From d1184fa30a557597874021966798fde382ee4242 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Mar 2023 17:11:57 -0700 Subject: [PATCH 1/3] Cleanup `computeLaunchParams` and warp-pad handling --- csrc/executor.cpp | 107 ++++++++++++++------------------ csrc/executor_utils.cpp | 28 --------- csrc/executor_utils.h | 27 -------- csrc/parallel_dimension_map.cpp | 49 ++++++++------- csrc/parallel_dimension_map.h | 4 ++ 5 files changed, 78 insertions(+), 137 deletions(-) diff --git a/csrc/executor.cpp b/csrc/executor.cpp index def03c9ee01..e9937fa4340 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -590,27 +590,20 @@ LaunchParams FusionExecutor::computeLaunchParams( }); auto& parallel_iter_extents = parallel_iter_extent_entry.get(); - auto simplified_parallel_iter_extent_entry = - executor_utils::caching::ExecutorCompileTimeEntry< - executor_utils::caching::SimplifiedParallelIterExtentMap>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getSimplifiedParallelIterExtents( - lower, parallel_binding_ids); - }); - auto& simplified_parallel_iter_extents = - simplified_parallel_iter_extent_entry.get(); - - auto warp_padded_parallel_entry = - executor_utils::caching::ExecutorCompileTimeEntry< - executor_utils::caching::WarpPaddedParallelExtents>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getWarpPaddedExtentsInfo( - lower->kernel(), parallel_binding_ids); - }); - auto& warp_padded_extent_set = - warp_padded_parallel_entry.get().warp_padded_extent_set; - auto& warp_padded_constant = - warp_padded_parallel_entry.get().warp_padded_constant; + const auto& simplified_parallel_iter_extents = + lower->parallelDimensionMap().getMap(); + + // auto warp_padded_parallel_entry = + // executor_utils::caching::ExecutorCompileTimeEntry< + // executor_utils::caching::WarpPaddedParallelExtents>( + // data_cache, [¶llel_binding_ids, &lower]() { + // return executor_utils::getWarpPaddedExtentsInfo( + // lower->kernel(), parallel_binding_ids); + // }); + // auto& warp_padded_extent_set = + // warp_padded_parallel_entry.get().warp_padded_extent_set; + // auto& warp_padded_constant = + // warp_padded_parallel_entry.get().warp_padded_constant; // TODO: Need to redesign this part a bit to // find the right place to trigger evaluate @@ -656,48 +649,42 @@ LaunchParams FusionExecutor::computeLaunchParams( } // Run through the rest of the parallel IterDomains and infer their size - for (auto& entry : simplified_parallel_iter_extents) { + for (auto [p_type, extent] : simplified_parallel_iter_extents) { FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution"); - auto p_type = entry.first; - auto parallel_extents = entry.second; - // Select the maxmimum value out of all the parallel extents - int64_t maximum_value = std::numeric_limits::min(); - for (auto extent : parallel_extents) { - auto val = expr_eval.evaluate(extent); - TORCH_INTERNAL_ASSERT( - val.has_value(), - "Tried to evaluate the extent, ", - extent->toInlineString(), - " for the ptype: ", - p_type, - " to set launch bounds but could not."); - - // apply padding to the extent if needed - if (warp_padded_extent_set.count(extent)) { - // Check if the extent has const value - auto padded_constant_it = warp_padded_constant.find(extent); - - if (padded_constant_it != warp_padded_constant.end()) { - // If already specified padded to constant, need to check - // runtime value not over the constant bound - TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second); - *val = EvaluatorValue(padded_constant_it->second); - } else { - // If no specified constant, pad to the smallest multiple of warp - // above the value. - auto padded_number_of_warps = (*val + warp_size - 1) / warp_size; - *val = warp_size * padded_number_of_warps; - } - TORCH_INTERNAL_ASSERT( - *val <= 1024, "padded dimension larger than max block size"); - } - maximum_value = std::max(maximum_value, val->as()); - } + auto val = expr_eval.evaluate(extent); + // std::cout << "extent: " << extent->toInlineString() << std::endl; + TORCH_INTERNAL_ASSERT( + val.has_value(), + "Tried to evaluate the extent, ", + extent->toInlineString(), + " for the ptype: ", + p_type, + " to set launch bounds but could not."); + + // // apply padding to the extent if needed + // if (warp_padded_extent_set.count(extent)) { + // // Check if the extent has const value + // auto padded_constant_it = warp_padded_constant.find(extent); + + // if (padded_constant_it != warp_padded_constant.end()) { + // // If already specified padded to constant, need to check + // // runtime value not over the constant bound + // TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second); + // *val = EvaluatorValue(padded_constant_it->second); + // } else { + // // If no specified constant, pad to the smallest multiple of warp + // // above the value. + // auto padded_number_of_warps = (*val + warp_size - 1) / warp_size; + // *val = warp_size * padded_number_of_warps; + // } + // TORCH_INTERNAL_ASSERT( + // *val <= 1024, "padded dimension larger than max block size"); + // } // Protect for size-0 tensors, they still have a value so would prefer to // bind nothing than 0 - if (maximum_value > 0) { - expr_eval.bind(p_type, maximum_value); - launch_params.bind(maximum_value, p_type); + if (val->as() > 0) { + expr_eval.bind(p_type, val->as()); + launch_params.bind(val->as(), p_type); } } diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 0f5743f8090..c4dd16244d1 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -1436,7 +1436,6 @@ ExecutorCompileTimeEntry::ExecutorCompileTimeEntry( // Template instantiation template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; -template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; @@ -1498,33 +1497,6 @@ std::unique_ptr getParallelIterExtents( return parallel_iter_extents_ptr; } -std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower* lower, - std::vector& parallel_binding_ids) { - auto parallel_iter_extents_ptr = std::make_unique(); - const auto& ca_map = lower->caMap(); - std::vector mapped; - bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded; - - for (auto id : parallel_binding_ids) { - if (std::any_of( - mapped.begin(), mapped.end(), [id, &ca_map](IterDomain* mapped_id) { - return ca_map->areMapped(mapped_id, id, IdMappingMode::LOOP); - })) { - if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) { - continue; - } - } - - insertParallelExtent( - ca_map->getConcreteMappedID(id, IdMappingMode::LOOP), - parallel_iter_extents_ptr); - mapped.push_back(id); - } - - return parallel_iter_extents_ptr; -} - std::unique_ptr getWarpPaddedExtentsInfo( kir::Kernel* kernel, std::vector& parallel_binding_ids) { diff --git a/csrc/executor_utils.h b/csrc/executor_utils.h index 24f81e167e7..efe0a4acb7d 100644 --- a/csrc/executor_utils.h +++ b/csrc/executor_utils.h @@ -147,27 +147,6 @@ class ParallelIterExtentMap { CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; -//! Compile-time info to be cached in each FusionExecutor: -//! SimplifiedParallelIterExtentMap -//! This entry type is a simplified version of ParallelIterExtentMap. -//! -//! For launch parameter binding we only need the most concrete iterdomain -//! in each disjoint set stored in CaParallelMap. This entry stores the -//! remaining list of extents for binding after this simplification. -//! -//! We still need ParallelIterExtentMap since we want to bind the concrete -//! values to the extents of all parallelized iterdomains. We would be -//! able to save these bindings if the integer machine has a notion of -//! equality and could be configured compile time. But that'd be a longer -//! term target. -class SimplifiedParallelIterExtentMap { - public: - using DataType = - std::unordered_map, TypeHash>; - static const CompileTimeEntryType EntryType = - CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; -}; - //! WarpPaddedExtentsInfo: //! Auxiliary data type for entry class WarpPaddedParallelExtents struct WarpPaddedExtentsInfo { @@ -330,12 +309,6 @@ using ParallelExtentMap = std::unique_ptr getParallelIterExtents( std::vector& parallel_binding_ids); -//! Returns the simplified set of extents necessary for launch parameter -//! binding. -std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower* lower, - std::vector& parallel_binding_ids); - //! Returns the symbolic or constant extetns of warp padded parallel //! iterdomains in the given vector. std::unique_ptr getWarpPaddedExtentsInfo( diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 2b80e413d65..d55b1716c9c 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -98,36 +98,41 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { } const auto tidx_pt = ParallelType::TIDx; - auto warp_size = 32; - - // If the dimension of TIDx is actually a multple of the warp size - // before padding, it can be left as exact - if (isExact(tidx_pt)) { - auto tidx_dim = dynamic_cast(getRaw(tidx_pt)); - if (tidx_dim) { - if (tidx_dim->isConst()) { - auto tidx_dim_val = tidx_dim->value().value(); - if (tidx_dim_val % warp_size == 0) { - // Dimension of TIDx is a multiple of the warp size - return; - } - } - // If tidx is strictly defined as blockDim.x then it must be set to a - // multiple of the warp and can be considered exact - if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) { - return; - } - } + auto warp_size_val = IrBuilder::create(32); + auto tidx_dim = getRaw(tidx_pt); + + TORCH_INTERNAL_ASSERT(tidx_dim != nullptr); + if (false) { + dim_map_.at(ParallelType::TIDx) = + NamedScalar::getParallelDim(ParallelType::TIDx); + exact_types_.erase(ParallelType::TIDx); + return; + } + + // If tidx is strictly defined as blockDim.x then it must be set to a + // multiple of the warp, there is nothing to do + if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) { + return; + } + + // If already multiple of warp, nothing to do + if (simplifyExpr(SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr(tidx_dim, warp_size_val), + tidx_dim->container()->zeroVal())) + ->getBool() == true) { + return; } // TIDx is padded to a multiple of warp. If it's known to be a // single warp, use the constant warp size as the dimension of // TIDx. Otherwise, just use blockDim.x. if (warp_info.is_tidx_single_warp) { - dim_map_.at(ParallelType::TIDx) = IrBuilder::create(warp_size); + dim_map_.at(ParallelType::TIDx) = warp_size_val; } else { dim_map_.at(ParallelType::TIDx) = - NamedScalar::getParallelDim(ParallelType::TIDx); + simplifyExpr(SimplifyingIrBuilder::mulExpr( + SimplifyingIrBuilder::ceilDivExpr(tidx_dim, warp_size_val), + warp_size_val)); } // TIDx is no longer exact diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index a5eccf0f42e..4cf1eab3c98 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -35,6 +35,10 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { std::string toString() const; + const std::unordered_map& getMap() const { + return dim_map_; + } + private: //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size. From e6e19bd2750e9bc8af150f16a4d184de7e9f4d27 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Mar 2023 17:15:56 -0700 Subject: [PATCH 2/3] more cleanup --- csrc/executor.cpp | 34 ---------------------------------- csrc/executor_utils.cpp | 31 ------------------------------- csrc/executor_utils.h | 24 ------------------------ 3 files changed, 89 deletions(-) diff --git a/csrc/executor.cpp b/csrc/executor.cpp index e9937fa4340..55f04666c96 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -593,18 +593,6 @@ LaunchParams FusionExecutor::computeLaunchParams( const auto& simplified_parallel_iter_extents = lower->parallelDimensionMap().getMap(); - // auto warp_padded_parallel_entry = - // executor_utils::caching::ExecutorCompileTimeEntry< - // executor_utils::caching::WarpPaddedParallelExtents>( - // data_cache, [¶llel_binding_ids, &lower]() { - // return executor_utils::getWarpPaddedExtentsInfo( - // lower->kernel(), parallel_binding_ids); - // }); - // auto& warp_padded_extent_set = - // warp_padded_parallel_entry.get().warp_padded_extent_set; - // auto& warp_padded_constant = - // warp_padded_parallel_entry.get().warp_padded_constant; - // TODO: Need to redesign this part a bit to // find the right place to trigger evaluate if (expr_eval.precomputedValues()) { @@ -652,7 +640,6 @@ LaunchParams FusionExecutor::computeLaunchParams( for (auto [p_type, extent] : simplified_parallel_iter_extents) { FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution"); auto val = expr_eval.evaluate(extent); - // std::cout << "extent: " << extent->toInlineString() << std::endl; TORCH_INTERNAL_ASSERT( val.has_value(), "Tried to evaluate the extent, ", @@ -661,27 +648,6 @@ LaunchParams FusionExecutor::computeLaunchParams( p_type, " to set launch bounds but could not."); - // // apply padding to the extent if needed - // if (warp_padded_extent_set.count(extent)) { - // // Check if the extent has const value - // auto padded_constant_it = warp_padded_constant.find(extent); - - // if (padded_constant_it != warp_padded_constant.end()) { - // // If already specified padded to constant, need to check - // // runtime value not over the constant bound - // TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second); - // *val = EvaluatorValue(padded_constant_it->second); - // } else { - // // If no specified constant, pad to the smallest multiple of warp - // // above the value. - // auto padded_number_of_warps = (*val + warp_size - 1) / warp_size; - // *val = warp_size * padded_number_of_warps; - // } - // TORCH_INTERNAL_ASSERT( - // *val <= 1024, "padded dimension larger than max block size"); - // } - // Protect for size-0 tensors, they still have a value so would prefer to - // bind nothing than 0 if (val->as() > 0) { expr_eval.bind(p_type, val->as()); launch_params.bind(val->as(), p_type); diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index c4dd16244d1..2d3c3e92c14 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -1436,7 +1436,6 @@ ExecutorCompileTimeEntry::ExecutorCompileTimeEntry( // Template instantiation template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; -template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; @@ -1497,35 +1496,5 @@ std::unique_ptr getParallelIterExtents( return parallel_iter_extents_ptr; } -std::unique_ptr getWarpPaddedExtentsInfo( - kir::Kernel* kernel, - std::vector& parallel_binding_ids) { - auto warp_padded_extent_info_ptr = - std::make_unique(); - auto& warp_padded_extent_set = - warp_padded_extent_info_ptr->warp_padded_extent_set; - auto& warp_padded_constant = - warp_padded_extent_info_ptr->warp_padded_constant; - bool has_warp_reduction = - kernel->getWarpPaddedParallelInfo().has_warp_reduction; - - for (auto id : parallel_binding_ids) { - // Apply warp padding only when there're warp reductions in - // the kernel. - if (has_warp_reduction) { - if (id->hasPaddingToMultipleOfWarp() || - kernel->isParallelTypePadded(id->getParallelType())) { - auto extent = id->extent(); - warp_padded_extent_set.insert(extent); - auto padded_value = id->getMaybeSizeAfterPadding(); - if (padded_value.has_value()) { - warp_padded_constant[extent] = padded_value.value(); - } - } - } - } - return warp_padded_extent_info_ptr; -} - } // namespace executor_utils } // namespace nvfuser diff --git a/csrc/executor_utils.h b/csrc/executor_utils.h index efe0a4acb7d..61cb8c3a56e 100644 --- a/csrc/executor_utils.h +++ b/csrc/executor_utils.h @@ -147,24 +147,6 @@ class ParallelIterExtentMap { CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; -//! WarpPaddedExtentsInfo: -//! Auxiliary data type for entry class WarpPaddedParallelExtents -struct WarpPaddedExtentsInfo { - std::unordered_set warp_padded_extent_set; - std::unordered_map warp_padded_constant; -}; - -//! Compile-time info to be cached in each FusionExecutor: -//! WarpPaddedParallelExtents -//! Stores the symbolic and constant extents of warp -//! padded parallel iterdomains. -class WarpPaddedParallelExtents { - public: - using DataType = WarpPaddedExtentsInfo; - static const CompileTimeEntryType EntryType = - CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS; -}; - //! VectorizedTensorInfo: //! Auxiliary data type for entry class VectorizedTensorValidation struct VectorizedTensorInfo { @@ -309,12 +291,6 @@ using ParallelExtentMap = std::unique_ptr getParallelIterExtents( std::vector& parallel_binding_ids); -//! Returns the symbolic or constant extetns of warp padded parallel -//! iterdomains in the given vector. -std::unique_ptr getWarpPaddedExtentsInfo( - kir::Kernel* lower, - std::vector& parallel_binding_ids); - void validateVectorizedTensors( kir::Kernel* kernel, const KernelArgumentHolder& args, From 70a1f41f763aa2c7bb6a95c56869c6e9497ae8d2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Mar 2023 17:20:20 -0700 Subject: [PATCH 3/3] cleanup --- csrc/parallel_dimension_map.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index d55b1716c9c..9888f2fc848 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -102,12 +102,6 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { auto tidx_dim = getRaw(tidx_pt); TORCH_INTERNAL_ASSERT(tidx_dim != nullptr); - if (false) { - dim_map_.at(ParallelType::TIDx) = - NamedScalar::getParallelDim(ParallelType::TIDx); - exact_types_.erase(ParallelType::TIDx); - return; - } // If tidx is strictly defined as blockDim.x then it must be set to a // multiple of the warp, there is nothing to do