diff --git a/csrc/executor.cpp b/csrc/executor.cpp index def03c9ee01..55f04666c96 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -590,27 +590,8 @@ LaunchParams FusionExecutor::computeLaunchParams( }); auto& parallel_iter_extents = parallel_iter_extent_entry.get(); - auto simplified_parallel_iter_extent_entry = - executor_utils::caching::ExecutorCompileTimeEntry< - executor_utils::caching::SimplifiedParallelIterExtentMap>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getSimplifiedParallelIterExtents( - lower, parallel_binding_ids); - }); - auto& simplified_parallel_iter_extents = - simplified_parallel_iter_extent_entry.get(); - - auto warp_padded_parallel_entry = - executor_utils::caching::ExecutorCompileTimeEntry< - executor_utils::caching::WarpPaddedParallelExtents>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getWarpPaddedExtentsInfo( - lower->kernel(), parallel_binding_ids); - }); - auto& warp_padded_extent_set = - warp_padded_parallel_entry.get().warp_padded_extent_set; - auto& warp_padded_constant = - warp_padded_parallel_entry.get().warp_padded_constant; + const auto& simplified_parallel_iter_extents = + lower->parallelDimensionMap().getMap(); // TODO: Need to redesign this part a bit to // find the right place to trigger evaluate @@ -656,48 +637,20 @@ LaunchParams FusionExecutor::computeLaunchParams( } // Run through the rest of the parallel IterDomains and infer their size - for (auto& entry : simplified_parallel_iter_extents) { + for (auto [p_type, extent] : simplified_parallel_iter_extents) { FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution"); - auto p_type = entry.first; - auto parallel_extents = entry.second; - // Select the maxmimum value out of all the parallel extents - int64_t maximum_value = std::numeric_limits::min(); - for (auto extent : parallel_extents) { - auto val = expr_eval.evaluate(extent); - TORCH_INTERNAL_ASSERT( - val.has_value(), - "Tried to evaluate the extent, ", - extent->toInlineString(), - " for the ptype: ", - p_type, - " to set launch bounds but could not."); - - // apply padding to the extent if needed - if (warp_padded_extent_set.count(extent)) { - // Check if the extent has const value - auto padded_constant_it = warp_padded_constant.find(extent); - - if (padded_constant_it != warp_padded_constant.end()) { - // If already specified padded to constant, need to check - // runtime value not over the constant bound - TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second); - *val = EvaluatorValue(padded_constant_it->second); - } else { - // If no specified constant, pad to the smallest multiple of warp - // above the value. - auto padded_number_of_warps = (*val + warp_size - 1) / warp_size; - *val = warp_size * padded_number_of_warps; - } - TORCH_INTERNAL_ASSERT( - *val <= 1024, "padded dimension larger than max block size"); - } - maximum_value = std::max(maximum_value, val->as()); - } - // Protect for size-0 tensors, they still have a value so would prefer to - // bind nothing than 0 - if (maximum_value > 0) { - expr_eval.bind(p_type, maximum_value); - launch_params.bind(maximum_value, p_type); + auto val = expr_eval.evaluate(extent); + TORCH_INTERNAL_ASSERT( + val.has_value(), + "Tried to evaluate the extent, ", + extent->toInlineString(), + " for the ptype: ", + p_type, + " to set launch bounds but could not."); + + if (val->as() > 0) { + expr_eval.bind(p_type, val->as()); + launch_params.bind(val->as(), p_type); } } diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 0f5743f8090..2d3c3e92c14 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -1436,8 +1436,6 @@ ExecutorCompileTimeEntry::ExecutorCompileTimeEntry( // Template instantiation template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; -template class ExecutorCompileTimeEntry; -template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; template class ExecutorCompileTimeEntry; @@ -1498,62 +1496,5 @@ std::unique_ptr getParallelIterExtents( return parallel_iter_extents_ptr; } -std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower* lower, - std::vector& parallel_binding_ids) { - auto parallel_iter_extents_ptr = std::make_unique(); - const auto& ca_map = lower->caMap(); - std::vector mapped; - bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded; - - for (auto id : parallel_binding_ids) { - if (std::any_of( - mapped.begin(), mapped.end(), [id, &ca_map](IterDomain* mapped_id) { - return ca_map->areMapped(mapped_id, id, IdMappingMode::LOOP); - })) { - if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) { - continue; - } - } - - insertParallelExtent( - ca_map->getConcreteMappedID(id, IdMappingMode::LOOP), - parallel_iter_extents_ptr); - mapped.push_back(id); - } - - return parallel_iter_extents_ptr; -} - -std::unique_ptr getWarpPaddedExtentsInfo( - kir::Kernel* kernel, - std::vector& parallel_binding_ids) { - auto warp_padded_extent_info_ptr = - std::make_unique(); - auto& warp_padded_extent_set = - warp_padded_extent_info_ptr->warp_padded_extent_set; - auto& warp_padded_constant = - warp_padded_extent_info_ptr->warp_padded_constant; - bool has_warp_reduction = - kernel->getWarpPaddedParallelInfo().has_warp_reduction; - - for (auto id : parallel_binding_ids) { - // Apply warp padding only when there're warp reductions in - // the kernel. - if (has_warp_reduction) { - if (id->hasPaddingToMultipleOfWarp() || - kernel->isParallelTypePadded(id->getParallelType())) { - auto extent = id->extent(); - warp_padded_extent_set.insert(extent); - auto padded_value = id->getMaybeSizeAfterPadding(); - if (padded_value.has_value()) { - warp_padded_constant[extent] = padded_value.value(); - } - } - } - } - return warp_padded_extent_info_ptr; -} - } // namespace executor_utils } // namespace nvfuser diff --git a/csrc/executor_utils.h b/csrc/executor_utils.h index 24f81e167e7..61cb8c3a56e 100644 --- a/csrc/executor_utils.h +++ b/csrc/executor_utils.h @@ -147,45 +147,6 @@ class ParallelIterExtentMap { CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; -//! Compile-time info to be cached in each FusionExecutor: -//! SimplifiedParallelIterExtentMap -//! This entry type is a simplified version of ParallelIterExtentMap. -//! -//! For launch parameter binding we only need the most concrete iterdomain -//! in each disjoint set stored in CaParallelMap. This entry stores the -//! remaining list of extents for binding after this simplification. -//! -//! We still need ParallelIterExtentMap since we want to bind the concrete -//! values to the extents of all parallelized iterdomains. We would be -//! able to save these bindings if the integer machine has a notion of -//! equality and could be configured compile time. But that'd be a longer -//! term target. -class SimplifiedParallelIterExtentMap { - public: - using DataType = - std::unordered_map, TypeHash>; - static const CompileTimeEntryType EntryType = - CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; -}; - -//! WarpPaddedExtentsInfo: -//! Auxiliary data type for entry class WarpPaddedParallelExtents -struct WarpPaddedExtentsInfo { - std::unordered_set warp_padded_extent_set; - std::unordered_map warp_padded_constant; -}; - -//! Compile-time info to be cached in each FusionExecutor: -//! WarpPaddedParallelExtents -//! Stores the symbolic and constant extents of warp -//! padded parallel iterdomains. -class WarpPaddedParallelExtents { - public: - using DataType = WarpPaddedExtentsInfo; - static const CompileTimeEntryType EntryType = - CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS; -}; - //! VectorizedTensorInfo: //! Auxiliary data type for entry class VectorizedTensorValidation struct VectorizedTensorInfo { @@ -330,18 +291,6 @@ using ParallelExtentMap = std::unique_ptr getParallelIterExtents( std::vector& parallel_binding_ids); -//! Returns the simplified set of extents necessary for launch parameter -//! binding. -std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower* lower, - std::vector& parallel_binding_ids); - -//! Returns the symbolic or constant extetns of warp padded parallel -//! iterdomains in the given vector. -std::unique_ptr getWarpPaddedExtentsInfo( - kir::Kernel* lower, - std::vector& parallel_binding_ids); - void validateVectorizedTensors( kir::Kernel* kernel, const KernelArgumentHolder& args, diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 2b80e413d65..9888f2fc848 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -98,36 +98,35 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { } const auto tidx_pt = ParallelType::TIDx; - auto warp_size = 32; - - // If the dimension of TIDx is actually a multple of the warp size - // before padding, it can be left as exact - if (isExact(tidx_pt)) { - auto tidx_dim = dynamic_cast(getRaw(tidx_pt)); - if (tidx_dim) { - if (tidx_dim->isConst()) { - auto tidx_dim_val = tidx_dim->value().value(); - if (tidx_dim_val % warp_size == 0) { - // Dimension of TIDx is a multiple of the warp size - return; - } - } - // If tidx is strictly defined as blockDim.x then it must be set to a - // multiple of the warp and can be considered exact - if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) { - return; - } - } + auto warp_size_val = IrBuilder::create(32); + auto tidx_dim = getRaw(tidx_pt); + + TORCH_INTERNAL_ASSERT(tidx_dim != nullptr); + + // If tidx is strictly defined as blockDim.x then it must be set to a + // multiple of the warp, there is nothing to do + if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) { + return; + } + + // If already multiple of warp, nothing to do + if (simplifyExpr(SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr(tidx_dim, warp_size_val), + tidx_dim->container()->zeroVal())) + ->getBool() == true) { + return; } // TIDx is padded to a multiple of warp. If it's known to be a // single warp, use the constant warp size as the dimension of // TIDx. Otherwise, just use blockDim.x. if (warp_info.is_tidx_single_warp) { - dim_map_.at(ParallelType::TIDx) = IrBuilder::create(warp_size); + dim_map_.at(ParallelType::TIDx) = warp_size_val; } else { dim_map_.at(ParallelType::TIDx) = - NamedScalar::getParallelDim(ParallelType::TIDx); + simplifyExpr(SimplifyingIrBuilder::mulExpr( + SimplifyingIrBuilder::ceilDivExpr(tidx_dim, warp_size_val), + warp_size_val)); } // TIDx is no longer exact diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index a5eccf0f42e..4cf1eab3c98 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -35,6 +35,10 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { std::string toString() const; + const std::unordered_map& getMap() const { + return dim_map_; + } + private: //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size.