Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmark/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ static void Softmax_WarpReduceReference(benchmark::State& benchmark_state) {
// Schedule through magic scheduler:
SchedulerRuntimeInfo runtime_info(fusion, aten_inputs);
NVF_ERROR(SchedulerEntry::canSchedule(
ScheduleHeuristic::InnerPersistent, fusion, runtime_info));
ScheduleHeuristic::Persistent, fusion, runtime_info));
auto scheduler = SchedulerEntry::makeEntry(
ScheduleHeuristic::InnerPersistent, fusion, runtime_info);
ScheduleHeuristic::Persistent, fusion, runtime_info);
scheduler->schedule(fusion);

FusionExecutor fe;
Expand Down
1 change: 0 additions & 1 deletion csrc/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <scheduler/normalization.h>
#include <scheduler/normalization_inner.h>
#include <scheduler/normalization_inner_outer.h>
#include <scheduler/normalization_outer.h>
#include <scheduler/pointwise.h>
#include <scheduler/reduction.h>
#include <scheduler/transpose.h>
2 changes: 0 additions & 2 deletions csrc/scheduler/heuristic_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ std::string toString(ScheduleHeuristic sh) {
return "reduction";
case ScheduleHeuristic::InnerPersistent:
return "inner_persistent";
case ScheduleHeuristic::OuterPersistent:
return "outer_persistent";
case ScheduleHeuristic::InnerOuterPersistent:
return "inner_outer_persistent";
case ScheduleHeuristic::Persistent:
Expand Down
4 changes: 1 addition & 3 deletions csrc/scheduler/heuristic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,17 @@ enum class ScheduleHeuristic {
InnerPersistent,
InnerOuterPersistent,
Persistent,
OuterPersistent,
Transpose,
Matmul
};

//! Define a schedule table to loop over all the heuristics in priority order.
constexpr std::array<ScheduleHeuristic, 9> all_heuristics_in_priority_order = {
constexpr std::array<ScheduleHeuristic, 8> all_heuristics_in_priority_order = {
ScheduleHeuristic::NoOp,
ScheduleHeuristic::Reduction,
ScheduleHeuristic::Transpose,
ScheduleHeuristic::PointWise,
ScheduleHeuristic::InnerPersistent,
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::InnerOuterPersistent,
ScheduleHeuristic::Persistent,
ScheduleHeuristic::Matmul};
Expand Down
4 changes: 0 additions & 4 deletions csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ void PersistentKernelScheduler::schedule(Fusion* fusion) {
}

bool PersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
// This scheduler is being divided into three separate schedulers and should
// be deleted. Disable the use of this scheduler for now.
return false;

// Needs at least one reduction to consider.
auto reduction_ops = ir_utils::getReductionOps(fusion);
if (reduction_ops.empty()) {
Expand Down
72 changes: 31 additions & 41 deletions csrc/scheduler/normalization_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ OuterPersistentKernelScheduler::OuterPersistentKernelScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache)
: SchedulerEntry(ScheduleHeuristic::OuterPersistent) {
: SchedulerEntry(ScheduleHeuristic::Persistent) {
computeHeuristics(fusion, runtime_info, data_cache);
}

Expand All @@ -50,33 +50,33 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
auto reduction_ops = ir_utils::getReductionOps(fusion);
if (reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "needs a reduction op");
ScheduleHeuristic::Persistent, "needs a reduction op");
return false;
}

if (ir_utils::filterByType<TensorView>(fusion->inputs()).empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"Scheduling not supported with no input");
return false;
}

// Check that inputs of all select/gather-like ops are fusion inputs
if (registry_utils::rejectScheduleForMemoryPromotion(
fusion, ScheduleHeuristic::OuterPersistent)) {
fusion, ScheduleHeuristic::Persistent)) {
return false;
}

// Fusions handled by persistent kernel scheduler cannot have MmaOp.
if (!ir_utils::getMmaOps(fusion).empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "no support for mma ops.");
ScheduleHeuristic::Persistent, "no support for mma ops.");
return false;
}

