Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
4ae4f09
First draft of handling Resize in DynamicTransformInfoBuilder
jacobhinkle Apr 28, 2023
4d75a37
Concretize resizes, clone resize transforms
jacobhinkle Apr 28, 2023
a027f2b
Set resized IterDomains to symbolic if extent is non-constant
jacobhinkle Apr 28, 2023
5de8af5
Set resized id to symbolic, broadcast, or iteration
jacobhinkle Apr 28, 2023
8aa38e5
Return references for getRe{shape,size}Transforms()
jacobhinkle Apr 29, 2023
840b5ee
Improve logic for determining concrete resize IterType
jacobhinkle Apr 29, 2023
f6ccdd9
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 1, 2023
acfd7ee
Add note about -1 issue with link
jacobhinkle May 1, 2023
c60a13d
Add dynamic pad shmoo test
jacobhinkle May 1, 2023
45b60e6
Hold TensorDomain for resizes in order to replace in output rootdomain
jacobhinkle May 1, 2023
1b84942
Fix issues in concretization of resize
jacobhinkle May 1, 2023
cdb0f27
Remove replaced vals after resize concretization
jacobhinkle May 1, 2023
6dbdbb1
Add more cases to pad shmoo test
jacobhinkle May 1, 2023
c663ece
Re-use IterType computation in concretization
jacobhinkle May 1, 2023
4efc74c
Add zero-element reduction test
jacobhinkle May 2, 2023
81a6f10
Short-circuit outerReductionHeuristic on numel==0
jacobhinkle May 2, 2023
4174a63
Short-circuit if numel of input is zero
jacobhinkle May 2, 2023
37fc7ce
Also guard innerReductionHeuristic, update test
jacobhinkle May 2, 2023
89b55cf
Merge branch 'main' into reduce_zero_elt
jacobhinkle May 2, 2023
4e2c170
Add more test cases for zero-element reduction
jacobhinkle May 3, 2023
d7714ce
Merge branch 'main' into reduce_zero_elt
jacobhinkle May 3, 2023
8bcc12c
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 3, 2023
4726363
Merge remote-tracking branch 'jh/reduce_zero_elt' into dynamic_resize
jacobhinkle May 3, 2023
bf87c19
Move resizeOutputIterType to ir_utils
jacobhinkle May 9, 2023
3371303
Use initializer-if in csrc/dynamic_transform.cpp
jacobhinkle May 9, 2023
d5f5213
Add iter_type option to resize(), other fixes
jacobhinkle May 10, 2023
1febea7
Merge branch 'dynamic_resize' of github.com:jacobhinkle/Fuser into dy…
jacobhinkle May 10, 2023
dcdccb7
Merge branch 'main' into dynamic_resize
jacobhinkle May 10, 2023
90260c7
Merge branch 'dynamic_resize' of github.com:jacobhinkle/Fuser into dy…
jacobhinkle May 10, 2023
88bb239
Minor fixup for initializer-if
jacobhinkle May 10, 2023
fde271d
Remove stale TODO comment
jacobhinkle May 10, 2023
2d98beb
Silence clang-tidy
jacobhinkle May 10, 2023
b43e70f
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 10, 2023
db5b8fc
Remove unguarded reduction stats printing
jacobhinkle May 10, 2023
5db7cb4
Try harder for static IterDomain::resize
jacobhinkle May 10, 2023
a02fbb9
Fix skip resize ops in BestEffortReplay
naoyam May 10, 2023
23f383b
Replace getInt with evaluateInt in IterDomain::resize()
jacobhinkle May 10, 2023
6cd6932
Use concrete sizes in cat and slice tests.
jacobhinkle May 11, 2023
aec043d
Stop holding TensorView in resize_transforms_
jacobhinkle May 11, 2023
c76c2ee
Mutate all IterDomain expressions from root
jacobhinkle May 12, 2023
678e0ec
Don't bail if no ID expr outputs are symbolic
jacobhinkle May 12, 2023
6bd6f31
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 12, 2023
b6b085f
Update ID symbolic check
jacobhinkle May 12, 2023
3826967
De-indent to clarify diff
jacobhinkle May 12, 2023
ab987b8
Do simple mutation if nothing propagates from producer
jacobhinkle May 12, 2023
eb3fca6
Change checks to better explain cases, expand comments
jacobhinkle May 12, 2023
1e403e7
Merge branch 'main' into dynamic_resize
jacobhinkle May 12, 2023
7a9a4d6
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 13, 2023
c9448ca
Disable all-or-nothing symbolic output check on exprs
jacobhinkle May 15, 2023
fe3a26f
Merge remote-tracking branch 'origin/main' into dynamic_resize
jacobhinkle May 15, 2023
e5aa9c8
Defer fusion->hasDynamicTransform() in FEC
jacobhinkle May 15, 2023
5cf6971
Remove zero-element changes that leaked into this PR
jacobhinkle May 16, 2023
8f4cc09
Remove early returns in concretize root->rfactor
jacobhinkle May 16, 2023
15f0cd8
Concretize IterDomains in-place instead of building
jacobhinkle May 16, 2023
d40fcfe
Remove erroneous TODO comment in DynamicPadShmoo_CUDA
jacobhinkle May 16, 2023
65656a6
Revert "Concretize IterDomains in-place instead of building"
jacobhinkle May 16, 2023
c86e048
Remove resizeOutputIterType
jacobhinkle May 16, 2023
5885c80
Expand comment on resize(), explain iter_type arg
jacobhinkle May 16, 2023
dc2c28b
Point to #264 and #346 in DynamicPadShmoo test
jacobhinkle May 16, 2023
00005e1
Merge branch 'main' into dynamic_resize
jacobhinkle May 16, 2023
2581958
Minor clean up
jacobhinkle May 16, 2023
88703fa
Merge branch 'main' into dynamic_resize
jacobhinkle May 16, 2023
96f1876
Merge branch 'main' into dynamic_resize
jacobhinkle May 16, 2023
152a0a8
Fix DynamicTransformConcretizationInfo::operator==
jacobhinkle May 16, 2023
4f9b4ba
Remove rfactor&&broadcast check, uncomment failing test.
jacobhinkle May 16, 2023
e29c7f5
Merge branch 'dynamic_resize' of github.com:jacobhinkle/Fuser into dy…
jacobhinkle May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 125 additions & 33 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class DynamicTransformInfoBuilder : public IterVisitor {
// Analyze a dynamic reshape and generate AnalyzeViewResult
void handle(ViewOp* op) override;

// We handle IterDomain "Resize" ops at TensorView level
void handle(TensorView* tv) override;

const auto& getInfo() const {
return info_;
}
Expand Down Expand Up @@ -77,20 +80,35 @@ bool DynamicTransformConcretizationInfo::operator==(
}
}

