Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,20 +627,40 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {

// Determine the output IterType
IterType iter_type = IterType::Symbolic;
for (auto inp_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
const auto input_ids =
ir_utils::filterByType<IterDomain>(expr->inputs()).vector();
for (auto i : c10::irange(input_ids.size())) {
auto inp_id = input_ids.at(i);
auto updated_id = maybeMutated(inp_id)->as<IterDomain>();
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
TORCH_CHECK(
updated_id == inp_id || !updated_id->isSymbolic(),
"Mutated IterDomains between root and rfactor should not be symbolic");
if (i == 0) {
// ops::promoteIterType will favor Symbolic if it encounters it
// alongside Broadcast. This is preferable at fusion definition, but
// here we are propagating, and if we only see Broadcast in some
// dimension, then we should not retain Symbolic. To work around this,
// we always overwrite Symbolic with the first concrete IterType we
// encounter.
iter_type = updated_id->getIterType();
} else {
iter_type =
ops::promoteIterType(iter_type, updated_id->getIterType());
}
}
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
expr->toString());

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
continue;
}

// If out_id is Symbolic, we need to concretize it here. If we did not
// yet determine its IterType, then we've missed our chance.
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
expr->toString());

auto concretized_out_id =
IterDomainBuilder(maybeMutated(out_id)->as<IterDomain>())
.iter_type(iter_type)
Expand Down
18 changes: 16 additions & 2 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2783,6 +2783,19 @@ IterDomain* IterDomain::resize(
"Expansion factor must be an integer scalar: ",
right_expansion->toString());

if (left_expansion->isConstInt() && right_expansion->isConstInt()) {
auto left = left_expansion->evaluateInt();
auto right = right_expansion->evaluateInt();
if (left == 0 && right == 0) {
// This is a trivial resize. Check that we are not changing the IterType,
// then return the input.
TORCH_CHECK(
!iter_type_opt.has_value() ||
iter_type_opt.value() == in->getIterType(),
"If IterType is specified in pad with zero expansion then it must match input");
return in;
}
}
TORCH_CHECK(
in->getIterType() == IterType::Iteration ||
in->getIterType() == IterType::Broadcast ||
Expand Down Expand Up @@ -2823,12 +2836,13 @@ IterDomain* IterDomain::resize(
if (iter_type_opt.has_value()) {
iter_type = iter_type_opt.value();
} else if (left_expansion->isConstInt() && right_expansion->isConstInt()) {
auto left = left_expansion->evaluateInt();
auto right = right_expansion->evaluateInt();
if (resized_id_size->isConstInt()) {
// Means input extent is also known
auto out_extent = resized_id_size->evaluateInt();
iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration;
} else if (
left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) {
} else if (left + right > 1) {
// Input extent is non-negative, so we know out_extent > 1
iter_type = IterType::Iteration;
}
Expand Down
40 changes: 40 additions & 0 deletions test/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,46 @@ TEST_F(ResizeTest, FusionResizePadScheduler4) {
__FILE__);
}

// Pad a broadcast
// See https://github.com/NVIDIA/Fuser/issues/798
TEST_F(ResizeTest, FusionResizePadBroadcastInput) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// IterTypes are {Broadcast, Iteration}
auto tv0 = makeConcreteTensor({1, -1});
fusion->addInput(tv0);

// trivial pad of broadcast dimension
auto tv1 =
pad(tv0,
{fusion->oneVal(),
fusion->zeroVal(),
fusion->zeroVal(),
fusion->zeroVal()});
fusion->addOutput(tv1);

std::vector<int64_t> shape({1, 2});

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto t1 = at::pad(t0, {1, 0, 0, 0});

testValidate(
executor_cache.fusion(),
cg_outputs,
aten_inputs,
{t1},
__LINE__,
__FILE__);
}

// Trivial cat
TEST_F(ResizeTest, FusionResizeCat1) {
Fusion fusion;
Expand Down