if (registry_utils::hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}
Expand All @@ -86,7 +86,7 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
if (reduction_tvs.empty()) {
// Use pointwise logic
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "no reduction tv");
ScheduleHeuristic::Persistent, "no reduction tv");
return false;
}

Expand All @@ -101,14 +101,6 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
}
bool combined_inner_outer =
!inner_reduction_tvs.empty() && !outer_reduction_tvs.empty();

if (!inner_reduction_tvs.empty() || outer_reduction_tvs.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
"ScheduleHeuristic::OuterPersistent requires outer reduction tvs without inner reduction tvs.");
return false;
}

if (!checkReductionPattern(
fusion, inner_reduction_tvs, outer_reduction_tvs)) {
return false;
Expand All @@ -123,7 +115,7 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
ComputeAtMap ca_map(fusion);
if (registry_utils::requiresForwardViewReplay(fusion, ca_map)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"Fusion requires view being reversible.");
return false;
}
Expand All @@ -133,7 +125,7 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
if (registry_utils::reductionInterferingView(
fusion, ca_map, reference_tv)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"View may interfere with normalization scheduling.");
return false;
}
Expand Down Expand Up @@ -161,7 +153,7 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
} else {
if (reduction_root_size(red) != axis_count) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"inconsistent reduction root size: ",
red->toString(),
", expected: ",
Expand All @@ -175,22 +167,22 @@ bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) {
auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion);
if (persistent_buffer_info.persistent_buffers.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "no persistent buffer identified");
ScheduleHeuristic::Persistent, "no persistent buffer identified");
return false;
}

if (registry_utils::SchedulerTopologyChecker::
hasNonNormalizePostReductionBCast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"unsupported post reduction normalization");
return false;
}

if (registry_utils::SchedulerTopologyChecker::
hasGatherToBroadcastBeforeReduction(fusion, reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"has unsupported gather-like ops before normalization");
return false;
}
Expand Down Expand Up @@ -251,7 +243,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime(

if (persistent_buffer_size > available_persistent_buffer_size) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"not enough registers or shared memory for persistence");
return false;
}
Expand All @@ -277,7 +269,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime(
false)
.first.has_value()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"Required batch number is larger than available batch number! Will cause register spills!");
return false;
}
Expand All @@ -300,8 +292,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime(
if (required_sm_per_norm >
scheduler_utils::safeDiv(device_multiprocessor_count, 3)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
"requires over half GPU persistence.");
ScheduleHeuristic::Persistent, "requires over half GPU persistence.");
return false;
}

Expand All @@ -314,7 +305,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime(
!(norm_per_sm >= warp_size / 2 ||
max_multi_reduction_factor >= warp_size)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "not enough threads");
ScheduleHeuristic::Persistent, "not enough threads");
return false;
}

Expand All @@ -330,7 +321,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime(
// half warp
: (warp_size / 8) * device_multiprocessor_count)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "not enough blocks");
ScheduleHeuristic::Persistent, "not enough blocks");
return false;
}

Expand Down Expand Up @@ -360,7 +351,7 @@ bool OuterPersistentKernelScheduler::checkReductionPattern(
if (!registry_utils::checkPatternEquivalence(
rtvs[it - 1], rtvs[it], root_map)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"unmapped reduction ",
rtvs[it - 1],
" and ",
Expand All @@ -375,23 +366,23 @@ bool OuterPersistentKernelScheduler::checkReductionPattern(
if (!normalization_scheduler_utils::checkIfReductionsAreInnerOuter(
inner_reduction_tvs, outer_reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"to use combined reduction, inner reduction tensor should be [I,I,...,R,R] and outer reduction tensor should be [R,R,...,I,I]");
return false;
}

if (!normalization_scheduler_utils::hasSharedInput(
inner_reduction_tvs, outer_reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"to use combined reduction, inner reduction and outer reduction should have shared input.");
return false;
}

if (!normalization_scheduler_utils::isConnectedOnlyThroughReductionProducer(
inner_reduction_tvs, outer_reduction_tvs)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"to use combined reduction, inner reduction and outer reduction should not have shared consumer, their consumers should not have shared non-outer-reduction producer.");
return false;
}
Expand Down Expand Up @@ -529,8 +520,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(

if (persistent_buffer_size > available_persistent_buffer_size) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
"not enough registers for persistence");
ScheduleHeuristic::Persistent, "not enough registers for persistence");
return false;
}

