Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 128 additions & 218 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,19 @@ class DynamicTransformConcretizer : public OptOutMutator {

using OptOutMutator::mutate;

void mutate(TensorView* tv) final;

void mutate(TensorDomain* td) final;

//! Concretizes the root domain of a symbolic consumer tensor from
//! its producer domains. Returns true if any root ID is concretized.
bool propagateFromProducerToConsumer(TensorView* consumer);
void mutate(IterDomain* id) final;

private:
const DynamicTransformConcretizationInfo* info_;

//! This map is used during concretization to identify, for a given IterDomain
//! the set of all IterDomains which are "aligned" with it in some TensorView
//! expression. This enables us to write mutate(IterDomain*) and propagate
//! information from producer IterDomains to consumers, which is otherwise not
//! represented in the graph since we do not connect IterDomains between
//! TensorViews with expressions.
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
id_producers_;
};

void DynamicTransformConcretizer::concretize() {
Expand All @@ -376,13 +379,76 @@ void DynamicTransformConcretizer::concretize() {
// Set output IterTypes for dynamic resize ops
concretizeResize();

// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_->fusion(), true);
for (auto stmt : all_stmts) {
if (stmt->isA<Val>()) {
mutate(stmt);
Comment on lines -382 to -383
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of only mutating Vals, we now mutate Exprs as well, which replaces the Expr in place if any inputs or outputs have changed. Note that outputs of Exprs are mutated after their definition has been mutated, so we should be careful updating a Val that has a definition. But of course we should be careful in that case in the existing code too.

// This fixes the set of statements we will process over the course of
// multiple passes.
// Since we will traverse these Statements after some have been removed, we
// will not be able to safely check the types of each Statement in later
// loops. To avoid segfaults, we first split all_statements into subsets for
// each traversal.
std::vector<Val*> non_tds_tvs;
std::vector<Expr*> all_exprs;
std::vector<Val*> tvs_and_tds;
for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) {
if (stmt->isExpr()) {
all_exprs.push_back(stmt->asExpr());
} else {
auto val = stmt->asVal();
if (val->isA<TensorView>() || val->isA<TensorDomain>()) {
tvs_and_tds.push_back(val);
} else {
non_tds_tvs.push_back(val);
}
Comment on lines +388 to +400
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was previously a single loop over all_stmts is now three separate loops over these subsets.

}
}

// When propagating IterTypes across expressions, we need to know the producer
// IterDomains corresponding to a consumer ID. This mapping helps facilitate
// this and is used later by mutate(IterDomain*) in the first pass below.
for (auto expr : all_exprs) {
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
for (auto [cid, pid] : root_map.mapConsumerToProducer(
consumer->domain(), producer->domain())) {
// Initialize set of producer IDs, if we haven't already
auto& producers = id_producers_.emplace(cid, 0).first->second;
producers.insert(pid);
}
}
}
}

// In this first pass, we only mutate Vals that are not TensorDomains or
// TensorViews.
//
// This pass does not modify the Fusion.
for (auto val : non_tds_tvs) {
mutate(val);
}

// In the second pass, we only mutate Exprs. For each expr, if any of its
// inputs, outputs, or attributes are registered for mutation, a new expr will
// be created and the original expr will be removed. This is the mechanism
// OptOutMutator provides for setting the `definition()` of replaced Vals.
//
// This pass may add and remove Exprs, so elements of all_exprs are invalid
// after this pass.
for (auto expr : all_exprs) {
mutate(expr);
}

// In the third pass, we mutate the TensorDomains and TensorViews, without
// touching any other Vals or Exprs. The only change made to the Fusion at
// this stage is that TensorViews have their domain() replaced if any of their
// IterDomains are registered for mutation. This must happen last, as Expr
// mutation is required in order to properly connect root and rfactor domains,
// which is checked when creating new TensorDomains.
//
// This pass modifies the Fusion by creating new TensorDomains and swapping
// them into TensorViews.
for (auto val : tvs_and_tds) {
mutate(val);
}
}

void DynamicTransformConcretizer::concretizeReshape() {
Expand Down Expand Up @@ -441,228 +507,72 @@ void DynamicTransformConcretizer::checkConcretizedUses(
}
}

// Concretizes inherited symbolic domains. Note that when this is
// called, it is assumed that all dynamic ops themselves are
// concretized. Since symbolic IDs may be propagated down to
// consumers, those domains need to be concretized accordingly.
void DynamicTransformConcretizer::mutate(TensorView* tv) {
if (!tv->domain()->hasSymbolicAxis()) {
return;
}

// First, try to concretize the root domain as there may be symbolic
// axes inherited from the producers
propagateFromProducerToConsumer(tv);

// If no root domain is altered by producer, we don't need to propagate back
// up to rfactor. We could return early, but instead we go ahead and check the
// root to rfactor transforms to be sure we have concretized any intermediate
// IterDomains.

// At this point, there should be no expr beyond rfactor root
TORCH_INTERNAL_ASSERT(
tv->getLeafDomain() == tv->getMaybeRFactorDomain(),
"Invalid tensor: ",
tv->toString());

// If it has an rfactor root domain, the IterTypes of the rfactor
// IDs may need to be updated as well. Traverse the rfactor exprs
// and mutate the IterTypes of output IDs if symbolic.
if (tv->hasRFactor()) {
// Note that it is assumed that theres's no further expression
// beyond the rfactor domain as asserted above
auto all_id_exprs = StmtSort::getExprsBetween(
tv->fusion(),
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()});
for (auto expr : all_id_exprs) {
// Assume outputs of IterDomain exprs are always IterDomains. If
// the assumption is invalidated, the logic here would need to
// be updated. Assert the assumption to immediately detect such
// a case if happened.
for (auto out_val : expr->outputs()) {
TORCH_INTERNAL_ASSERT(
out_val->isA<IterDomain>(),
"Unexpected output: ",
out_val->toString(),
". IterDomain was expected.");
}
void DynamicTransformConcretizer::mutate(IterDomain* id) {
// This will register id for mutation if start, stop, or extent are registered
// for mutation
OptOutMutator::mutate(id);

// Use this to prototype new concretizations, since it will have replaced
// extent (see above)
auto mut_id = maybeMutated(id)->as<IterDomain>();

IterDomain* concretized_id = nullptr;

// NOTE: We do not return early if all outputs are concrete as there may
// still be concrete inputs. For example, a Symbolic IterDomain might be
// padded with constant pad widths (1, 1), in which case although we do
// not know the exact extent of the output, we know it is at least as
// large as the sum of the pad widths, 2. In such cases, the output
// IterDomain is concrete at definition, since if the extent is >1 we know
// the IterType is Iteration. In these cases, we must continue to
// concretize intermediate expressions between the root and R-factor
// domain. See test DynamicTransform5_CUDA which demonstrates this
// behavior.
// NOTE: We also do not assume that if one output ID is symbolic, that
// they all must be. See test FusionSliceForNanoGPT3_CUDA for an example
// that does a static split by a factor of 16 of a symbolic input domain.
// The static split in that case results in a concrete IterDomain with
// extent 16 along with a symbolic one (extent ceilDiv(n / 16)).

// Determine the output IterType
IterType iter_type = IterType::Symbolic;
for (auto inp_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
if (mut_id->isSymbolic()) {
if (auto def = id->definition()) {
IterType iter_type = mut_id->getIterType();
// Determine concrete IterType based on promotion of inputs to def
for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) {
auto updated_id = maybeMutated(inp_id)->as<IterDomain>();
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
}
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
expr->toString());

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
continue;
def->toString());
concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build();
} else {
// IterDomains without definitions might be root domains for the output of
// a TensorView expression. If so, we should propagate their
// concretization in the producer to consumer direction.

auto producers_it = id_producers_.find(id);
if (producers_it != id_producers_.end()) {
// id was a consumer root ID in some TV expression

std::optional<IterType> id_type;
for (auto producer_id : producers_it->second) {
producer_id = maybeMutated(producer_id)->as<IterDomain>();
if (id_type.has_value()) {
id_type =
ops::promoteIterType(*id_type, producer_id->getIterType());
} else {
id_type = producer_id->getIterType();
}
}
auto concretized_out_id =
IterDomainBuilder(out_id).iter_type(iter_type).build();
registerConcretization(out_id, concretized_out_id);
}

// The expr itself needs to be mutated as well in case the outputs are
// mutated, which can be done by the mutate method
OptOutMutator::mutate(expr);
}
}

// Root and rfactor domains are updated. First mutate the
// TensorDomain and then TensorView
mutate(tv->domain());
OptOutMutator::mutate(tv);
}

// Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but
// the contiguity vector may need to be updated as well as symbolic
// domains may be mutated to broadcast domains, which means contiguity
// may need to be changed to nullopt
void DynamicTransformConcretizer::mutate(TensorDomain* td) {
bool mutated = false;

auto updateIdVec = [&](const std::vector<IterDomain*>& ids) {
std::vector<IterDomain*> updated_ids;
for (auto id : ids) {
auto updated_id = maybeMutated(id)->as<IterDomain>();
updated_ids.push_back(updated_id);
if (!updated_id->sameAs(id)) {
mutated = true;
}
}
return updated_ids;
};

std::vector<IterDomain*> root_dom = updateIdVec(td->root());
std::vector<IterDomain*> rfactor_dom = td->hasRFactor()
? updateIdVec(td->maybeRFactor())
: std::vector<IterDomain*>();
std::vector<IterDomain*> domain = updateIdVec(td->leaf());

if (!mutated) {
return;
}

// Update the contiguity vector. Drop the contig val if mutated to broadcast
auto contig = td->contiguity();

for (const auto i : c10::irange(td->maybeRFactor().size())) {
auto original_id = td->maybeRFactor().at(i);
if (original_id->getIterType() != IterType::Symbolic) {
continue;
}

TORCH_INTERNAL_ASSERT(
contig.at(i),
"Unexpected to have a non-contig symbolic domain: ",
original_id->toString());

auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i);

// If the concretized ID is a broadcast domain, drop the contig val
if (updated_id->isBroadcast()) {
contig.at(i) = std::nullopt;
}
}

Val* mutated_val = IrBuilder::create<TensorDomain>(
td->container(), root_dom, rfactor_dom, domain, contig);
registerConcretization(td, mutated_val);
}

bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
TensorView* consumer) {
if (consumer->definition() == nullptr ||
!consumer->domain()->hasSymbolicAxis()) {
return false;
}

const auto& root_domain = consumer->getRootDomain();

auto def = consumer->definition();

bool is_concretized = false;

for (const auto i : c10::irange(root_domain.size())) {
auto root_id = root_domain.at(i);
if (root_id->getIterType() != IterType::Symbolic) {
continue;
}

// Figure out the right IterType of this consumer root ID from its
// corresponding producer IDs

std::optional<IterType> id_type;

for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
auto c2p = root_map.mapConsumerToProducer(
consumer->domain(), producer->domain());
TORCH_INTERNAL_ASSERT(
id_type.has_value(),
"Did not find id_type for consumer root domain ",
id->toString(),
". Perhaps consumer def has no inputs.");

TORCH_INTERNAL_ASSERT(
c2p.find(root_id) != c2p.end(),
"No input ID found to map with output ID: ",
root_id->toString());
TORCH_INTERNAL_ASSERT(
id_type.value() != IterType::Symbolic,
"Failed to concretize ",
id->toString());

auto input_id = c2p.at(root_id);
TORCH_INTERNAL_ASSERT(
input_id->getIterType() != IterType::Symbolic,
"Producer ID not concretized: ",
input_id->toString());
if (id_type.value() != id->getIterType())

if (id_type.has_value()) {
id_type = ops::promoteIterType(*id_type, input_id->getIterType());
} else {
id_type = input_id->getIterType();
concretized_id =
IterDomainBuilder(mut_id).iter_type(id_type.value()).build();
}
}

TORCH_INTERNAL_ASSERT(
id_type.has_value(),
"Did not find id_type for consumer root domain ",
root_id->toString(),
". Perhaps consumer def has no inputs. Consumer definition = ",
def->toString());

TORCH_INTERNAL_ASSERT(
id_type != IterType::Symbolic,
"Failed to concretize ",
root_id->toString(),
" of ",
consumer->toString());

auto concretized_id =
IterDomainBuilder(root_id).iter_type(*id_type).build();

registerConcretization(root_id, concretized_id);
is_concretized = true;
}

return is_concretized;
if (concretized_id) {
registerConcretization(id, concretized_id);
}
}

DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) {
Expand Down
Loading