Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,53 @@ void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) {
}
}

void ConcretizedBroadcastDomains::handle(CatOp* op) {
auto id =
op->out()->as<TensorView>()->getMaybeRFactorDomain().at(op->concatenatedDim());
if (id->isBroadcast()) {
broadcast_origin_map_.emplace(id, std::unordered_set<IterDomain*>({id}));
}
}

void ConcretizedBroadcastDomains::handle(PadOp* op) {
for (auto i : op->getPaddedAxes()) {
// Instead of the root domain of the output, as with BroadcastOp, we set the
// origin as the RFactor domain, since PadOp inserts Resize ops between root
// and rfactor
auto id = op->out()->as<TensorView>()->getMaybeRFactorDomain().at(i);
if (id->isBroadcast()) {
broadcast_origin_map_.emplace(id, std::unordered_set<IterDomain*>({id}));
}
}
}

void ConcretizedBroadcastDomains::handle(SliceOp* op) {
auto consumer_root = op->out()->as<TensorView>()->getMaybeRFactorDomain();
auto producer_rfactor = TensorDomain::noReductions(
op->in()->as<TensorView>()->getMaybeRFactorDomain());
TORCH_INTERNAL_ASSERT(
consumer_root.size() == producer_rfactor.size(),
"Consumer root size ",
consumer_root.size(),
" does not match producer rfactor size ",
producer_rfactor.size());
for (auto i : c10::irange(consumer_root.size())) {
auto cid = consumer_root.at(i);
auto pid = producer_rfactor.at(i);
if (cid->isBroadcast()) {
// Map to producer ID if it was already broadcast. Otherwise to consumer
// ID
if (pid->isBroadcast()) {
broadcast_origin_map_.emplace(
pid, std::unordered_set<IterDomain*>({cid, pid}));
} else {
broadcast_origin_map_.emplace(
cid, std::unordered_set<IterDomain*>({cid}));
}
}
}
}

