Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
757f685
Add C++ repro for #418
jacobhinkle May 30, 2023
03ac754
Process all outputs of each Expr in concretization
jacobhinkle May 30, 2023
36c70c5
Add fusion_ir_dynamic debug dump option
jacobhinkle May 31, 2023
9727388
Propagate extent from P2C during concretization
jacobhinkle Jun 1, 2023
d675161
Merge remote-tracking branch 'origin/main' into fix_issue418
jacobhinkle Jun 1, 2023
031f327
Revert "Propagate extent from P2C during concretization"
jacobhinkle Jun 1, 2023
3724a00
Match concretized reshape extents to desired extents.
jacobhinkle Jun 1, 2023
72b29ee
Remove debugging print statements
jacobhinkle Jun 1, 2023
78ad7e2
Add output extents as ViewOp inputs
jacobhinkle Jun 1, 2023
788f500
Remove concrete_reshape_out_tv after replacing
jacobhinkle Jun 1, 2023
bf9aa1e
Find expr outputs before calling StmtSort again
jacobhinkle Jun 1, 2023
4f05ad5
Revert to d675161 but keep bf9aa1e5
jacobhinkle Jun 1, 2023
8160c6f
Fix DynamicTransformIssue418_CUDA ATen keepdim args
jacobhinkle Jun 1, 2023
3c173e0
Merge branch 'main' into fix_issue418
jacobhinkle Jun 1, 2023
83062b8
Add failing full groupnorm test
jacobhinkle Jun 2, 2023
82e7ee8
Fix bug in failing test
jacobhinkle Jun 2, 2023
5c8a4c2
Properly set broadcast axes in inputs to 418Full test
jacobhinkle Jun 2, 2023
9ffadb9
Add channels-last test
jacobhinkle Jun 2, 2023
596ceab
Grab iter_type fix in newOutputDomain from #358
jacobhinkle Jun 2, 2023
fd6eca8
Merge remote-tracking branch 'origin/main' into fix_issue418
jacobhinkle Jun 2, 2023
6b85cf4
Propagate replacements for placeholder extents
jacobhinkle Jun 2, 2023
c8a1176
Merge branch 'main' into fix_issue418
jacobhinkle Jun 2, 2023
f8a0610
Temporarily revert "Propagate replacements for placeholder extents"
jacobhinkle Jun 2, 2023
8008ece
Merge branch 'main' into fix_issue418
jacobhinkle Jun 2, 2023
bd17931
Print dynamic fusion if fusion_ir_concretized is given
jacobhinkle Jun 6, 2023
9aaedec
Set proper output extent in dynamic cat
jacobhinkle Jun 6, 2023
735278c
Merge branch 'main' into fix_issue418
jacobhinkle Jun 6, 2023
140298f
Add failing PadBroadcast test
jacobhinkle Jun 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,35 @@ void DynamicTransformConcretizer::concretize() {
concretizeResize();

// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_.fusion(), true);

// We need to concretize all immediate outputs of all intermediate
// expressions; even those leading to dead code branches. To do this, we
// insert all outputs from all intermediate expressions to "leaves". If we
// don't insert anything new, we are done. Otherwise, we traverse again using
// these as outputs as well, which ensures the output is sorted.
auto leaves = info_.fusion()->getTerminatingOutputs();
auto leaves_set =
std::unordered_set<Statement*>(leaves.begin(), leaves.end());
std::vector<Statement*> all_stmts;
bool inserted = true;
while (inserted) {
all_stmts = StmtSort::getStmts(info_.fusion(), leaves, true);
inserted = false;
for (auto stmt : all_stmts) {
if (stmt->isExpr()) {
for (auto o : stmt->as<Expr>()->outputs()) {
if (leaves_set.find(o) == leaves_set.end()) {
leaves.push_back(o);
leaves_set.insert(o);
inserted = true;
}
}
}
}
}
// Concretize all vals in the final vector
for (auto stmt : all_stmts) {
if (stmt->isA<Val>()) {
if (stmt->isVal()) {
mutate(stmt);
}
}
Expand Down Expand Up @@ -641,6 +667,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
// corresponding producer IDs

std::optional<IterType> id_type;
Val* extent = nullptr;

for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
Expand All @@ -663,6 +690,11 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
} else {
id_type = input_id->getIterType();
}

// Set extent expression based on producer, overwriting that of consumer
if (!extent) {
extent = input_id->extent();
}
}

