Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 15 additions & 62 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,27 +590,8 @@ LaunchParams FusionExecutor::computeLaunchParams(
});
auto& parallel_iter_extents = parallel_iter_extent_entry.get();

auto simplified_parallel_iter_extent_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::SimplifiedParallelIterExtentMap>(
data_cache, [&parallel_binding_ids, &lower]() {
return executor_utils::getSimplifiedParallelIterExtents(
lower, parallel_binding_ids);
});
auto& simplified_parallel_iter_extents =
simplified_parallel_iter_extent_entry.get();

auto warp_padded_parallel_entry =
executor_utils::caching::ExecutorCompileTimeEntry<
executor_utils::caching::WarpPaddedParallelExtents>(
data_cache, [&parallel_binding_ids, &lower]() {
return executor_utils::getWarpPaddedExtentsInfo(
lower->kernel(), parallel_binding_ids);
});
auto& warp_padded_extent_set =
warp_padded_parallel_entry.get().warp_padded_extent_set;
auto& warp_padded_constant =
warp_padded_parallel_entry.get().warp_padded_constant;
const auto& simplified_parallel_iter_extents =
lower->parallelDimensionMap().getMap();

// TODO: Need to redesign this part a bit to
// find the right place to trigger evaluate
Expand Down Expand Up @@ -656,48 +637,20 @@ LaunchParams FusionExecutor::computeLaunchParams(
}

// Run through the rest of the parallel IterDomains and infer their size
for (auto& entry : simplified_parallel_iter_extents) {
for (auto [p_type, extent] : simplified_parallel_iter_extents) {
FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution");
auto p_type = entry.first;
auto parallel_extents = entry.second;
// Select the maxmimum value out of all the parallel extents
int64_t maximum_value = std::numeric_limits<int64_t>::min();
for (auto extent : parallel_extents) {
auto val = expr_eval.evaluate(extent);
TORCH_INTERNAL_ASSERT(
val.has_value(),
"Tried to evaluate the extent, ",
extent->toInlineString(),
" for the ptype: ",
p_type,
" to set launch bounds but could not.");

// apply padding to the extent if needed
if (warp_padded_extent_set.count(extent)) {
// Check if the extent has const value
auto padded_constant_it = warp_padded_constant.find(extent);

if (padded_constant_it != warp_padded_constant.end()) {
// If already specified padded to constant, need to check
// runtime value not over the constant bound
TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second);
*val = EvaluatorValue(padded_constant_it->second);
} else {
// If no specified constant, pad to the smallest multiple of warp
// above the value.
auto padded_number_of_warps = (*val + warp_size - 1) / warp_size;
*val = warp_size * padded_number_of_warps;
}
TORCH_INTERNAL_ASSERT(
*val <= 1024, "padded dimension larger than max block size");
}
maximum_value = std::max(maximum_value, val->as<int64_t>());
}
// Protect for size-0 tensors, they still have a value so would prefer to
// bind nothing than 0
if (maximum_value > 0) {
expr_eval.bind(p_type, maximum_value);
launch_params.bind(maximum_value, p_type);
auto val = expr_eval.evaluate(extent);
TORCH_INTERNAL_ASSERT(
val.has_value(),
"Tried to evaluate the extent, ",
extent->toInlineString(),
" for the ptype: ",
p_type,
" to set launch bounds but could not.");

if (val->as<int64_t>() > 0) {
expr_eval.bind(p_type, val->as<int64_t>());
launch_params.bind(val->as<int64_t>(), p_type);
}
}

Expand Down
59 changes: 0 additions & 59 deletions csrc/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,6 @@ ExecutorCompileTimeEntry<EntryClass>::ExecutorCompileTimeEntry(
// Template instantiation
template class ExecutorCompileTimeEntry<ParallelBindingIterDomains>;
template class ExecutorCompileTimeEntry<ParallelIterExtentMap>;
template class ExecutorCompileTimeEntry<SimplifiedParallelIterExtentMap>;
template class ExecutorCompileTimeEntry<WarpPaddedParallelExtents>;
template class ExecutorCompileTimeEntry<VectorizedTensorValidation>;
template class ExecutorCompileTimeEntry<InputAliasIndices>;
template class ExecutorCompileTimeEntry<OutputAliasIndices>;
Expand Down Expand Up @@ -1498,62 +1496,5 @@ std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
return parallel_iter_extents_ptr;
}

std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
GpuLower* lower,
std::vector<IterDomain*>& parallel_binding_ids) {
auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>();
const auto& ca_map = lower->caMap();
std::vector<IterDomain*> mapped;
bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded;

for (auto id : parallel_binding_ids) {
if (std::any_of(
mapped.begin(), mapped.end(), [id, &ca_map](IterDomain* mapped_id) {
return ca_map->areMapped(mapped_id, id, IdMappingMode::LOOP);
})) {
if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) {
continue;
}
}

insertParallelExtent(
ca_map->getConcreteMappedID(id, IdMappingMode::LOOP),
parallel_iter_extents_ptr);
mapped.push_back(id);
}

return parallel_iter_extents_ptr;
}