void ConcretizedBroadcastDomains::handle(Expr* expr) {
IterVisitor::handle(expr);

Expand Down
7 changes: 7 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {

void handle(BroadcastOp* bop) final;

// After concretization, ops with Resized IterDomains in their outputs may set
// the broadcast flag, even though they are not BroadcastOps themselves. In
// these cases, we set the output as the origin.
void handle(CatOp* op) final;
void handle(PadOp* op) final;
void handle(SliceOp* op) final;

void handle(Expr* expr) final;

void markAsConcretized(
Expand Down
6 changes: 5 additions & 1 deletion csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,10 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
if (!out_id->isSymbolic() ||
mutations_.find(out_id) != mutations_.end()) {
// Skip symbolic outputs and outputs that have already been registered
// for mutation
continue;
}
auto concretized_out_id =
Expand Down Expand Up @@ -644,6 +647,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(

for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
root_map.mapSymbolic(true);
auto c2p = root_map.mapConsumerToProducer(
consumer->domain(), producer->domain());

Expand Down
4 changes: 4 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,10 @@ class TORCH_CUDA_CU_API CatOp : public Expr {
return attribute(0)->as<Attribute<int>>()->value;
}

Val* out() const {
return output(0);
}

//! The index val that determines which input tensor should be used
//! to fill the particular output position of this expression. Only
//! valid after indexing
Expand Down
74 changes: 68 additions & 6 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ std::vector<IterDomain*> newOutputDomain(
std::vector<int64_t> start_offsets(out_domain.size(), 0);
std::vector<int64_t> stop_offsets(out_domain.size(), 0);
std::vector<Val*> extent_vals(out_domain.size(), nullptr);
std::vector<bool> mismatched_symbolic_extents(out_domain.size(), false);
std::vector<Val*> expanded_extent_vals(out_domain.size(), nullptr);
std::vector<c10::optional<IterType>> iter_types(
out_domain.size(), c10::nullopt);
std::vector<std::optional<IterType>> iter_types(
out_domain.size(), std::nullopt);

for (auto tv : tvs) {
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
Expand All @@ -235,7 +236,62 @@ std::vector<IterDomain*> newOutputDomain(
dom.size(),
" dimensions but expected ",
out_domain.size());
// If there is any Iteration domain, we should use the first one's
// extent.
//
// If all inputs are Symbolic or Broadcast, then we can use the
// symbolic extent if all the symbolic extents agree.
//
// Otherwise, we don't know the output extent and iter_type should be
// Symbolic if there are any Symbolic inputs else Broadcast.
for (const auto i : c10::irange(dom.size())) {
auto iter_type = dom[i]->getIterType();
auto prev_iter_type = iter_types[i];
if (prev_iter_type.has_value()) {
// Clang-tidy complains about unchecked access to optional value here
if (iter_type == IterType::Iteration &&
prev_iter_type.value() == IterType::Symbolic) {
// Prefer the Iteration extent, since Symbolic could be broadcast
extent_vals[i] = nullptr;
} else if (iter_type == IterType::Symbolic) {
switch (prev_iter_type.value()) {
case IterType::Iteration:
// Previously found Iteration domain, so ignore all Symbolic
// domains
continue;
case IterType::Symbolic:
if (extent_vals[i]->sameAs(dom[i]->extent())) {
// matching symbolic extent
continue;
} else {
// Mismatched symbolic input extents. Any one of the symbolic
// inputs could be a Broadcast or Iteration domain. Until
// concretization, we will not know which one holds the true
// extent (or whether they all are Broadcast, so that the output
// is also Broadcast). We record that these symbolic extents
// mismatched so that we can introduce a new symbolic extent
// later.
mismatched_symbolic_extents[i] = true;
}
break;
case IterType::Broadcast:
// Previously found only broadcast, so this will either also
// broadcast or resolve those broadcasts. If the expanded
// extent of any of the broadcasts is not 1, then it will need to
// match that of the dom[i]. In either case, prefer dom[i]'s
// extent, so clear iter_types[i] and extent_vals[i] so that the
// rest of this iteration will mark output as Symbolic.
iter_types[i] = std::nullopt;
extent_vals[i] = nullptr;
break;
default:
TORCH_CHECK(
false,
"Encountered unexpected IterType when creating new output domain: ",
prev_iter_type.value());
}
}
}
if (dom[i]->isBroadcast()) {
if (dom[i]->hasExpandedExtent()) {
expanded_extent_vals[i] =
Expand All @@ -244,9 +300,9 @@ std::vector<IterDomain*> newOutputDomain(
continue;
}
extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
if (iter_types[i].has_value()) {
if (prev_iter_type.has_value()) {
iter_types[i] =
promoteIterType(iter_types[i].value(), dom[i]->getIterType());
promoteIterType(prev_iter_type.value(), dom[i]->getIterType());
} else {
iter_types[i] = dom[i]->getIterType();
}
Expand All @@ -268,15 +324,21 @@ std::vector<IterDomain*> newOutputDomain(
}
}
for (const auto dim_i : c10::irange(out_domain.size())) {
auto iter_type = iter_types[dim_i];
if (iter_type == IterType::Symbolic && mismatched_symbolic_extents[dim_i]) {
// if we have a symbolic output but the input symbolic extents did not
// match, create a new extent
extent_vals[dim_i] = nullptr;
}
if (extent_vals[dim_i] != nullptr) {
TORCH_INTERNAL_ASSERT(
iter_types[dim_i].has_value(),
iter_type.has_value(),
"Could not deduce iter type for new tensor view.");
out_domain[dim_i] =
IterDomainBuilder(
IrBuilder::create<Int>(start_offsets[dim_i]), extent_vals[dim_i])
.stop_offset(IrBuilder::create<Int>(stop_offsets[dim_i]))
.iter_type(iter_types[dim_i].value())
.iter_type(iter_type.value())
.build();
} else {
out_domain[dim_i] = IterDomainBuilder(
Expand Down
14 changes: 13 additions & 1 deletion csrc/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// domains of torch_gather)
// 3. Squeeze and unsqueeze
// 4. Broadcast and non broadcast
// 5. Symbolic IDs

// Condition 1: when the producer ID is the dim of a select-like op
if (producer_id == indexed_producer_id) {
Expand Down Expand Up @@ -181,6 +182,15 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
continue;
}

// Condition 5
if (!map_symbolic_ &&
(producer_id->getIterType() == IterType::Symbolic ||
consumer_id->getIterType() == IterType::Symbolic)) {
itc++;
itp++;
continue;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;
if (!producer_to_consumer) {
Expand Down Expand Up @@ -861,7 +871,9 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped(
}

if (consumer_id->isBroadcast()) {
TORCH_INTERNAL_ASSERT(producer_id->isBroadcast());
// Note that consumer may be broadcast even though producer is not if it is
// the output of a Resize op.

// Get bcast_map_ entry for consumer_id
const auto consumer_bcast_domains =
root_map_.getConcretizedKeys(consumer_td, consumer_id);
Expand Down
7 changes: 7 additions & 0 deletions csrc/root_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap {
return *this;
}

PairwiseRootDomainMap& mapSymbolic(bool b) {
map_symbolic_ = b;
return *this;
}

PairwiseRootDomainMap& mapDifferentExtents(bool b) {
map_different_extents_ = b;
return *this;
Expand Down Expand Up @@ -136,6 +141,8 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap {
//! Map broadcast and non-broadcast domains. Note that this is on by
//! default
bool map_broadcast_ = true;
//! Map symbolic domains with one another.
bool map_symbolic_ = false;
//! Map domains that may have different extents, e.g., torch_gather
bool map_different_extents_ = false;
//! Map domains that are indirectly accessed, e.g., index_select
Expand Down
61 changes: 61 additions & 0 deletions test/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,67 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) {
reductionDynamicPadAddFusion(invocations);
}

// Test dynamic pad followed by broadcast resolution
TEST_F(NVFuserTest, DynamicPadBroadcast_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

TensorView* tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);

// 2d axis order here is YX
auto ypad = IrBuilder::create<Int>();
fusion.addInput(ypad);
auto xpad = IrBuilder::create<Int>();
fusion.addInput(xpad);

// two-way resizes to cut square tv down to broadcastable size in each axis
auto tv0_pad = pad(tv0, {fusion.zeroVal(), xpad, fusion.zeroVal(), ypad});

// This will potentially resolve the y or x broadcast
auto p = mul(tv0_pad, tv1);
fusion.addOutput(p);
fusion.addOutput(tv0_pad);

fusion.printMath();

FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr));

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at_x = at::randn({5, 5}, options);
at::Tensor at_y = at::randn({5, 5}, options);

// trivial resize
std::vector<c10::IValue> aten_inputs({at_x, at_y, 0, 0});
std::vector<at::Tensor> outputs;

/*
aten_inputs[2] = 0;
aten_inputs[3] = 0;
outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs);
testValidate(fusion_executor_cache.fusion(), outputs, aten_inputs, {at_x *
at_y}, __LINE__, __FILE__);
*/

// shrink first axis
aten_inputs[2] = -4;
aten_inputs[3] = 0;
outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs);
std::cout << outputs << std::endl;
std::cout << at_x.slice(0, 0, 1) * at_y << std::endl;
std::cout << at_x.slice(0, 0, 1) << std::endl;
testValidate(
fusion_executor_cache.fusion(),
outputs,
aten_inputs,
{at_x.slice(0, 0, 1) * at_y, at_x.slice(0, 0, 1)},
__LINE__,
__FILE__);
}

// Test that a Symbolic root/Broadcast rfactor is not concretized to
// Iteration/Iteration
TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) {
Expand Down