Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ MultiDeviceExecutor::MultiDeviceExecutor(
should_run_[group] = involvedDevices(expr).count(comm_.deviceId());
}
// prepare the order in which to launch the kernels/comms
prepareRuntimeOrder(staged_fusion_.get(), workspace);
prepareRuntimeOrder(staged_fusion_.get(), workspace_);

// Allocator setup
// vals_to_allocate_ stores the tensors that need to be allocated at runtime,
Expand Down Expand Up @@ -217,7 +217,7 @@ std::vector<at::Tensor> MultiDeviceExecutor::runWithInput(
}

// Run through the groups to launch kernels and comms
for (auto group : workspace.group_run_order) {
for (auto group : workspace_.group_run_order) {
if (!is_resharding_.at(group)) {
postKernel(group, launch_params);
} else {
Expand Down Expand Up @@ -261,7 +261,7 @@ std::string MultiDeviceExecutor::validate() const {
std::ostream& MultiDeviceExecutor::print() {
int compute_segment_counter = 0;
int communication_counter = 0;
for (auto group : workspace.group_run_order) {
for (auto group : workspace_.group_run_order) {
if (is_resharding_[group]) {
debug() << "Communication " << communication_counter << ": "
<< group->exprs().at(0) << "\n";
Expand All @@ -277,4 +277,18 @@ std::ostream& MultiDeviceExecutor::print() {
return debug();
}

std::vector<FusionExecutorCache*> MultiDeviceExecutor::
getFusionExecutorCaches() {
NVF_CHECK(
params_.use_fusion_executor_cache,
"MultideviceExecutor must be configured to use FusionExecutorCache");
std::vector<FusionExecutorCache*> fecs;
for (SegmentedGroup* group : workspace_.group_run_order) {
if (fec_.count(group) > 0) {
fecs.push_back(&(fec_.at(group)));
}
}
return fecs;
}

} // namespace nvfuser
8 changes: 7 additions & 1 deletion csrc/multidevice/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ class MultiDeviceExecutor {
//! Print to default debugging output stream
std::ostream& print();

// Returns a vector of Fusion executor caches that corresponds to
// each compute segment in runtime order.This is only valid if the executor
// was configured to use FusionExecutorCache. i.e.
// params.use_fusion_executor_cache = true
std::vector<FusionExecutorCache*> getFusionExecutorCaches();

private:
// execute locally a SegmentedGroup that does not involve inter-device
// communication. Launch Params are used only if
Expand All @@ -129,7 +135,7 @@ class MultiDeviceExecutor {
// 2) a Fusion comprised of one Expr, representing inter-device communication
std::unique_ptr<SegmentedFusion> staged_fusion_;
// Stores the order in which the pipeline's stage should be executed
RuntimeWorkSpace workspace;
RuntimeWorkSpace workspace_;
// Cache Fusions, FusionExecutors, and Communications
std::unordered_map<SegmentedGroup*, FusionExecutor> fe_;
std::unordered_map<SegmentedGroup*, FusionExecutorCache> fec_;
Expand Down
12 changes: 12 additions & 0 deletions tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ TEST_F(DistributedMatmulTest, LayoutTN_NoComms) {
{expected_output},
__LINE__,
__FILE__);

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);
}

TEST_F(DistributedMatmulTest, LayoutTN_Allgather) {
Expand Down Expand Up @@ -177,6 +180,9 @@ TEST_F(DistributedMatmulTest, LayoutTN_Allgather) {
{expected_output},
__LINE__,
__FILE__);

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);
}

TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) {
Expand Down Expand Up @@ -228,6 +234,9 @@ TEST_F(DistributedMatmulTest, LayoutNT_AllReduce) {

testValidate(
runtime.completeFusion(), outputs, inputs, {out}, __LINE__, __FILE__);

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);
}

TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
Expand Down Expand Up @@ -292,5 +301,8 @@ TEST_F(DistributedMatmulTest, LayoutNT_ReduceScatter) {
{expected_output},
__LINE__,
__FILE__);

std::vector<FusionExecutorCache*> fecs = runtime.getFusionExecutorCaches();
EXPECT_EQ(fecs.size(), 1);
}
} // namespace nvfuser