Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ void DynamicTransformConcretizer::concretize() {
// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(
info_->fusion(),
/*traverse_members*/ false,
/*traverse_attributes*/ false,
/*traverse_members*/ true,
/*traverse_attributes*/ true,
Comment on lines +457 to +458
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we propagate scalars, we should do this anywhere they might occur; namely in members and attributes. This was not necessary when we were only replacing IterTypes since IterDomains are traversed regardless of these settings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?

/*traverse_siblings*/ true);
for (auto tv : ir_utils::filterByType<TensorView>(all_stmts)) {
mutate(tv);
Expand Down Expand Up @@ -497,6 +497,25 @@ void DynamicTransformConcretizer::concretizeReshape() {
// replacement is valid
checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv);

// Extent expressions often change when concretizing a reshape. Here we
// replace these in all downstream expressions so that the Fusion looks just
// like it would have if we had used a static reshape instead.
auto old_rfactor = incomplete_out_tv->getMaybeRFactorDomain();
auto new_rfactor = concrete_reshape_out_tv->getMaybeRFactorDomain();
TORCH_INTERNAL_ASSERT(
old_rfactor.size() == new_rfactor.size(),
"Concretized reshape rfactor size does not match symbolic rfactor");
for (auto idx : c10::irange(new_rfactor.size())) {
auto old_extent = old_rfactor.at(idx)->extent();
auto new_extent = new_rfactor.at(idx)->extent();
// If the old extent did not have a definition, we don't need to replace
// it, since it will get bound whenever this tensor is a segmentation
// edge.
if (old_extent->definition() && !new_extent->sameAs(old_extent)) {
registerConcretization(old_extent, new_extent);
}
}

// Replace the old tensor with the new concretized tensor
auto uses = incomplete_out_tv->uses();
for (auto use_of_old_tv : uses) {
Expand Down
62 changes: 59 additions & 3 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,62 @@ std::string toString(const SegmentedFusion* segmented_fusion) {
return ss.str();
}

//! Sets the rfactor as root and erases rfactor of all inputs in fusion. Any
//! non-constant expressions in those extents are replaced by new scalars with
//! no definition. These mutations are performed throughout the Fusion so that
//! downstream expressions dependent on the original inputs' rfactor extents can
//! be computed properly.
void convertInputRfactorsToRoots(Fusion* fusion) {
FusionGuard fg(fusion);

// Holds all Val replacements across all inputs
std::unordered_map<Val*, Val*> replacement_map;

for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
// Create a new root domain and replacement TensorDomain.
// Given an rfactor domain, create a new IterDomain.
// Otherwise, clone the previous IterDomain
std::vector<IterDomain*> new_root_domain;
auto rfactor = tv->getMaybeRFactorDomain();
new_root_domain.reserve(rfactor.size());

// Does the domain (root / rfactor) contain all concrete sized extents?
bool tv_is_concrete = true;
for (auto id : rfactor) {
if (!id->extent()->isConstScalar()) {
tv_is_concrete = false;
break;
}
}

for (const auto& id : rfactor) {
if (id->isRFactorProduct()) {
// Create new symbolic extents for rfactor iterDomains
auto domain_extent = (!tv_is_concrete)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is fine as it's mostly just copied from tensor_view.cpp. However, I don't know why we do use all symbolic extents if tv_is_concrete. It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.

This is not about this PR itself, but does it make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss duplicated replacements if members were also traversed? Would it be a concern?

This is an important point since it could lead to us concretizing one aspect of a Val but losing a previously concretized aspect (for example losing the concretized extent when we later concretize IterType). But I think we are OK if we are careful to use maybeMutated when creating replacements. In one case we do already do multiple mutations; we call OptOutMutator::mutate(id) at the beginning of our mutate(IterDomain* id) override.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this means that if there's any IterDomain with a symbolic extent, all IterDomains would have symbolic extents.

Yeah that is strange. All IterDomains in the rfactor domain are checked to determine concreteness, but then the condition is only applied to rfactor products. The purpose of this function is to make the root a "standalone" domain so that we can bind input shapes to it, so it seems like all extents regardless of if they were already rfactor products should be either constant ints or pure symbolic.

We always need to create a new IterDomain for rfactor products in order to cut the connection to the root domain. It seems like a simpler condition would be

auto domain_extent = id->extent()->isConstScalar()
    ? id->extent()
    : IrBuilder::create<Val>(DataType::Int);

And in light of the above, we could move this outside of the id->isRFactorProduct() check, so that we would replace IDs that are not rfactor products if they have non-constant derived extents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a separate issue for this. Let's merge this PR as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Merging without that change. I will do that in another PR.

? IrBuilder::create<Val>(DataType::Int)
: id->extent();
replacement_map.emplace(id->extent(), domain_extent);
new_root_domain.push_back(IterDomainBuilder(id)
.extent(domain_extent)
.resetSchedulingParams()
.build());
} else {
new_root_domain.push_back(id->cloneWithoutRFactor());
}
}

TORCH_INTERNAL_ASSERT(
new_root_domain.size() == tv->domain()->contiguity().size());
auto new_td = IrBuilder::create<TensorDomain>(
new_root_domain, tv->domain()->contiguity());
replacement_map.emplace(tv->domain(), new_td);
}

// This will replace the values in the mapping replacement_map throughout the
// Fusion
ir_utils::replaceValue(fusion, replacement_map);
}

std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {
std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>();

Expand Down Expand Up @@ -1469,9 +1525,9 @@ std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {
fusion_segment->addOutput(complete_to_segment_map.clone(out));
}

for (auto tv : view_tvs) {
tv->convertRfactorToRootDomain();
}
// Replace all vals that are rfactor extents in fusion_segment->inputs() with
// new Vals so that they can be bound to the segment inputs.
convertInputRfactorsToRoots(fusion_segment.get());

return fusion_segment;
}
Expand Down
7 changes: 0 additions & 7 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,6 @@ class TORCH_CUDA_CU_API TensorView : public Val {
return domain_;
}

//! This is for a TensorView with an rFactor domain that is an input to a
//! fusion segment. We convert the rfactor domain into a new root domain.
//! Any dynamic-sized rfactor iterDomains are given a new symbolic extent.
//! Concrete integer extents are kept. Output TensorViews of any subsequent
//! expressions that use this TensorView are also updated.
void convertRfactorToRootDomain();

void setContiguity(const std::vector<std::optional<bool>>& contig) {
domain()->setContiguity(contig);
}
Expand Down
63 changes: 0 additions & 63 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,69 +240,6 @@ std::string TensorView::toInlineString(int indent_size) const {
return toString(indent_size);
}

void TensorView::convertRfactorToRootDomain() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change completely unrelated with the propagation of concretized vals?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At first I thought it was going to be necessary. I realized it was not necessary for this PR, but it is a little faster and reduces the TensorView interface so I figured I would keep it. I can split it into another PR if you prefer.

// For a given TensorView, does its domain (root / rfactor) contain any
// concrete sized extents?
auto is_concrete_tensor = [](TensorView* tv) {
for (auto id : tv->getMaybeRFactorDomain()) {
if (!id->extent()->isConstScalar()) {
return false;
}
}
return true;
};

// Create a new root domain and replacement TensorDomain.
// Given an rfactor domain, create a new IterDomain.
// Otherwise, clone the previous IterDomain
auto createReplacementDomain =
[this](const std::vector<Val*>& replacement_extents) {
TORCH_INTERNAL_ASSERT(
!replacement_extents.empty() &&
getMaybeRFactorDomain().size() == replacement_extents.size());
Comment on lines -260 to -262
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I omitted the check that replacement_extents is not empty since it did not seem necessary.

size_t idx = 0;
std::vector<IterDomain*> new_root_domain(
getMaybeRFactorDomain().size());
for (const auto& id : getMaybeRFactorDomain()) {
if (replacement_extents[idx] != nullptr) {
new_root_domain[idx] = IterDomainBuilder(id)
.extent(replacement_extents[idx])
.resetSchedulingParams()
.build();
++idx;
} else {
TORCH_INTERNAL_ASSERT(!id->isRFactorProduct());
new_root_domain[idx++] = id->cloneWithoutRFactor();
}
}

TORCH_INTERNAL_ASSERT(
new_root_domain.size() == domain()->contiguity().size());
setDomain(IrBuilder::create<TensorDomain>(
container(), new_root_domain, domain()->contiguity()));
};

std::vector<Val*> rfactor_extents;
std::unordered_map<Val*, Val*> replacement_map;
const auto kThisIsConcreteTensor = is_concrete_tensor(this);
for (const auto& id : getMaybeRFactorDomain()) {
if (id->isRFactorProduct()) {
// Create new symbolic extents for rfactor iterDomains
auto domain_extent = (!kThisIsConcreteTensor)
? IrBuilder::create<Val>(container(), DataType::Int)
: id->extent();
rfactor_extents.push_back(domain_extent);
replacement_map.emplace(id->extent(), domain_extent);
} else {
rfactor_extents.push_back(nullptr);
}
}
createReplacementDomain(rfactor_extents);

// Propagate new extent throughout fusion using ValReplacementMutator
ir_utils::replaceValue(fusion(), replacement_map);
}

TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
: Val(src, ir_cloner),
domain_(ir_cloner->clone(src->domain_)),
Expand Down
52 changes: 23 additions & 29 deletions test/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,9 @@ TEST_F(NVFuserTest, FusionDynamicReshapeReductionShmoo_CUDA) {
{{8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9}, false}, // merge(1) osplit(1, 3)

// test passing -1 dynamically for dimension size
// This currently fails. see https://github.com/NVIDIA/Fuser/issues/249
// This is unsupported. See https://github.com/NVIDIA/Fuser/issues/249
// Values of -1 must be passed as constants instead of input-dependent
// scalars.
//{{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false} // merge(1) osplit(1, 3)
};
reductionDynamicViewAddFusion(
Expand Down Expand Up @@ -1068,7 +1070,7 @@ TEST_F(NVFuserTest, FusionDynamicEmptyCat2_CUDA) {
}

// Repro of https://github.com/NVIDIA/Fuser/issues/418
TEST_F(NVFuserTest, DynamicTransformIssue418Concretization_CUDA) {
TEST_F(NVFuserTest, DynamicTransformIssue418_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -1077,39 +1079,31 @@ TEST_F(NVFuserTest, DynamicTransformIssue418Concretization_CUDA) {
auto s0 = IrBuilder::create<Val>(DataType::Int);
fusion->addInput(s0);

auto v00 = tv0->axis(0)->extent();
auto v01 = tv0->axis(1)->extent();
auto v02 = tv0->axis(2)->extent();
auto v03 = tv0->axis(3)->extent();

auto tv1 = reshape(tv0, {v00, div(v01, s0), s0, v02, v03});
auto sh = tensor_sizes(tv0);
auto tv1 = reshape(tv0, {sh[0], div(sh[1], s0), s0, sh[2], sh[3]});
// Reducing along axis 2 in tv1 is equivalent to a partial reduction across
// axis 1 of tv0.
auto vm = variance_mean(tv1, {2, 3, 4}, 0, true);
fusion->addOutput(vm.mean);
fusion->addOutput(vm.var);

{
ExpressionEvaluator expr_eval;

expr_eval.bind(tv0->axis(0)->extent(), 256L);
expr_eval.bind(tv0->axis(1)->extent(), 128L);
expr_eval.bind(tv0->axis(2)->extent(), 28L);
expr_eval.bind(tv0->axis(3)->extent(), 28L);
expr_eval.bind(s0, 4L);

auto initial_info = DynamicTransform::getInitialInfo(fusion.get());
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);

TORCH_CHECK(
info.getReshapeTransforms().size() == 1,
"Expected to have one reshape transform: ",
info.toString());
FusionExecutorCache fusion_executor_cache(std::move(fusion));

DynamicTransform::concretizeFusion(fusion.get(), &info);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at0 = at::randn({256, 128, 28, 28}, options);
std::vector<c10::IValue> aten_inputs = {at0, 32};
auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs);
auto at1 = at0.reshape({256, 4, 32, 28, 28});
auto atmean = at1.mean({2, 3, 4}, /*keepdim*/ true);
auto atvar = at1.var({2, 3, 4}, /*unbiased*/ true, /*keepdim*/ true);

TORCH_CHECK(
!fusion->hasDynamicTransform(),
"Expected to have no dynamic transform");
}
testValidate(
fusion_executor_cache.fusion(),
outputs,
aten_inputs,
{atmean, atvar},
__LINE__,
__FILE__);
}

TEST_F(NVFuserTest, Issue249_CUDA) {
Expand Down