diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index cdc80329206..6b1ad79a3dd 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -150,6 +150,41 @@ TensorView* reshape(TensorView* inp_tv, const std::vector& new_sizes) { // domain. std::vector logical_domain(new_sizes.size(), nullptr); bool found_neg_one = false; + + // We don't compute numel unless we need to, and then we compute it only once + // to encourage hoisting + Val* numel = nullptr; + const auto neg_one_size = [&numel, &inp_dom, &new_sizes](size_t pos) { + if (numel == nullptr) { + numel = FusionGuard::getCurFusion()->oneVal(); + for (const auto j : arange(inp_dom.size())) { + numel = SimplifyingIrBuilder::mulExpr(numel, inp_dom.at(j)->extent()); + } + } + + Val* other_new_numel = FusionGuard::getCurFusion()->oneVal(); + for (const auto j : arange(new_sizes.size())) { + if (pos == j) { + continue; + } + Val* new_size = + SimplifyingIrBuilder::maybeCastExpr(DataType::Index, new_sizes.at(j)); + other_new_numel = + SimplifyingIrBuilder::mulExpr(other_new_numel, new_size); + } + // In case numel is 0, other_new_numel might also be 0 and we would hit a + // division by zero. In such cases, using 1 as the denominator will cause + // us to properly compute 0 for this extent. + other_new_numel = SimplifyingIrBuilder::whereExpr( + eq(other_new_numel, FusionGuard::getCurFusion()->zeroVal()), + FusionGuard::getCurFusion()->oneVal(), + other_new_numel); + + Val* new_size = SimplifyingIrBuilder::divExpr(numel, other_new_numel); + NVF_ERROR(new_size->dtype() == DataType::Index); + return simplifyExpr(new_size); + }; + for (const auto i : arange(new_sizes.size())) { auto new_size = new_sizes.at(i); if (new_size->isConstScalar() && new_size->evaluate().as() == -1) { @@ -162,20 +197,13 @@ TensorView* reshape(TensorView* inp_tv, const std::vector& new_sizes) { "A maximum of one value of -1 can be provided to reshape."); found_neg_one = true; - Val* numel = FusionGuard::getCurFusion()->oneVal(); - Val* other_new_numel = FusionGuard::getCurFusion()->oneVal(); - for (const auto j : arange(inp_dom.size())) { - numel = SimplifyingIrBuilder::mulExpr(numel, inp_dom.at(j)->extent()); - } - for (const auto j : arange(new_sizes.size())) { - if (i == j) { - continue; - } - other_new_numel = - SimplifyingIrBuilder::mulExpr(other_new_numel, new_sizes.at(j)); - } - new_size = SimplifyingIrBuilder::divExpr(numel, other_new_numel); - new_size = simplifyExpr(new_size); + new_size = neg_one_size(i); + } else if (!new_size->isConstScalar()) { + // Dynamic size could be -1. Here we build a correct shape expression + new_size = where( + eq(new_size, IrBuilder::create(-1L, DataType::Index)), + neg_one_size(i), + new_size); } new_size = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, new_size); auto rf_id = diff --git a/tests/cpp/test_dynamic_transform.cpp b/tests/cpp/test_dynamic_transform.cpp index 090ac43af23..81d37b8e57c 100644 --- a/tests/cpp/test_dynamic_transform.cpp +++ b/tests/cpp/test_dynamic_transform.cpp @@ -88,17 +88,19 @@ TEST_F(DynamicTransformTest, DynamicTransform1) { expr_eval.bind(tv0->axis(1)->extent(), 3L); expr_eval.bind(reshape_shape0, 3L); expr_eval.bind(reshape_shape1, -1L); + // We cannot infer the shape of tv1 from the above bound values, since + // either axis of tv2 might be broadcast against one from tv1. + expr_eval.bind(tv1->axis(0)->extent(), 3L); + expr_eval.bind(tv1->axis(1)->extent(), 4L); // This should throw an exception since any reshape size of -1 must be // specified as a definition-time constant, as opposed to an input scalar. - EXPECT_THAT( - [&]() { - auto initial_info = DynamicTransform::getInitialInfo(&fusion); - auto info = - DynamicTransformConcretizationInfo(&initial_info, &expr_eval); - }, - ::testing::ThrowsMessage(::testing::HasSubstr( - "Values of -1 passed to reshape must be constant at definition"))); + auto initial_info = DynamicTransform::getInitialInfo(&fusion); + auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); + NVF_CHECK( + info.getReshapeTransforms().size() == 1, + "Expected to have one reshape transform: ", + info.toString()); } { @@ -858,17 +860,14 @@ TEST_F(DynamicTransformTest, FusionDynamicReshapeReductionShmoo) { {{8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9}, false}, // merge(1) osplit(1, 3) // test passing -1 dynamically for dimension size - // This is unsupported. See https://github.com/NVIDIA/Fuser/issues/249 - // Values of -1 must be passed as constants instead of input-dependent - // scalars. - //{{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false} // merge(1) osplit(1, 3) + {{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false}, // merge(1) osplit(1, 3) // Empty reshapes should translate to FullOp {{8, 0, 7, 9}, {7, 8, 0, 9}, true}, // symbolic_sizes = [ -1, -1, 0, -1 ] - // In the case below there's now a separate Val introduced for the output - // extent, which is zero. This is represented in - // DynamicTransformConcretizationInfo causing cache miss - {{8, 0, 7, 9}, {7, 8, -1, 9}, true}, // symbolic_sizes = [ -1, -1, 0, -1 ] + // This hits the same conc info as {8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9} + {{8, 0, 7, 9}, + {7, 8, -1, 9}, + false}, // symbolic_sizes = [ -1, -1, 0, -1 ] {{8, 0, 7, 9}, {7, 8, 0, 0}, true}, // symbolic_sizes = [ -1, -1, 0, 0 ] {{8, 0, 7, 9}, {47, 0, 13, 0}, true}, // symbolic_sizes = [ -1, 0, -1, 0 ] }; @@ -1160,10 +1159,8 @@ TEST_F(DynamicTransformTest, Issue249InputNegative1) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_x = at::randn({2, 3, 4, 5}, options); - // Dynamic reshape sizes that are not constant at definition must be explicit: - // no -1 allowed - EXPECT_THROW( - executor_cache.runFusionWithInputs({at_x, 2, 4, -1}), std::exception); + // Test that running with dynamic -1 works as expected + executor_cache.runFusionWithInputs({at_x, 2, 4, -1}); // Passing explicit sizes works fine auto outputs = executor_cache.runFusionWithInputs({at_x, 2, 4, 15}); diff --git a/tests/cpp/test_evaluator.cpp b/tests/cpp/test_evaluator.cpp index a29d22a3e19..75c4e1b005e 100644 --- a/tests/cpp/test_evaluator.cpp +++ b/tests/cpp/test_evaluator.cpp @@ -825,4 +825,28 @@ TEST_F(ExprEvalTest, NamedScalar) { EXPECT_EQ(cache_id_pvalue.as(), kCacheIdValue); } +TEST_F(ExprEvalTest, View_FlattenToMinusOneAsInput) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* in = makeSymbolicTensor(2); + auto* out_dim_size = IrBuilder::create(DataType::Int); + TensorView* out = reshape(in, {out_dim_size}); + fusion.addInput(in); + fusion.addInput(out_dim_size); + fusion.addOutput(out); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + at::Tensor in_tensor = at::randn({2, 3}, options); + + ExpressionEvaluator evaluator; + evaluator.bind(in, in_tensor); + evaluator.bind(out_dim_size, -1); + auto out_tensor = evaluator.evaluate(out).as(); + + EXPECT_TRUE(at::allclose(out_tensor, in_tensor.view({-1}))); +} + } // namespace nvfuser