Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3881,15 +3881,12 @@ bool SegmentCandidateFinder::codeGenSupportedMerge(
NVF_ERROR(
areDirectlyConnected(group1, group2),
"only support testing immediate producer-consumer groups");
if (options_.only_segment_resharding_exprs) {
for (auto group : {group1, group2}) {
for (auto expr : group->exprs()) {
if (isResharding(expr)) {
return false;
}
}
}
return true;
// The segmemter should ideally be redesigned to be more flexible and
// decoupled from the schedulers, but for now, we just return
// `SchedulerType::None` as it is not relevant when the segmenter is
// used with a custom should-merge function.
if (options_.custom_should_merge_groups != nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also segment_set that serves a similar purpose. But it's hard to tell why it's not sufficient without reviewing the following PRs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that both mechanisms serve the same purpose. Btw, this was also true before this PR with the only_segment_resharding_exprs option. For the time being, besides the fact that passing a function is much closer to the existing code than moving to segmenter sets, I also find this way more lightweight and usable in this context, more precisely, it saves me

  1. a pass for adding the sets
  2. adding an option in the segmenter to only segment according to the segmenter set
  3. a pass for removing the segmenter sets

return (options_.custom_should_merge_groups)(group1, group2);
}
return tryMerge(segmented_fusion_.get(), runtimeInfo(), group1, group2) !=
SchedulerType::None;
Expand All @@ -3900,7 +3897,7 @@ bool SegmentCandidateFinder::codeGenSupportedMerge(
SchedulerType SegmentCandidateFinder::deriveSchedulerType(
SegmentedGroup* group) {
FUSER_PERF_SCOPE("SegmentCandidateFinder::deriveSchedulerType");
if (options_.only_segment_resharding_exprs) {
if (options_.custom_should_merge_groups != nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this always None?

Copy link
Collaborator Author

@samnordmann samnordmann Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is nullptr by default, and if it is, we fallback to the traditional single device segmenter using the schedulers
does it answer your question?

Copy link
Collaborator Author

@samnordmann samnordmann Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I misunderstood your question. I guess this one is more for @wujingyue -- here I'm only reproducing the previous behavior, but replacing the option "only_segment_resharding_exprs" with a more agnostic one.

The idea of returning None here has something to do with how FusionExecutorCache decide to lower segments. However, this is not used in MultiDeviceExecutor, so I am not so familiar about this part

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is where the extension of the custom "should merge" function feels more like a hack. The overall design of the segmenter is tightly coupled with scheduling, so it is assumed to have this scheduler type. However, what we are finding is that sometimes we also want to use this without scheduling.

This is a good learning for when we redesign the segmenter. For now, can you please leave a note? Something like:

The segmemter should ideally be redesigned to be more flexible and decoupled from the schedulers, but for now, we just return `SchedulerType::None` as it is not relevant when the segmenter is used with a custom should-merge function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed. For the record, this hack has been present for quite a long time now. Let me add the comment as you suggest

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall design of the segmenter is tightly coupled with scheduling, so it is assumed to have this scheduler type

That's correct, @naoyam. FWIW, this flag is only turned on for MultiDeviceExecutor. In FusionExecutorCache, schedulers test isResharding as you suggested.

// We don't need to generate a SchedulerType for multidevice segments at
// this moment
return SchedulerType::None;
Expand All @@ -3920,7 +3917,7 @@ SegmentCandidateFinder::SegmentCandidateFinder(
: options_(options), runtime_inputs_(inputs) {
FUSER_PERF_SCOPE("SegmentCandidateFinder::SegmentCandidateFinder");
NVF_ERROR(
!options_.only_segment_resharding_exprs ||
options_.custom_should_merge_groups == nullptr ||
(!options_.run_translate_welford &&
!options_.run_combine_reductions && options_.run_herrmann_merge &&
options_.run_final_merge),
Expand Down
8 changes: 7 additions & 1 deletion csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <visibility.h>

#include <deque>
#include <functional>
#include <list>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -482,7 +483,12 @@ struct SegmentCandidateFinderOptions {
bool run_combine_reductions = true;
bool run_herrmann_merge = true;
bool run_final_merge = true;
bool only_segment_resharding_exprs = false;
// if provided, this custom function will be used to determine if two groups
// should be merged. If not provided, the tryMerge function will be used. This
// option is used in the context of MultiGpus where we proceed to a first
// segmentation to scoop out communications from compute.
std::function<bool(SegmentedGroup*, SegmentedGroup*)>
custom_should_merge_groups = nullptr;
};

//! SegmentCandidateFinder
Expand Down
7 changes: 6 additions & 1 deletion csrc/host_ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ HostIrContainer::HostIrContainer(int64_t num_kernel_executors)
HostIrContainer::~HostIrContainer() = default;

Stream* HostIrContainer::getDefaultStream() {
if (!default_stream_) {
if (default_stream_ == nullptr) {
default_stream_ = IrBuilder::createInContainer<Stream>(this);
}
return default_stream_;
Expand All @@ -35,6 +35,11 @@ Stream* HostIrContainer::getDefaultStream() {
std::ostream& HostIrContainer::print(std::ostream& os) const {
IrMathPrinter op_exprs(os);
op_exprs.handle(this);
os << "Aliases:{";
for (const auto& alias : alias_) {
os << "\n " << alias.first << " -> " << alias.second;
}
os << "\n}\n";
return os;
}

Expand Down
16 changes: 16 additions & 0 deletions csrc/host_ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class HostIrContainer final : public Fusion {
//! Print to an output stream
std::ostream& print(std::ostream& os) const;

void resetTopLevelExprs(std::vector<Expr*> exprs) {
top_level_exprs_ = std::move(exprs);
}

const std::vector<Expr*>& topLevelExprs() const;

void pushBackTopLevelExprs(Expr* expr);
Expand All @@ -55,10 +59,22 @@ class HostIrContainer final : public Fusion {

Stream* getDefaultStream();

void markAlias(TensorView* original, const TensorView* new_alias) {
while (alias_.count(original)) {
original = alias_[original]->as<TensorView>();
}
alias_[new_alias] = original;
}

const auto& alias() const {
return alias_;
}

private:
std::vector<Expr*> top_level_exprs_;
std::vector<std::unique_ptr<KernelExecutor>> kernel_executors_;
Stream* default_stream_ = nullptr;
std::unordered_map<const Val*, Val*> alias_;
};

} // namespace hir
Expand Down
Loading