Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,41 @@ TensorView* reshape(TensorView* inp_tv, const std::vector<Val*>& new_sizes) {
// domain.
std::vector<IterDomain*> logical_domain(new_sizes.size(), nullptr);
bool found_neg_one = false;

// We don't compute numel unless we need to, and then we compute it only once
// to encourage hoisting
Val* numel = nullptr;
const auto neg_one_size = [&numel, &inp_dom, &new_sizes](size_t pos) {
if (numel == nullptr) {
numel = FusionGuard::getCurFusion()->oneVal();
for (const auto j : arange(inp_dom.size())) {
numel = SimplifyingIrBuilder::mulExpr(numel, inp_dom.at(j)->extent());
}
}

Val* other_new_numel = FusionGuard::getCurFusion()->oneVal();
for (const auto j : arange(new_sizes.size())) {
if (pos == j) {
continue;
}
Val* new_size =
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, new_sizes.at(j));
other_new_numel =
SimplifyingIrBuilder::mulExpr(other_new_numel, new_size);
}
// In case numel is 0, other_new_numel might also be 0 and we would hit a
// division by zero. In such cases, using 1 as the denominator will cause
// us to properly compute 0 for this extent.
other_new_numel = SimplifyingIrBuilder::whereExpr(
eq(other_new_numel, FusionGuard::getCurFusion()->zeroVal()),
FusionGuard::getCurFusion()->oneVal(),
other_new_numel);

Val* new_size = SimplifyingIrBuilder::divExpr(numel, other_new_numel);
NVF_ERROR(new_size->dtype() == DataType::Index);
return simplifyExpr(new_size);
};

for (const auto i : arange(new_sizes.size())) {
auto new_size = new_sizes.at(i);
if (new_size->isConstScalar() && new_size->evaluate().as<int64_t>() == -1) {
Expand All @@ -162,20 +197,13 @@ TensorView* reshape(TensorView* inp_tv, const std::vector<Val*>& new_sizes) {
"A maximum of one value of -1 can be provided to reshape.");
found_neg_one = true;

Val* numel = FusionGuard::getCurFusion()->oneVal();
Val* other_new_numel = FusionGuard::getCurFusion()->oneVal();
for (const auto j : arange(inp_dom.size())) {
numel = SimplifyingIrBuilder::mulExpr(numel, inp_dom.at(j)->extent());
}
for (const auto j : arange(new_sizes.size())) {
if (i == j) {
continue;
}
other_new_numel =
SimplifyingIrBuilder::mulExpr(other_new_numel, new_sizes.at(j));
}
new_size = SimplifyingIrBuilder::divExpr(numel, other_new_numel);
new_size = simplifyExpr(new_size);
new_size = neg_one_size(i);
} else if (!new_size->isConstScalar()) {
// Dynamic size could be -1. Here we build a correct shape expression
new_size = where(
eq(new_size, IrBuilder::create<Val>(-1L, DataType::Index)),
neg_one_size(i),
new_size);
}
new_size = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, new_size);
auto rf_id =
Expand Down
37 changes: 17 additions & 20 deletions tests/cpp/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ TEST_F(DynamicTransformTest, DynamicTransform1) {
expr_eval.bind(tv0->axis(1)->extent(), 3L);
expr_eval.bind(reshape_shape0, 3L);
expr_eval.bind(reshape_shape1, -1L);
// We cannot infer the shape of tv1 from the above bound values, since
// either axis of tv2 might be broadcast against one from tv1.
expr_eval.bind(tv1->axis(0)->extent(), 3L);
expr_eval.bind(tv1->axis(1)->extent(), 4L);

// This should throw an exception since any reshape size of -1 must be
// specified as a definition-time constant, as opposed to an input scalar.
EXPECT_THAT(
[&]() {
auto initial_info = DynamicTransform::getInitialInfo(&fusion);
auto info =
DynamicTransformConcretizationInfo(&initial_info, &expr_eval);
},
::testing::ThrowsMessage<nvfError>(::testing::HasSubstr(
"Values of -1 passed to reshape must be constant at definition")));
auto initial_info = DynamicTransform::getInitialInfo(&fusion);
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);
NVF_CHECK(
info.getReshapeTransforms().size() == 1,
"Expected to have one reshape transform: ",
info.toString());
}

{
Expand Down Expand Up @@ -858,17 +860,14 @@ TEST_F(DynamicTransformTest, FusionDynamicReshapeReductionShmoo) {
{{8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9}, false}, // merge(1) osplit(1, 3)

// test passing -1 dynamically for dimension size
// This is unsupported. See https://github.com/NVIDIA/Fuser/issues/249
// Values of -1 must be passed as constants instead of input-dependent
// scalars.
//{{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false} // merge(1) osplit(1, 3)
{{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false}, // merge(1) osplit(1, 3)

// Empty reshapes should translate to FullOp
{{8, 0, 7, 9}, {7, 8, 0, 9}, true}, // symbolic_sizes = [ -1, -1, 0, -1 ]
// In the case below there's now a separate Val introduced for the output
// extent, which is zero. This is represented in
// DynamicTransformConcretizationInfo causing cache miss
{{8, 0, 7, 9}, {7, 8, -1, 9}, true}, // symbolic_sizes = [ -1, -1, 0, -1 ]
// This hits the same conc info as {8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9}
{{8, 0, 7, 9},
{7, 8, -1, 9},
false}, // symbolic_sizes = [ -1, -1, 0, -1 ]
{{8, 0, 7, 9}, {7, 8, 0, 0}, true}, // symbolic_sizes = [ -1, -1, 0, 0 ]
{{8, 0, 7, 9}, {47, 0, 13, 0}, true}, // symbolic_sizes = [ -1, 0, -1, 0 ]
};
Expand Down Expand Up @@ -1160,10 +1159,8 @@ TEST_F(DynamicTransformTest, Issue249InputNegative1) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at_x = at::randn({2, 3, 4, 5}, options);

// Dynamic reshape sizes that are not constant at definition must be explicit:
// no -1 allowed
EXPECT_THROW(
executor_cache.runFusionWithInputs({at_x, 2, 4, -1}), std::exception);
// Test that running with dynamic -1 works as expected
executor_cache.runFusionWithInputs({at_x, 2, 4, -1});

// Passing explicit sizes works fine
auto outputs = executor_cache.runFusionWithInputs({at_x, 2, 4, 15});
Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/test_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,4 +825,28 @@ TEST_F(ExprEvalTest, NamedScalar) {
EXPECT_EQ(cache_id_pvalue.as<int64_t>(), kCacheIdValue);
}

TEST_F(ExprEvalTest, View_FlattenToMinusOneAsInput) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeSymbolicTensor(2);
auto* out_dim_size = IrBuilder::create<Val>(DataType::Int);
TensorView* out = reshape(in, {out_dim_size});
fusion.addInput(in);
fusion.addInput(out_dim_size);
fusion.addOutput(out);

fusion.printMath();

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
at::Tensor in_tensor = at::randn({2, 3}, options);

ExpressionEvaluator evaluator;
evaluator.bind(in, in_tensor);
evaluator.bind(out_dim_size, -1);
auto out_tensor = evaluator.evaluate(out).as<at::Tensor>();

EXPECT_TRUE(at::allclose(out_tensor, in_tensor.view({-1})));
}

} // namespace nvfuser