for (const auto i : c10::irange(resize_transforms_.size())) {
const auto& transform = resize_transforms_.at(i);
const auto& other_transform = other.resize_transforms_.at(i);
if (transform != other_transform) {
return false;
}
}

return true;
}

DynamicTransformConcretizationInfo DynamicTransformConcretizationInfo::clone(
IrCloner& ir_cloner) const {
DynamicTransformConcretizationInfo cloned_info(
(Fusion*)ir_cloner.container());
for (auto& pair : reshape_transforms_) {
for (const auto& [tv, analyze_result] : reshape_transforms_) {
cloned_info.reshape_transforms_.emplace_back(
ir_cloner.clone(pair.first),
ir_cloner.clone(tv),
// reshape_transforms_ holds pairs of TensorView* and AnalyzeViewResult
// AnalyzeViewResult can be copied directly as it holds no references to
// Statements that would need cloning, only integer indices of axes.
pair.second);
analyze_result);
}
for (const auto& [id, iter_type] : resize_transforms_) {
cloned_info.resize_transforms_.emplace_back(
ir_cloner.clone(id),
// Similar to reshape_transforms_, we only clone the IterDomains in
// resize_transforms_
iter_type);
}
return cloned_info;
}
Expand All @@ -104,9 +122,56 @@ std::string DynamicTransformConcretizationInfo::toString() const {
ss << indent << indent << kv.first->toString() << ", "
<< kv.second.toString() << "\n";
}
ss << indent << "Resize:\n";
for (const auto& [id, iter_type] : resize_transforms_) {
ss << indent << indent << id->toString() << ", " << iter_type << "\n";
}
return ss.str();
}

void DynamicTransformInfoBuilder::handle(TensorView* tv) {
const auto& rfd = tv->getMaybeRFactorDomain();
for (auto id : rfd) {
if (!id->definition()) {
continue;
}
if (auto op = dynamic_cast<Resize*>(id->definition());
id->getIterType() == IterType::Symbolic && op != nullptr) {
auto out_extent_val = expr_eval_->evaluate(id->extent());
TORCH_INTERNAL_ASSERT(
out_extent_val.has_value(),
"Cannot evaluate the extent of a resized IterDomain: ",
id->toString());

auto in_id = op->in()->as<IterDomain>();
auto in_extent_val = expr_eval_->evaluate(in_id->extent());
TORCH_INTERNAL_ASSERT(
in_extent_val.has_value(),
"Cannot evaluate the extent of input to an IterDomain resize: ",
in_id->toString());

auto left = op->leftExpand()->as<Int>();
auto left_val = expr_eval_->evaluate(left);
TORCH_INTERNAL_ASSERT(
left_val.has_value(),
"Cannot evaluate the left expansion of an IterDomain resize: ",
left->toString());

auto right = op->rightExpand()->as<Int>();
auto right_val = expr_eval_->evaluate(right);
TORCH_INTERNAL_ASSERT(
right_val.has_value(),
"Cannot evaluate the right expansion of an IterDomain resize: ",
right->toString());

auto out_itertype = out_extent_val->as<int64_t>() == 1
? IterType::Broadcast
: IterType::Iteration;
info_.resize_transforms_.emplace_back(id, out_itertype);
}
}
}

void DynamicTransformInfoBuilder::handle(ViewOp* op) {
auto inp_tv = op->in()->as<TensorView>();
auto out_tv = op->out()->as<TensorView>();
Expand Down Expand Up @@ -204,6 +269,8 @@ class DynamicTransformConcretizer : public OptOutMutator {

void concretizeReshape();

void concretizeResize();

using OptOutMutator::mutate;

void mutate(TensorView* tv) final;
Expand All @@ -216,15 +283,17 @@ class DynamicTransformConcretizer : public OptOutMutator {

private:
const DynamicTransformConcretizationInfo& info_;
std::unordered_map<IterDomain*, IterDomain*> update_map_;
};

void DynamicTransformConcretizer::concretize() {
// First, concretize all dynamic reshape ops
concretizeReshape();

// Second, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_.fusion(), false);
// Set output IterTypes for dynamic resize ops
concretizeResize();

// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_.fusion(), true);
for (auto stmt : all_stmts) {
if (stmt->isA<Val>()) {
mutate(stmt);
Expand Down Expand Up @@ -257,6 +326,24 @@ void DynamicTransformConcretizer::concretizeReshape() {
}
}

void DynamicTransformConcretizer::concretizeResize() {
// Concretize each resize op.
for (const auto& [id, iter_type] : info_.getResizeTransforms()) {
TORCH_CHECK(
id->definition() && id->definition()->isA<Resize>(),
"Resized IterDomain must have a Resize definition");
auto def = id->definition()->as<Resize>();
auto new_id = IterDomain::resize(
def->in(),
def->leftExpand(),
def->rightExpand(),
id->isRFactorProduct(),
iter_type);

registerMutation(id, new_id);
}
}

// Concretizes inherited symbolic domains. Note that when this is
// called, it is assumed that all dynamic ops themselves are
// concretized. Since symbolic IDs may be propagated down to
Expand All @@ -268,15 +355,12 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {

// First, try to concretize the root domain as there may be symbolic
// axes inherited from the producers
auto propagated = propagateFromProducerToConsumer(tv);

// If no root domain is altered, nothing to do further
if (!propagated) {
return;
}
propagateFromProducerToConsumer(tv);

// Root IDs are altered. Need to propagate the changes to rfactor
// domain
// If no root domain is altered by producer, we don't need to propagate back
// up to rfactor. We could return early, but instead we go ahead and check the
// root to rfactor transforms to be sure we have concretized any intermediate
// IterDomains.

// At this point, there should be no expr beyond rfactor root
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -308,20 +392,21 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
". IterDomain was expected.");
}

// If none of the output IDs is symbolic, nothing to concretize
if (std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->as<IterDomain>()->getIterType() !=
IterType::Symbolic;
})) {
continue;
}
// If any of output IDs is symbolic, all outputs should be symbolic
TORCH_INTERNAL_ASSERT(std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->as<IterDomain>()->getIterType() ==
IterType::Symbolic;
}));
// NOTE: We do not return early if all outputs are concrete as there may
// still be concrete inputs. For example, a Symbolic IterDomain might be
// padded with constant pad widths (1, 1), in which case although we do
// not know the exact extent of the output, we know it is at least as
// large as the sum of the pad widths, 2. In such cases, the output
// IterDomain is concrete at definition, since if the extent is >1 we know
// the IterType is Iteration. In these cases, we must continue to
// concretize intermediate expressions between the root and R-factor
// domain. See test DynamicTransform5_CUDA which demonstrates this
// behavior.
// NOTE: We also do not assume that if one output ID is symbolic, that
// they all must be. See test FusionSliceForNanoGPT3_CUDA for an example
// that does a static split by a factor of 16 of a symbolic input domain.
// The static split in that case results in a concrete IterDomain with
// extent 16 along with a symbolic one (extent ceilDiv(n / 16)).

