diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 3f7c4f03cdc..ae1e6e28380 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -305,7 +305,7 @@ void DynamicTransformConcretizationInfo::analyzeResizes( "Found non-dynamic Resize in initial concretization info: ", op->toString()); - auto extent_val = expr_eval->evaluate(out_id->extent()); + auto extent_val = expr_eval->evaluate(out_id->getMaybeExpandedExtent()); NVF_ERROR( extent_val.hasValue(), "Cannot evaluate the extent of a resized domain: ", diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index f71dc030c87..0298fe5d329 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include @@ -342,6 +343,32 @@ void PrecomputedValues::bindTensorMetaData( bindValue(extent->evaluatorIndex(), value); } } + + // Here we bind TensorMetaData so that GetMetaData expressions can be + // evaluated. Note that we do not bind the at::Tensor itself here since that + // would mean PrecomputedValues will own the tensor. Unlike + // ExpressionEvaluator, PrecomputedValues objects are typically long-lived, so + // we do not want them to own large objects. + // To do this we create a temporary ExpressionEvaluator so that we can compute + // the metadata once, then save it + ExpressionEvaluator ee; + ee.bindPrecomputedValues(this); + ee.bind(tv, tensor); + auto metadata_val = IrBuilder::metadataExpr(tv); + auto metadata = ee.evaluate(metadata_val); + // NOTE: In some cases we may not be able to evaluate metadata. For example, + // if there exists a split expression between the root and rfactor domains + // of tv whose split factor is not able to be evaluated. For that reason, + // calling code should ensure that all inputs required to propagate strides + // from root to allocation domains are already bound to "this" before binding + // a TensorView's metadata. + NVF_ERROR( + metadata.hasValue(), + "Could not evaluate metadata expression for ", + tv->toString(), + " with input tensor ", + tensor); + bindValue(metadata_val->evaluatorIndex(), metadata); } NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index 6f9b52ea794..573ae4e1d9e 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -722,6 +722,9 @@ ExpressionEvaluator bindInputs( ExpressionEvaluator expr_eval; const auto& inputs = kernel->inputs(); for (const auto i : c10::irange(inputs.size())) { + // NOTE: we bind all inputs here, including at::Tensors. This means that + // expr_eval will create a PolymorphicValue containing *args[i], which means + // that at::Tensor's lifetime will be at least as long as that of expr_eval. expr_eval.bind(inputs[i], *args[i], true); } diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 5eb3b0c5244..86016e3e295 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -181,6 +181,12 @@ class StructHandle { StructHandle& operator=(const StructHandle& other) = default; StructHandle& operator=(StructHandle&& other) = default; + //! This is a shallow comparison operator that just checks whether we point to + //! the same exact Struct + bool operator==(const StructHandle& other) const { + return struct_ptr_ == other.struct_ptr_; + } + template bool is() const { return std::dynamic_pointer_cast(struct_ptr_) != nullptr; diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 8dc7d61608f..682d6b8fb60 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -2481,6 +2481,27 @@ def fusion_func(fd: FusionDefinition) -> None: self.assertEqual(nvf_out[0].shape, (0, 0)) self.assertEqual(nvf_out[1].shape, (0, 0)) + # Test that a pad of an expanded empty tensor works properly + # See https://github.com/NVIDIA/Fuser/issues/596#issuecomment-1714465618 + def test_pad_expanded_empty(self): + inputs = [ + torch.randn((0,), dtype=torch.float64, device="cuda:0").as_strided( + (2, 0, 3), (0, 0, 0) + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.from_pytorch(inputs[0]) + S1 = fd.define_scalar(-3.70753, dtype=DataType.Double) + T2 = fd.ops.pad(T0, [0, 0, 1, 1, 1, 0], S1) + fd.add_output(T2) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + torch_ref = F.pad(inputs[0], (0, 0, 1, 1, 1, 0), "constant", -3.70753) + + self.assertEqual(nvf_out[0], torch_ref) + if __name__ == "__main__": run_tests() diff --git a/test/test_resize.cpp b/test/test_resize.cpp index fde3162a468..0fbd1b7c27a 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -3107,4 +3107,51 @@ TEST_F(ResizeTest, CatOfExpandedBroadcast) { NVF_CHECK(ref.equal(cg_outputs[0])); } +// Test that an empty input which is expanded in some non-zero directions can be +// padded in the empty dim as well as the expanded dims. +// This should match test_python_frontend.py::test_pad_expanded_empty +// See https://github.com/NVIDIA/Fuser/issues/870 +TEST_F(ResizeTest, PadExpandedEmpty) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto i0 = IrBuilder::create(DataType::Index); + auto i1 = IrBuilder::create(DataType::Index); + auto i2 = IrBuilder::create(DataType::Index); + + auto tv0 = TensorViewBuilder() + .shape({i0, i1, i2}) + .expanded({true, false, true}) + .dtype(DataType::Double) + .build(); + fusion.addInput(tv0); + + auto s0 = IrBuilder::create(-3.70753); + + std::vector pad_widths( + {fusion.zeroVal(DataType::Index), + fusion.zeroVal(DataType::Index), + fusion.oneVal(DataType::Index), + fusion.oneVal(DataType::Index), + fusion.oneVal(DataType::Index), + fusion.zeroVal(DataType::Index)}); + auto tv1 = pad(tv0, pad_widths, s0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + + auto t0 = at::randn({0}, options).as_strided({2, 0, 3}, {0, 0, 0}); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + std::cout << t0 << std::endl; + std::cout << t0.strides() << std::endl; + + testValidate( + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser