Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e325d99
Add tests for padding to broadcast in various ways
jacobhinkle Jul 18, 2023
8d30c20
Set bcast consumer IDs with no bcast producers as origins
jacobhinkle Jul 18, 2023
e62d692
Only map root->rfactor IDs with same IterType in ca root dom map builder
jacobhinkle Jul 18, 2023
18cc3d1
Switch to more challenging static pad to broadcast test
jacobhinkle Jul 18, 2023
af340f4
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Jul 18, 2023
bed1358
Remove commented code in static test
jacobhinkle Jul 18, 2023
ea05f07
Always skip mapping symbolic IterDomains that have different extents
jacobhinkle Jul 19, 2023
3fcafcb
Use maybeMutated id for isSymbolic check.
jacobhinkle Jul 19, 2023
78d8ffc
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Jul 19, 2023
5d0b6c5
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Jul 21, 2023
77910f0
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Jul 24, 2023
e84ffe6
Update comment for isSymbolic check in mutate(TV)
jacobhinkle Jul 24, 2023
c6410b4
Add comment to #596 test
jacobhinkle Jul 24, 2023
b877331
Expand comment about condition 5
jacobhinkle Jul 24, 2023
dc3a156
Add comment to itertype check in ComputeAtRootDomainMapBuilder
jacobhinkle Jul 24, 2023
6020cd6
Map Symbolic with non-Broadcast in propagateFromP2C
jacobhinkle Jul 24, 2023
9b48a6f
Rename test
jacobhinkle Jul 24, 2023
6e2945a
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Jul 24, 2023
012e878
Remove mapSymbolicNonBroadcast option. On always now
jacobhinkle Jul 25, 2023
1aca3af
Revert "Remove mapSymbolicNonBroadcast option. On always now"
jacobhinkle Jul 25, 2023
eeff747
Update mapSymbolic option. Expose it to ExactRootDomainMap
jacobhinkle Jul 25, 2023
3e974ba
Fix DynamicTransform4_CUDA, add long comment
jacobhinkle Jul 25, 2023
dbb7157
Fix DynamicTransform3_CUDA
jacobhinkle Jul 25, 2023
a5f2543
Fix DynamicTransform1_CUDA
jacobhinkle Jul 25, 2023
8b8b947
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Jul 25, 2023
84f9b69
Update doxygen comment for mapSymbolic
jacobhinkle Jul 25, 2023
648c1f4
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Jul 26, 2023
ec1f978
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Aug 15, 2023
dd30b1b
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Sep 7, 2023
82cbaa5
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Sep 11, 2023
8af36ea
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Sep 13, 2023
7ff24e6
Register concretization from unmutated ID in root->rfactor
jacobhinkle Sep 13, 2023
1fe841a
Clean up comment in condition 5
jacobhinkle Sep 13, 2023
4cd57bf
Remove map_symbolic from ExactRootDomainMap
jacobhinkle Sep 22, 2023
25e27c1
Asset instead of ignoring missing c2p mapping
jacobhinkle Sep 22, 2023
15be185
Clean up
jacobhinkle Sep 22, 2023
0f7952f
clang-format
jacobhinkle Sep 22, 2023
10e46ee
Merge remote-tracking branch 'origin/main' into resolve_resize_broadc…
jacobhinkle Sep 22, 2023
166ff69
Handle TensorView instead of BroadcastOp
jacobhinkle Sep 22, 2023
480dee3
Check directly that rfactor bcast is not root
jacobhinkle Sep 22, 2023
8cc4076
Remove check for missing entry in origin map
jacobhinkle Sep 22, 2023
d8fce7b
Restore handle(BroadcastOp*)
jacobhinkle Sep 22, 2023
f43a4b5
Undo trivial prior change to minimize diff
jacobhinkle Sep 22, 2023
9db969a
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Sep 22, 2023
3da217b
Add comments in trivial_broadcast.cpp
jacobhinkle Sep 25, 2023
b30622d
Typos
jacobhinkle Sep 25, 2023
95d6da9
Merge branch 'main' into resolve_resize_broadcasts
jacobhinkle Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ std::unordered_set<IterDomain*> ConcretizedBroadcastDomains::
return {};
}

// In some cases an op like pad or slice will introduce a broadcast domain by
// truncating a longer dimension or expanding an empty dimension to size 1. In
// these cases tv will have RFactor Broadcast IterDomains that are not present
// in the root domain. Contrast this with BroadcastOp, whose output does not
// have RFactor domains and instead places new broadcast domains in the output
// root domain.
void ConcretizedBroadcastDomains::handle(TensorView* tv) {
if (!tv->hasRFactor()) {
return;
}
for (auto id : tv->getMaybeRFactorDomain()) {
// Register broadcast rfactor domains that are not root domains as new
// broadcast origins.
if (id->isBroadcast() &&
std::find(tv->getRootDomain().begin(), tv->getRootDomain().end(), id) ==
tv->getRootDomain().end()) {
broadcast_origin_map_.emplace(id, std::unordered_set<IterDomain*>({id}));
}
}
}

// Most broadcasts are handled with this method, since Broadcast domains are
// usually introduced through a BroadcastOp. Others are handled by the
// handle(TensorView*) method.
void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) {
// Create a new entry for each of new broadcast domains
auto out = bop->out()->as<TensorView>();
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ConcretizedBroadcastDomains : private IterVisitor {
private:
using IterVisitor::handle;

void handle(TensorView* tv) final;

void handle(BroadcastOp* bop) final;

void dispatch(Expr* expr) final;
Expand Down
38 changes: 26 additions & 12 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,10 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
}
// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
auto mut_id = maybeMutated(out_id)->as<IterDomain>();
if (!mut_id->isSymbolic()) {
// We are only concretizing IterType here, so if we have already
// concretized the iter_type for this ID, we can skip this.
continue;
}

Expand All @@ -690,9 +693,7 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
expr->toString());

auto concretized_out_id =
IterDomainBuilder(maybeMutated(out_id)->as<IterDomain>())
.iter_type(iter_type)
.build();
IterDomainBuilder(mut_id).iter_type(iter_type).build();
registerConcretization(out_id, concretized_out_id);
}

Expand Down Expand Up @@ -794,6 +795,23 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(

auto def = consumer->definition();

// We will loop over IterDomains in the consumer root. For each, we need to
// inspect the consumer to producer map to all producers. Instead of
// recomputing these for each root IterDomain, we precompute them for each
// producer here then re-use them in the following loop.
std::vector<std::unordered_map<IterDomain*, IterDomain*>> c2p_maps;
for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
// We map symbolic domains here regardless of whether their extents match.
// This is safe because we are propagating from a producer which should have
// already been concretized. The consumer might have a different extent
// which will be equivalent to (but not necessarily sameAs) the producer's,
// and we just want to use its IterType to concretize the consumer ID.
root_map.mapSymbolic(true);
c2p_maps.push_back(
root_map.mapConsumerToProducer(consumer->domain(), producer->domain()));
}

bool is_concretized = false;

for (const auto i : c10::irange(root_domain.size())) {
Expand All @@ -807,17 +825,13 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(

std::optional<IterType> id_type;

for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
auto c2p = root_map.mapConsumerToProducer(
consumer->domain(), producer->domain());

for (const auto& c2p : c2p_maps) {
auto p_it = c2p.find(root_id);
NVF_ERROR(
c2p.find(root_id) != c2p.end(),
p_it != c2p.end(),
"No input ID found to map with output ID: ",
root_id->toString());

auto input_id = c2p.at(root_id);
auto input_id = p_it->second;
NVF_ERROR(
input_id == maybeMutated(input_id),
"Consumer IterDomain ",
Expand Down
3 changes: 3 additions & 0 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ void ExpressionEvaluator::print() const {
}

void ExpressionEvaluator::propagateBoundValuesThroughExactMaps(Fusion* fusion) {
// We map Symbolic IterDomains here only if their extents match. This avoids
// mapping between symbolic domains that might concretize to an (Iteration,
// Broadcast) pair from a resolved broadcast.
const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets();

for (const auto& set : mapped_sets.disjointSets()) {
Expand Down
31 changes: 30 additions & 1 deletion csrc/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// domains of torch_gather)
// 3. Squeeze and unsqueeze
// 4. Broadcast and non broadcast
// 5. Symbolic ID with different extent from other ID

// Condition 1: when the producer ID is the dim of a select-like op
if (producer_id == indexed_producer_id) {
Expand Down Expand Up @@ -182,6 +183,27 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
continue;
}

// Condition 5
// At least one ID is symbolic.
//
// If map_symbolic_ is true:
// Map these IDs regardless of other considerations.
//
// If map_symbolic_ is false (default):
// Map these only if their extents are identical. IterType::Symbolic
// reflects that the extent might evaluate to 1 for some inputs, in which
// case it may be valid to use those domains in a broadcast op. If the
// extents are exactly the same between two aligned IterDomains, the
// Symbolic one will be concretized to the same IterType as the other, so
// they should be mapped with one another.
if (!map_symbolic_ &&
(producer_id->isSymbolic() || consumer_id->isSymbolic()) &&
(!producer_id->extent()->sameAs(consumer_id->extent()))) {
itc++;
itp++;
continue;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;
if (!producer_to_consumer) {
Expand Down Expand Up @@ -1185,7 +1207,14 @@ void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) {
if (root_set.find(id) == root_set.end() || rf_id == id) {
continue;
}
setMaybeMapped(td, id, td, rf_id);
// Usually, the itertypes between IterDomain expression inputs and
// outputs will match. However, it is possible for a Resize operation to
// take an Iteration input and reduce it to size 1, after which it
// becomes Broadcast. This check avoids mapping an Iteration and
// Broadcast domain in such a case.
if (id->getIterType() == rf_id->getIterType()) {
setMaybeMapped(td, id, td, rf_id);
}
}
}
// Once mappings for rfactor axes are propagated to root axes,
Expand Down
12 changes: 12 additions & 0 deletions csrc/root_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class PairwiseRootDomainMap : public RootDomainMap {
return *this;
}

//! If b is true: map symbolic domains with other IterDomains even if their
//! extents don't match. If b is false (default): map symbolic domains with
//! other IterDomains only if their extents match.
PairwiseRootDomainMap& mapSymbolic(bool b) {
map_symbolic_ = b;
return *this;
}

PairwiseRootDomainMap& mapDifferentExtents(bool b) {
map_different_extents_ = b;
return *this;
Expand Down Expand Up @@ -137,6 +145,10 @@ class PairwiseRootDomainMap : public RootDomainMap {
//! Map broadcast and non-broadcast domains. Note that this is on by
//! default
bool map_broadcast_ = true;
//! Map symbolic domains with other IterDomains, even if their extents don't
//! match. Note that this is off by default, in which case they are mapped
//! only if their extents match.
bool map_symbolic_ = false;
//! Map domains that may have different extents, e.g., torch_gather
bool map_different_extents_ = false;
//! Map domains that are indirectly accessed, e.g., index_select
Expand Down
16 changes: 16 additions & 0 deletions test/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ TEST_F(NVFuserTest, DynamicTransform1_CUDA) {
expr_eval.bind(tv0->axis(1)->extent(), 3L);
expr_eval.bind(reshape_shape0, 3L);
expr_eval.bind(reshape_shape1, 4L);
// We cannot infer the shape of tv1 from the above bound values, since
// either axis of tv2 might be broadcast against one from tv1.
expr_eval.bind(tv1->axis(0)->extent(), 3L);
expr_eval.bind(tv1->axis(1)->extent(), 4L);

auto initial_info = DynamicTransform::getInitialInfo(&fusion);
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);
Expand Down Expand Up @@ -187,6 +191,11 @@ TEST_F(NVFuserTest, DynamicTransform3_CUDA) {
expr_eval.bind(tv0->axis(1)->extent(), shape_before.at(1));
expr_eval.bind(tv1->axis(0)->extent(), shape_after.at(0));
expr_eval.bind(tv1->axis(1)->extent(), shape_after.at(1));
// We cannot infer reshape_shape0 and reshape_shape1 from tv0's and tv1's
// extents alone, since either of these reshaped extents could either match
// that of tv1 or be 1, resulting in a broadcast.
expr_eval.bind(reshape_shape0, shape_after.at(0));
expr_eval.bind(reshape_shape1, shape_after.at(1));

auto initial_info = DynamicTransform::getInitialInfo(&fusion);
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);
Expand Down Expand Up @@ -251,6 +260,13 @@ TEST_F(NVFuserTest, DynamicTransform4_CUDA) {

for (const auto i : c10::irange(after_shape.size())) {
expr_eval.bind(tv2->axis((int)i)->extent(), after_shape.at(i));
// We must bind tv1's extents, since they cannot be inferred until after
// concretization. Because tv2 is a dynamic reshape both its IterDomains
// are Symbolic, which means both of tv3's IterDomains are also Symbolic.
// tv1 has both IterDomains of type Iteration, but it since we add tv3 to
// it to get tv4, we do not know whether this will resolve broadcasts from
// tv3 or not until concretization.
expr_eval.bind(tv1->axis((int)i)->extent(), after_shape.at(i));
}

auto initial_info = DynamicTransform::getInitialInfo(&fusion);
Expand Down
Loading