TORCH_INTERNAL_ASSERT(
Expand All @@ -680,7 +712,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
consumer->toString());

auto concretized_id =
IterDomainBuilder(root_id).iter_type(*id_type).build();
IterDomainBuilder(root_id).extent(extent).iter_type(*id_type).build();

registerConcretization(root_id, concretized_id);
is_concretized = true;
Expand Down
5 changes: 4 additions & 1 deletion csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,10 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor(
kernel_runtime->updateHeuristicsLaunchParams(new_heuristics.get());
} else {
// cache miss, need to re-build an optimized graph for this case

if (isDebugDumpEnabled(DebugDumpOption::FusionIrConcretized)) {
std::cout << "Fusion Before Concretization:" << std::endl;
fusion()->printMath();
}
// concretize fusion_ for use in this runtime
auto fusion = std::make_unique<Fusion>(*fusion_);
FusionGuard fg(fusion.get());
Expand Down
17 changes: 15 additions & 2 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ TensorView* reshape(TensorView* inp_tv, const std::vector<Val*>& new_sizes) {
return static_reshape_output;
}

auto root_domain = ops::newOutputDomain({inp_tv}, inp_tv->dtype());
auto root_domain = ops::newOutputDomain({inp_tv});

// Create placeholder rfactor domain. Note it's not connected with the root
// domain.
Expand Down Expand Up @@ -632,7 +632,20 @@ TensorView* cat(const std::vector<TensorView*>& inputs, int64_t cat_dim) {
}

// Now all of resized_inputs have the same shape as the out tensor
auto out = ops::newOutputTV(resized_inputs, dtype);
// NOTE: ops::newOutputTV would not necessarily be able to infer that the
// padded dimensions are all of the same size. However, we know that they are
// constructed such that that is the case, so we can use
auto out_domain = ops::newOutputDomain(resized_inputs);
// Override the concatenated dimension and insert an IterDomain with the true
// extent, if needed
if (!out_domain.at(cat_dim)->extent()->sameAs(concat_ext)) {
out_domain[cat_dim] =
IterDomainBuilder(out_domain.at(cat_dim)).extent(concat_ext).build();
}
auto out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
dtype);
Comment on lines +635 to +648
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a breakage that occurred when I made the other changes. It makes the output extent of cat look as it should (e.g. (i0 + i2) + i4) instead of creating a new symbolic extent as was done previously and which complicated concretization.


IrBuilder::create<CatOp>(out, resized_inputs, cat_dim);

Expand Down
72 changes: 62 additions & 10 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ IterType promoteIterType(IterType type1, IterType type2) {
}
}