// Determine the output IterType
IterType iter_type = IterType::Symbolic;
Expand All @@ -336,13 +421,13 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
auto concreteized_out_id =
auto concretized_out_id =
IterDomainBuilder(out_id).iter_type(iter_type).build();
registerMutation(out_id, concreteized_out_id);
registerMutation(out_id, concretized_out_id);
}

// Outputs are mutated. The expr itself needs to be mutated as
// well, which can be done by the mutate method
// The expr itself needs to be mutated as well in case the outputs are
// mutated, which can be done by the mutate method
OptOutMutator::mutate(expr);
}
}
Expand Down Expand Up @@ -457,7 +542,14 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
}

TORCH_INTERNAL_ASSERT(
id_type.has_value() && id_type != IterType::Symbolic,
id_type.has_value(),
"Did not find id_type for consumer root domain ",
root_id->toString(),
". Perhaps consumer def has no inputs. Consumer definition = ",
def->toString());

TORCH_INTERNAL_ASSERT(
id_type != IterType::Symbolic,
"Failed to concretize ",
root_id->toString(),
" of ",
Expand Down
14 changes: 13 additions & 1 deletion csrc/dynamic_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ class DynamicTransformInfoBuilder;
//! of the fusion inputs
class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
public:
const std::vector<std::pair<TensorView*, AnalyzeViewResult>>
const std::vector<std::pair<TensorView*, AnalyzeViewResult>>&
getReshapeTransforms() const {
return reshape_transforms_;
}

const std::vector<std::pair<IterDomain*, IterType>>& getResizeTransforms()
const {
return resize_transforms_;
}

bool operator==(const DynamicTransformConcretizationInfo& other) const;

bool operator!=(const DynamicTransformConcretizationInfo& other) const {
Expand All @@ -53,8 +58,15 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {

private:
Fusion* fusion_ = nullptr;

// Holds, for each dynamic reshape, the output TensorView, and the result of
// analyzeView
std::vector<std::pair<TensorView*, AnalyzeViewResult>> reshape_transforms_;

// Holds the resized IterDomain (output of the Resize op) along with the
// TensorView where it appears, and its concretized IterType
std::vector<std::pair<IterDomain*, IterType>> resize_transforms_;

friend class DynamicTransformInfoBuilder;
};

Expand Down
16 changes: 15 additions & 1 deletion csrc/ir_internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,25 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
//! is marked as an rfactor domain. For example, expressions such as
//! PadOp and SliceOp resize IterDomains and generate rfactor
//! resized domains.
//!
//! Usually, the IterType of the output IterDomain will be Symbolic. This is
//! because unless the left and right expansions are known at Fusion
//! definition we cannot be sure that the output will have an extent != 1. In
//! case the output extent is in fact 1, we will set the IterType to
//! Broadcast. If the left and right expansions are constant, and sum to at
//! least two, then even an empty input will result in an Iteration IterType.
//! In these cases, we will set the output IterType to Iteration at
//! definition. Otherwise, it will be set to Symbolic and will be resolved
//! when concretization is performed by FusionExecutorCache.
//!
//! The optional iter_type argument can be used to force the output IterType,
//! but for safety its use should typically be confined to concretization.
static IterDomain* resize(
IterDomain* in,
Val* left_expansion,
Val* right_expansion,
bool mark_as_rfactor = false);
bool mark_as_rfactor = false,
std::optional<IterType> iter_type = std::nullopt);

bool isReduction() const {
return getIterType() == IterType::Reduction;
Expand Down
33 changes: 28 additions & 5 deletions csrc/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <device_lower/lower2device.h>
#include <disjoint_set.h>
#include <dynamic_transform.h>
#include <ir_cloner.h>
#include <ir_interface_nodes.h>
#include <ir_iostream.h>
Expand Down Expand Up @@ -2101,9 +2102,12 @@ IterDomain::IterDomain(
is_padded_dimension_(is_padded_dimension),
padded_to_size_(padded_to_size),
is_mma_swizzled_(is_mma_swizzled) {
TORCH_CHECK(
!(isRFactorProduct() && isBroadcast()),
"IterDomain cannot be both a broadcast and rfactor domain.");
// NOTE: We previously asserted !(isRFactorProduct() && isBroadcast()), i.e.
// that an IterDomain could not be both a broadcast and an rfactor domain.
// However, since the introduction of the resize op, we now have a legitimate
// case where this may be true; namely, whenever we resize an IterDomain to
// size 1, we will mark it as Broadcast, but the resize must lie between root
// and rfactor.

TORCH_INTERNAL_ASSERT(
extent->isIntegralScalar(),
Expand Down Expand Up @@ -2459,7 +2463,8 @@ IterDomain* IterDomain::resize(
IterDomain* in,
Val* left_expansion,
Val* right_expansion,
bool mark_as_rfactor) {
bool mark_as_rfactor,
std::optional<IterType> iter_type_opt) {
TORCH_CHECK(
left_expansion->isIntegralScalar(),
"Expansion factor must be an integer scalar: ",
Expand Down Expand Up @@ -2502,10 +2507,28 @@ IterDomain* IterDomain::resize(
right_expansion);
}

// If output IterType is provided, use it. Otherwise, if we can prove the
// resized extent is 1, set to Broadcast, if we can prove it is >1 set to
// Iteration, and otherwise fall back to Symbolic.
IterType iter_type = IterType::Symbolic;
if (iter_type_opt.has_value()) {
iter_type = iter_type_opt.value();
} else if (left_expansion->isConstInt() && right_expansion->isConstInt()) {
if (resized_id_size->isConstInt()) {
// Means input extent is also known
auto out_extent = resized_id_size->evaluateInt();
iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration;
} else if (
left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) {
// Input extent is non-negative, so we know out_extent > 1
iter_type = IterType::Iteration;
}
}

auto resized_id =
IterDomainBuilder(in->container()->zeroVal(), resized_id_size->as<Int>())
.is_rfactor_domain(mark_as_rfactor)
.iter_type(in->getIterType())
.iter_type(iter_type)
.build();

IrBuilder::create<Resize>(
Expand Down
Loading