std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
kir::Kernel* kernel,
std::vector<IterDomain*>& parallel_binding_ids) {
auto warp_padded_extent_info_ptr =
std::make_unique<caching::WarpPaddedExtentsInfo>();
auto& warp_padded_extent_set =
warp_padded_extent_info_ptr->warp_padded_extent_set;
auto& warp_padded_constant =
warp_padded_extent_info_ptr->warp_padded_constant;
bool has_warp_reduction =
kernel->getWarpPaddedParallelInfo().has_warp_reduction;

for (auto id : parallel_binding_ids) {
// Apply warp padding only when there're warp reductions in
// the kernel.
if (has_warp_reduction) {
if (id->hasPaddingToMultipleOfWarp() ||
kernel->isParallelTypePadded(id->getParallelType())) {
auto extent = id->extent();
warp_padded_extent_set.insert(extent);
auto padded_value = id->getMaybeSizeAfterPadding();
if (padded_value.has_value()) {
warp_padded_constant[extent] = padded_value.value();
}
}
}
}
return warp_padded_extent_info_ptr;
}

} // namespace executor_utils
} // namespace nvfuser
51 changes: 0 additions & 51 deletions csrc/executor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,45 +147,6 @@ class ParallelIterExtentMap {
CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP;
};

//! Compile-time info to be cached in each FusionExecutor:
//! SimplifiedParallelIterExtentMap
//! This entry type is a simplified version of ParallelIterExtentMap.
//!
//! For launch parameter binding we only need the most concrete iterdomain
//! in each disjoint set stored in CaParallelMap. This entry stores the
//! remaining list of extents for binding after this simplification.
//!
//! We still need ParallelIterExtentMap since we want to bind the concrete
//! values to the extents of all parallelized iterdomains. We would be
//! able to save these bindings if the integer machine has a notion of
//! equality and could be configured compile time. But that'd be a longer
//! term target.
class SimplifiedParallelIterExtentMap {
public:
using DataType =
std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP;
};

//! WarpPaddedExtentsInfo:
//! Auxiliary data type for entry class WarpPaddedParallelExtents
struct WarpPaddedExtentsInfo {
std::unordered_set<const Val*> warp_padded_extent_set;
std::unordered_map<const Val*, int64_t> warp_padded_constant;
};

//! Compile-time info to be cached in each FusionExecutor:
//! WarpPaddedParallelExtents
//! Stores the symbolic and constant extents of warp
//! padded parallel iterdomains.
class WarpPaddedParallelExtents {
public:
using DataType = WarpPaddedExtentsInfo;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS;
};

//! VectorizedTensorInfo:
//! Auxiliary data type for entry class VectorizedTensorValidation
struct VectorizedTensorInfo {
Expand Down Expand Up @@ -330,18 +291,6 @@ using ParallelExtentMap =
std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
std::vector<IterDomain*>& parallel_binding_ids);

//! Returns the simplified set of extents necessary for launch parameter
//! binding.
std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
GpuLower* lower,
std::vector<IterDomain*>& parallel_binding_ids);

//! Returns the symbolic or constant extetns of warp padded parallel
//! iterdomains in the given vector.
std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
kir::Kernel* lower,
std::vector<IterDomain*>& parallel_binding_ids);

void validateVectorizedTensors(
kir::Kernel* kernel,
const KernelArgumentHolder& args,
Expand Down
43 changes: 21 additions & 22 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,36 +98,35 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
}

const auto tidx_pt = ParallelType::TIDx;
auto warp_size = 32;

// If the dimension of TIDx is actually a multple of the warp size
// before padding, it can be left as exact
if (isExact(tidx_pt)) {
auto tidx_dim = dynamic_cast<Int*>(getRaw(tidx_pt));
if (tidx_dim) {
if (tidx_dim->isConst()) {
auto tidx_dim_val = tidx_dim->value().value();
if (tidx_dim_val % warp_size == 0) {
// Dimension of TIDx is a multiple of the warp size
return;
}
}
// If tidx is strictly defined as blockDim.x then it must be set to a
// multiple of the warp and can be considered exact
if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) {
return;
}
}
auto warp_size_val = IrBuilder::create<Int>(32);
auto tidx_dim = getRaw(tidx_pt);

TORCH_INTERNAL_ASSERT(tidx_dim != nullptr);

// If tidx is strictly defined as blockDim.x then it must be set to a
// multiple of the warp, there is nothing to do
if (tidx_dim->sameAs(NamedScalar::getParallelDim(tidx_pt))) {
return;
}

// If already multiple of warp, nothing to do
if (simplifyExpr(SimplifyingIrBuilder::eqExpr(
SimplifyingIrBuilder::modExpr(tidx_dim, warp_size_val),
tidx_dim->container()->zeroVal()))
->getBool() == true) {
return;
}

// TIDx is padded to a multiple of warp. If it's known to be a
// single warp, use the constant warp size as the dimension of
// TIDx. Otherwise, just use blockDim.x.
if (warp_info.is_tidx_single_warp) {
dim_map_.at(ParallelType::TIDx) = IrBuilder::create<Int>(warp_size);
dim_map_.at(ParallelType::TIDx) = warp_size_val;
} else {
dim_map_.at(ParallelType::TIDx) =
NamedScalar::getParallelDim(ParallelType::TIDx);
simplifyExpr(SimplifyingIrBuilder::mulExpr(
SimplifyingIrBuilder::ceilDivExpr(tidx_dim, warp_size_val),
warp_size_val));
}

// TIDx is no longer exact
Expand Down
4 changes: 4 additions & 0 deletions csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class TORCH_CUDA_CU_API ParallelDimensionMap {

std::string toString() const;

const std::unordered_map<ParallelType, Val*, TypeHash>& getMap() const {
return dim_map_;
}

private:
//! TIDx may need to be marked as non-exact as it may be padded to a
//! multiple of the warp size.
Expand Down