Expand All @@ -557,7 +547,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
if (required_sm_per_norm >
scheduler_utils::safeDiv(device_multiprocessor_count, 2)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"requires over half GPU persistence.",
" required SMs per normalization: ",
required_sm_per_norm);
Expand All @@ -575,7 +565,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
// TODO: Is this necessary for block persistence as well?
if (vectorization_factor < 4) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "not enough vectorized");
ScheduleHeuristic::Persistent, "not enough vectorized");
return false;
}

Expand All @@ -589,7 +579,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(

if (!cross_grid_params.has_value()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "no valid launch config found");
ScheduleHeuristic::Persistent, "no valid launch config found");
return false;
}
}
Expand All @@ -606,7 +596,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
// factor
if (max_multi_reduction_factor < min_multi_reduction_factor) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"Not enough threads.",
" Multi reduction factor, ",
max_multi_reduction_factor,
Expand All @@ -631,7 +621,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
if (is_cross_grid &&
max_used_sms < scheduler_utils::safeDiv(device_multiprocessor_count, 2)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent,
ScheduleHeuristic::Persistent,
"cross grid - not enough used SMs: ",
max_used_sms);
return false;
Expand All @@ -645,7 +635,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
device_max_threads_per_multiprocessor * 4 && // Large reduction dim
max_used_sms < min_fraction_of_sms) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "not enough used SMs");
ScheduleHeuristic::Persistent, "not enough used SMs");
return false;
}

Expand All @@ -662,7 +652,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
return !reduction_tv->definition()->isA<WelfordOp>();
})) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "non-Welford not enabled yet");
ScheduleHeuristic::Persistent, "non-Welford not enabled yet");
return false;
}

Expand All @@ -679,7 +669,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTimeOuter(
0) &&
device_prop->major == 7) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::OuterPersistent, "iteration not evenly divided");
ScheduleHeuristic::Persistent, "iteration not evenly divided");
return false;
}

Expand Down
13 changes: 0 additions & 13 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ bool SchedulerEntry::canSchedule(
case ScheduleHeuristic::InnerPersistent:
return checkCanSchedule<InnerPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
case ScheduleHeuristic::OuterPersistent:
return checkCanSchedule<OuterPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
case ScheduleHeuristic::InnerOuterPersistent:
return checkCanSchedule<InnerOuterPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
Expand Down Expand Up @@ -230,10 +227,6 @@ std::unique_ptr<SchedulerEntry> SchedulerEntry::makeEntry(
scheduler_entry = std::make_unique<InnerPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
break;
case ScheduleHeuristic::OuterPersistent:
scheduler_entry = std::make_unique<OuterPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
break;
case ScheduleHeuristic::InnerOuterPersistent:
scheduler_entry = std::make_unique<InnerOuterPersistentKernelScheduler>(
fusion, runtime_info, data_cache);
Expand Down Expand Up @@ -318,11 +311,6 @@ HeuristicSummary::HeuristicSummary(
InnerPersistentKernelScheduler::canScheduleRunTime(
fusion, runtime_info, this);
break;
case ScheduleHeuristic::OuterPersistent:
getOuterPersistentHeuristics(fusion, runtime_info, this);
OuterPersistentKernelScheduler::canScheduleRunTime(
fusion, runtime_info, this);
break;
case ScheduleHeuristic::InnerOuterPersistent:
getInnerOuterPersistentHeuristics(fusion, runtime_info, this);
InnerOuterPersistentKernelScheduler::canScheduleRunTime(
Expand Down Expand Up @@ -394,7 +382,6 @@ void HeuristicSummary::validate() const {
break;
}
case ScheduleHeuristic::InnerPersistent:
case ScheduleHeuristic::OuterPersistent:
case ScheduleHeuristic::InnerOuterPersistent: {
NVF_ERROR(entry_type_map_.count(EntryType::REDUCTION_TVS));
NVF_ERROR(
Expand Down
Loading