std::vector<IterDomain*> newOutputDomain(
const std::vector<Val*>& vals,
DataType dtype) {
std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals) {
std::vector<TensorView*> tvs;
for (auto val : vals) {
if (val->getValType() == ValType::TensorView) {
Expand All @@ -223,9 +221,10 @@ std::vector<IterDomain*> newOutputDomain(
std::vector<int64_t> start_offsets(out_domain.size(), 0);
std::vector<int64_t> stop_offsets(out_domain.size(), 0);
std::vector<Val*> extent_vals(out_domain.size(), nullptr);
std::vector<bool> mismatched_symbolic_extents(out_domain.size(), false);
std::vector<Val*> expanded_extent_vals(out_domain.size(), nullptr);
std::vector<c10::optional<IterType>> iter_types(
out_domain.size(), c10::nullopt);
std::vector<std::optional<IterType>> iter_types(
out_domain.size(), std::nullopt);

for (auto tv : tvs) {
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
Expand All @@ -236,6 +235,53 @@ std::vector<IterDomain*> newOutputDomain(
" dimensions but expected ",
out_domain.size());
for (const auto i : c10::irange(dom.size())) {
auto iter_type = dom[i]->getIterType();
auto prev_iter_type = iter_types[i];
if (prev_iter_type.has_value()) {
// Clang-tidy complains about unchecked access to optional value here
if (iter_type == IterType::Iteration &&
prev_iter_type.value() == IterType::Symbolic) {
// Prefer the Iteration extent, since Symbolic could be broadcast
extent_vals[i] = nullptr;
} else if (iter_type == IterType::Symbolic) {
switch (prev_iter_type.value()) {
case IterType::Iteration:
// Previously found Iteration domain, so ignore all Symbolic
// domains
continue;
case IterType::Symbolic:
if (extent_vals[i]->sameAs(dom[i]->extent())) {
// matching symbolic extent
continue;
} else {
// Mismatched symbolic input extents. Any one of the symbolic
// inputs could be a Broadcast or Iteration domain. Until
// concretization, we will not know which one holds the true
// extent (or whether they all are Broadcast, so that the output
// is also Broadcast). We record that these symbolic extents
// mismatched so that we can introduce a new symbolic extent
// later.
mismatched_symbolic_extents[i] = true;
}
break;
case IterType::Broadcast:
// Previously found only broadcast, so this will either also
// broadcast or resolve those broadcasts. If the expanded
// extent of any of the broadcasts is not 1, then it will need to
// match that of the dom[i]. In either case, prefer dom[i]'s
// extent, so clear iter_types[i] and extent_vals[i] so that the
// rest of this iteration will mark output as Symbolic.
iter_types[i] = std::nullopt;
extent_vals[i] = nullptr;
break;
default:
TORCH_CHECK(
false,
"Encountered unexpected IterType when creating new output domain: ",
prev_iter_type.value());
}
}
}
if (dom[i]->isBroadcast()) {
if (dom[i]->hasExpandedExtent()) {
expanded_extent_vals[i] =
Expand All @@ -244,9 +290,9 @@ std::vector<IterDomain*> newOutputDomain(
continue;
}
extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
if (iter_types[i].has_value()) {
if (prev_iter_type.has_value()) {
iter_types[i] =
promoteIterType(iter_types[i].value(), dom[i]->getIterType());
promoteIterType(prev_iter_type.value(), dom[i]->getIterType());
} else {
iter_types[i] = dom[i]->getIterType();
}
Expand All @@ -268,15 +314,21 @@ std::vector<IterDomain*> newOutputDomain(
}
}
for (const auto dim_i : c10::irange(out_domain.size())) {
auto iter_type = iter_types[dim_i];
if (iter_type == IterType::Symbolic && mismatched_symbolic_extents[dim_i]) {
// if we have a symbolic output but the input symbolic extents did not
// match, create a new extent
extent_vals[dim_i] = IrBuilder::create<Int>();
}
if (extent_vals[dim_i] != nullptr) {
TORCH_INTERNAL_ASSERT(
iter_types[dim_i].has_value(),
iter_type.has_value(),
"Could not deduce iter type for new tensor view.");
out_domain[dim_i] =
IterDomainBuilder(
IrBuilder::create<Int>(start_offsets[dim_i]), extent_vals[dim_i])
.stop_offset(IrBuilder::create<Int>(stop_offsets[dim_i]))
.iter_type(iter_types[dim_i].value())
.iter_type(iter_type.value())
.build();
} else {
out_domain[dim_i] = IterDomainBuilder(
Expand All @@ -292,7 +344,7 @@ std::vector<IterDomain*> newOutputDomain(
}

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
auto out_domain = newOutputDomain(vals, dtype);
auto out_domain = newOutputDomain(vals);
return IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
Expand Down
4 changes: 1 addition & 3 deletions csrc/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ Val* newScalar(ValType vtype, DataType dtype);

IterType promoteIterType(IterType type1, IterType type2);

std::vector<IterDomain*> newOutputDomain(
const std::vector<Val*>& vals,
DataType dtype);
std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals);

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype);

Expand Down
Loading