diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 6c48be0a453..9c6d9c3b914 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -30,6 +30,9 @@ class DynamicTransformInfoBuilder : public IterVisitor { // Analyze a dynamic reshape and generate AnalyzeViewResult void handle(ViewOp* op) override; + // We handle IterDomain "Resize" ops at TensorView level + void handle(TensorView* tv) override; + const auto& getInfo() const { return info_; } @@ -77,6 +80,14 @@ bool DynamicTransformConcretizationInfo::operator==( } } + for (const auto i : c10::irange(resize_transforms_.size())) { + const auto& transform = resize_transforms_.at(i); + const auto& other_transform = other.resize_transforms_.at(i); + if (transform != other_transform) { + return false; + } + } + return true; } @@ -84,13 +95,20 @@ DynamicTransformConcretizationInfo DynamicTransformConcretizationInfo::clone( IrCloner& ir_cloner) const { DynamicTransformConcretizationInfo cloned_info( (Fusion*)ir_cloner.container()); - for (auto& pair : reshape_transforms_) { + for (const auto& [tv, analyze_result] : reshape_transforms_) { cloned_info.reshape_transforms_.emplace_back( - ir_cloner.clone(pair.first), + ir_cloner.clone(tv), // reshape_transforms_ holds pairs of TensorView* and AnalyzeViewResult // AnalyzeViewResult can be copied directly as it holds no references to // Statements that would need cloning, only integer indices of axes. - pair.second); + analyze_result); + } + for (const auto& [id, iter_type] : resize_transforms_) { + cloned_info.resize_transforms_.emplace_back( + ir_cloner.clone(id), + // Similar to reshape_transforms_, we only clone the IterDomains in + // resize_transforms_ + iter_type); } return cloned_info; } @@ -104,9 +122,56 @@ std::string DynamicTransformConcretizationInfo::toString() const { ss << indent << indent << kv.first->toString() << ", " << kv.second.toString() << "\n"; } + ss << indent << "Resize:\n"; + for (const auto& [id, iter_type] : resize_transforms_) { + ss << indent << indent << id->toString() << ", " << iter_type << "\n"; + } return ss.str(); } +void DynamicTransformInfoBuilder::handle(TensorView* tv) { + const auto& rfd = tv->getMaybeRFactorDomain(); + for (auto id : rfd) { + if (!id->definition()) { + continue; + } + if (auto op = dynamic_cast(id->definition()); + id->getIterType() == IterType::Symbolic && op != nullptr) { + auto out_extent_val = expr_eval_->evaluate(id->extent()); + TORCH_INTERNAL_ASSERT( + out_extent_val.has_value(), + "Cannot evaluate the extent of a resized IterDomain: ", + id->toString()); + + auto in_id = op->in()->as(); + auto in_extent_val = expr_eval_->evaluate(in_id->extent()); + TORCH_INTERNAL_ASSERT( + in_extent_val.has_value(), + "Cannot evaluate the extent of input to an IterDomain resize: ", + in_id->toString()); + + auto left = op->leftExpand()->as(); + auto left_val = expr_eval_->evaluate(left); + TORCH_INTERNAL_ASSERT( + left_val.has_value(), + "Cannot evaluate the left expansion of an IterDomain resize: ", + left->toString()); + + auto right = op->rightExpand()->as(); + auto right_val = expr_eval_->evaluate(right); + TORCH_INTERNAL_ASSERT( + right_val.has_value(), + "Cannot evaluate the right expansion of an IterDomain resize: ", + right->toString()); + + auto out_itertype = out_extent_val->as() == 1 + ? IterType::Broadcast + : IterType::Iteration; + info_.resize_transforms_.emplace_back(id, out_itertype); + } + } +} + void DynamicTransformInfoBuilder::handle(ViewOp* op) { auto inp_tv = op->in()->as(); auto out_tv = op->out()->as(); @@ -204,6 +269,8 @@ class DynamicTransformConcretizer : public OptOutMutator { void concretizeReshape(); + void concretizeResize(); + using OptOutMutator::mutate; void mutate(TensorView* tv) final; @@ -216,15 +283,17 @@ class DynamicTransformConcretizer : public OptOutMutator { private: const DynamicTransformConcretizationInfo& info_; - std::unordered_map update_map_; }; void DynamicTransformConcretizer::concretize() { // First, concretize all dynamic reshape ops concretizeReshape(); - // Second, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_.fusion(), false); + // Set output IterTypes for dynamic resize ops + concretizeResize(); + + // Finally, propagate concretized domains + auto all_stmts = StmtSort::getStmts(info_.fusion(), true); for (auto stmt : all_stmts) { if (stmt->isA()) { mutate(stmt); @@ -257,6 +326,24 @@ void DynamicTransformConcretizer::concretizeReshape() { } } +void DynamicTransformConcretizer::concretizeResize() { + // Concretize each resize op. + for (const auto& [id, iter_type] : info_.getResizeTransforms()) { + TORCH_CHECK( + id->definition() && id->definition()->isA(), + "Resized IterDomain must have a Resize definition"); + auto def = id->definition()->as(); + auto new_id = IterDomain::resize( + def->in(), + def->leftExpand(), + def->rightExpand(), + id->isRFactorProduct(), + iter_type); + + registerMutation(id, new_id); + } +} + // Concretizes inherited symbolic domains. Note that when this is // called, it is assumed that all dynamic ops themselves are // concretized. Since symbolic IDs may be propagated down to @@ -268,15 +355,12 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // First, try to concretize the root domain as there may be symbolic // axes inherited from the producers - auto propagated = propagateFromProducerToConsumer(tv); - - // If no root domain is altered, nothing to do further - if (!propagated) { - return; - } + propagateFromProducerToConsumer(tv); - // Root IDs are altered. Need to propagate the changes to rfactor - // domain + // If no root domain is altered by producer, we don't need to propagate back + // up to rfactor. We could return early, but instead we go ahead and check the + // root to rfactor transforms to be sure we have concretized any intermediate + // IterDomains. // At this point, there should be no expr beyond rfactor root TORCH_INTERNAL_ASSERT( @@ -308,20 +392,21 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { ". IterDomain was expected."); } - // If none of the output IDs is symbolic, nothing to concretize - if (std::all_of( - expr->outputs().begin(), expr->outputs().end(), [](Val* output) { - return output->as()->getIterType() != - IterType::Symbolic; - })) { - continue; - } - // If any of output IDs is symbolic, all outputs should be symbolic - TORCH_INTERNAL_ASSERT(std::all_of( - expr->outputs().begin(), expr->outputs().end(), [](Val* output) { - return output->as()->getIterType() == - IterType::Symbolic; - })); + // NOTE: We do not return early if all outputs are concrete as there may + // still be concrete inputs. For example, a Symbolic IterDomain might be + // padded with constant pad widths (1, 1), in which case although we do + // not know the exact extent of the output, we know it is at least as + // large as the sum of the pad widths, 2. In such cases, the output + // IterDomain is concrete at definition, since if the extent is >1 we know + // the IterType is Iteration. In these cases, we must continue to + // concretize intermediate expressions between the root and R-factor + // domain. See test DynamicTransform5_CUDA which demonstrates this + // behavior. + // NOTE: We also do not assume that if one output ID is symbolic, that + // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example + // that does a static split by a factor of 16 of a symbolic input domain. + // The static split in that case results in a concrete IterDomain with + // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). // Determine the output IterType IterType iter_type = IterType::Symbolic; @@ -336,13 +421,13 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { - auto concreteized_out_id = + auto concretized_out_id = IterDomainBuilder(out_id).iter_type(iter_type).build(); - registerMutation(out_id, concreteized_out_id); + registerMutation(out_id, concretized_out_id); } - // Outputs are mutated. The expr itself needs to be mutated as - // well, which can be done by the mutate method + // The expr itself needs to be mutated as well in case the outputs are + // mutated, which can be done by the mutate method OptOutMutator::mutate(expr); } } @@ -457,7 +542,14 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( } TORCH_INTERNAL_ASSERT( - id_type.has_value() && id_type != IterType::Symbolic, + id_type.has_value(), + "Did not find id_type for consumer root domain ", + root_id->toString(), + ". Perhaps consumer def has no inputs. Consumer definition = ", + def->toString()); + + TORCH_INTERNAL_ASSERT( + id_type != IterType::Symbolic, "Failed to concretize ", root_id->toString(), " of ", diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 4d4b6477218..8dc21f60c3a 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -27,11 +27,16 @@ class DynamicTransformInfoBuilder; //! of the fusion inputs class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { public: - const std::vector> + const std::vector>& getReshapeTransforms() const { return reshape_transforms_; } + const std::vector>& getResizeTransforms() + const { + return resize_transforms_; + } + bool operator==(const DynamicTransformConcretizationInfo& other) const; bool operator!=(const DynamicTransformConcretizationInfo& other) const { @@ -53,8 +58,15 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { private: Fusion* fusion_ = nullptr; + + // Holds, for each dynamic reshape, the output TensorView, and the result of + // analyzeView std::vector> reshape_transforms_; + // Holds the resized IterDomain (output of the Resize op) along with the + // TensorView where it appears, and its concretized IterType + std::vector> resize_transforms_; + friend class DynamicTransformInfoBuilder; }; diff --git a/csrc/ir_internal_base_nodes.h b/csrc/ir_internal_base_nodes.h index ca2e82c930d..dbd29fc7c15 100644 --- a/csrc/ir_internal_base_nodes.h +++ b/csrc/ir_internal_base_nodes.h @@ -152,11 +152,25 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! is marked as an rfactor domain. For example, expressions such as //! PadOp and SliceOp resize IterDomains and generate rfactor //! resized domains. + //! + //! Usually, the IterType of the output IterDomain will be Symbolic. This is + //! because unless the left and right expansions are known at Fusion + //! definition we cannot be sure that the output will have an extent != 1. In + //! case the output extent is in fact 1, we will set the IterType to + //! Broadcast. If the left and right expansions are constant, and sum to at + //! least two, then even an empty input will result in an Iteration IterType. + //! In these cases, we will set the output IterType to Iteration at + //! definition. Otherwise, it will be set to Symbolic and will be resolved + //! when concretization is performed by FusionExecutorCache. + //! + //! The optional iter_type argument can be used to force the output IterType, + //! but for safety its use should typically be confined to concretization. static IterDomain* resize( IterDomain* in, Val* left_expansion, Val* right_expansion, - bool mark_as_rfactor = false); + bool mark_as_rfactor = false, + std::optional iter_type = std::nullopt); bool isReduction() const { return getIterType() == IterType::Reduction; diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index 8ef30f9283f..4c4634841a4 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -2101,9 +2102,12 @@ IterDomain::IterDomain( is_padded_dimension_(is_padded_dimension), padded_to_size_(padded_to_size), is_mma_swizzled_(is_mma_swizzled) { - TORCH_CHECK( - !(isRFactorProduct() && isBroadcast()), - "IterDomain cannot be both a broadcast and rfactor domain."); + // NOTE: We previously asserted !(isRFactorProduct() && isBroadcast()), i.e. + // that an IterDomain could not be both a broadcast and an rfactor domain. + // However, since the introduction of the resize op, we now have a legitimate + // case where this may be true; namely, whenever we resize an IterDomain to + // size 1, we will mark it as Broadcast, but the resize must lie between root + // and rfactor. TORCH_INTERNAL_ASSERT( extent->isIntegralScalar(), @@ -2459,7 +2463,8 @@ IterDomain* IterDomain::resize( IterDomain* in, Val* left_expansion, Val* right_expansion, - bool mark_as_rfactor) { + bool mark_as_rfactor, + std::optional iter_type_opt) { TORCH_CHECK( left_expansion->isIntegralScalar(), "Expansion factor must be an integer scalar: ", @@ -2502,10 +2507,28 @@ IterDomain* IterDomain::resize( right_expansion); } + // If output IterType is provided, use it. Otherwise, if we can prove the + // resized extent is 1, set to Broadcast, if we can prove it is >1 set to + // Iteration, and otherwise fall back to Symbolic. + IterType iter_type = IterType::Symbolic; + if (iter_type_opt.has_value()) { + iter_type = iter_type_opt.value(); + } else if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + if (resized_id_size->isConstInt()) { + // Means input extent is also known + auto out_extent = resized_id_size->evaluateInt(); + iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration; + } else if ( + left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) { + // Input extent is non-negative, so we know out_extent > 1 + iter_type = IterType::Iteration; + } + } + auto resized_id = IterDomainBuilder(in->container()->zeroVal(), resized_id_size->as()) .is_rfactor_domain(mark_as_rfactor) - .iter_type(in->getIterType()) + .iter_type(iter_type) .build(); IrBuilder::create( diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index a12348ef9a4..47eaa22b5fd 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -112,8 +112,7 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( } FusionExecutorCache::FusionExecutorCache(std::unique_ptr fusion) - : fusion_(std::move(fusion)), - has_dynamic_reshape_(fusion_->hasDynamicTransform()) {} + : fusion_(std::move(fusion)) {} KernelArgumentHolder FusionExecutorCache::prepareInputs( const at::ArrayRef& inputs) { @@ -131,7 +130,7 @@ KernelArgumentHolder FusionExecutorCache::prepareInputs( // short-circuiting here, resulting in avoidable rebuilds of concretization // info. auto id_lookup_ret = - inputs_id_lookup_.lookupId(inputs, /*hash_scalars*/ has_dynamic_reshape_); + inputs_id_lookup_.lookupId(inputs, /*hash_scalars*/ isDynamic()); if (id_lookup_ret.eviction) { evictCache(id_lookup_ret.evict_id); } @@ -375,7 +374,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( // will be used only as a cache key. std::optional conc_info = std::nullopt; size_t conc_info_index = 0; - if (has_dynamic_reshape_) { + if (isDynamic()) { conc_info = DynamicTransform::getConcretizationInfo(fusion_.get(), &args); TORCH_CHECK( conc_info.has_value(), @@ -426,7 +425,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( // concretize fusion_ for use in this runtime auto fusion = std::make_unique(*fusion_); FusionGuard fg(fusion.get()); - if (has_dynamic_reshape_) { + if (isDynamic()) { const auto& cloned_conc_info = fusion->getManagedSafe( conc_info_index); @@ -451,7 +450,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } } - if (has_dynamic_reshape_) { + if (isDynamic()) { // In the case of cache hits, we tend to accumulate managed data in // fusion_. Here we release the concretization info we created to avoid // cloning more and more entries. diff --git a/csrc/kernel_cache.h b/csrc/kernel_cache.h index 9d6b9d624e5..dee82055b98 100644 --- a/csrc/kernel_cache.h +++ b/csrc/kernel_cache.h @@ -516,6 +516,19 @@ class TORCH_CUDA_CU_API FusionExecutorCache { const KernelArgumentHolder& inputs, std::optional forced_index_type = std::nullopt); + //! Check whether the input `fusion_` has dynamic elements such as non-static + //! reshapes. Note that `fusion_` might be updated after initializing + //! `FusionExecutorCache` as is done by `FusionDefinition` in the Python + //! frontend. In that case care must be taken to delay this check until the + //! entire Fusion is defined. For that reason, this function is private, and + //! should only be called inside runFusionWithInputs. + bool isDynamic() { + if (!is_dynamic_.has_value()) { + is_dynamic_ = fusion_->hasDynamicTransform(); + } + return is_dynamic_.value(); + } + private: //! original un-scheduled `Fusion`. This may contain dynamic transforms and //! Symbolic IterDomains. @@ -551,8 +564,9 @@ class TORCH_CUDA_CU_API FusionExecutorCache { //! caching profiles. Currently it just makes it easier to test FusionKernelRuntime* most_recent_runtime_ = nullptr; - //! Whether fusion_ contains dynamic reshapes - bool has_dynamic_reshape_ = false; + //! Whether fusion_ contains dynamic reshapes. This is cached by + //! `fusionIsDynamic()` + std::optional is_dynamic_ = std::nullopt; }; class GraphCache { diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index c84db4deeda..eb7515173ee 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -94,7 +94,7 @@ std::shared_ptr innerReductionHeuristic( std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1))); // Conservative value, could be set to larger based on arch if necessary. - constexpr int64_t l1_cache = 32 * 1024; + constexpr int64_t l1_cache = (int64_t)32 * 1024; // Could change per generation, but for l1 we want to consider active threads, // not resident constexpr int64_t active_threads = 1024; diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 99ea4e377cc..cd7d835f9a3 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -386,7 +386,7 @@ BestEffortReplay::BestEffortReplay( } if (skip_resize) { - skipResizes(); + skipResizes(target_exprs, replay_exprs); } std::string err_str( @@ -626,7 +626,7 @@ BestEffortReplay::BestEffortReplay( } if (skip_resize) { - skipResizes(); + skipResizes(target_exprs, replay_exprs); } } } @@ -1089,9 +1089,18 @@ void BestEffortReplay::skipSwizzles( } // Same logic as skipSwizzles -void BestEffortReplay::skipResizes() { - auto isResizeInput = [](IterDomain* id) -> bool { - return id->uses().size() == 1 && id->uses().front()->isA(); +void BestEffortReplay::skipResizes( + const std::vector& target_exprs, + const std::vector& replay_exprs) { + auto getResizeUse = [](IterDomain* id, + const std::vector& exprs) -> Resize* { + for (auto id_use : id->uses()) { + if (std::find(exprs.begin(), exprs.end(), id_use) == exprs.end()) { + continue; + } + return dynamic_cast(id_use); + } + return nullptr; }; bool updated = true; @@ -1103,11 +1112,13 @@ void BestEffortReplay::skipResizes() { auto new_target_id = target_id; auto replay_id = it.second; auto new_replay_id = replay_id; - if (isResizeInput(target_id)) { - new_target_id = target_id->uses().front()->as()->out(); + if (auto target_resize = getResizeUse(target_id, target_exprs); + target_resize != nullptr) { + new_target_id = target_resize->out(); } - if (isResizeInput(replay_id)) { - new_replay_id = replay_id->uses().front()->as()->out(); + if (auto replay_resize = getResizeUse(replay_id, replay_exprs); + replay_resize != nullptr) { + new_replay_id = replay_resize->out(); } if (new_target_id == target_id && new_replay_id == replay_id) { diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 076ece21f95..9ad78b46980 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -306,7 +306,9 @@ class TORCH_CUDA_CU_API BestEffortReplay { const std::unordered_map& replay_id2expr); // Skip resize in both target and replay domains - void skipResizes(); + void skipResizes( + const std::vector& target_exprs, + const std::vector& replay_exprs); public: // When skip_resize is true, resize is ignored or in other words forwarded diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 2c6f9491e98..50566371655 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -747,7 +747,7 @@ void reductionDynamicViewAddFusion( (reshape_before_reduction) ? add(x, bias) : sum(x, {kReductionAxis}); // create vectors of input scalars describing this reshape std::vector output_shape(output_dims); - for (int i : c10::irange(output_dims)) { + for (size_t i : c10::irange(output_dims)) { output_shape[i] = IrBuilder::create(); fusion.addInput(output_shape[i]); } @@ -784,6 +784,9 @@ void reductionDynamicViewAddFusion( auto output_shape = std::get<1>(inv); auto expect_miss = std::get<2>(inv); + TORCH_INTERNAL_ASSERT(input_shape.size() == input_dims); + TORCH_INTERNAL_ASSERT(output_shape.size() == output_dims); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_x = at::randn(input_shape, options); @@ -802,14 +805,14 @@ void reductionDynamicViewAddFusion( } } if (negone_dim >= 0) { - bias_shape[negone_dim] = at_x.numel() / other_numel; + bias_shape[negone_dim] = (int64_t)at_x.numel() / (int64_t)other_numel; } } at::Tensor at_bias = at::randn(bias_shape, options); std::vector aten_inputs = {at_x, at_bias}; // Add input scalars describing the reshape size for concretization - for (int i : c10::irange(output_dims)) { - aten_inputs.push_back(output_shape[i]); + for (size_t i : c10::irange(output_dims)) { + aten_inputs.emplace_back(output_shape[i]); } auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); @@ -838,11 +841,117 @@ TEST_F(NVFuserTest, FusionDynamicReshapeReductionShmoo_CUDA) { {8, 3, 4 * 7, 5}, false}, // merge(1) merge(2) osplit(1, 3) {{8, 3 * 5, 7, 9}, {8, 3, 5 * 7, 9}, false}, // merge(1) osplit(1, 3) + // test passing -1 dynamically for dimension size + // This currently fails. see https://github.com/NVIDIA/Fuser/issues/249 //{{8, 3 * 5, 7, 9}, {8, 3, -1, 9}, false} // merge(1) osplit(1, 3) }; reductionDynamicViewAddFusion( invocations, true /* reshape_before_reduction */); } +using dynamic_pad_invocation = std::tuple< + std::vector, // input_shape + std::vector, // pad_widths + bool // expect miss + >; + +void reductionDynamicPadAddFusion( + std::vector& invocations) { + constexpr int kReductionAxis = -1; + + auto input_dims = std::get<0>(invocations[0]).size(); + auto num_pad_widths = std::get<1>(invocations[0]).size(); + + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(input_dims); + fusion.addInput(x); + + std::vector pad_width_vals(num_pad_widths); + for (auto i : c10::irange(num_pad_widths)) { + pad_width_vals[i] = IrBuilder::create(); + fusion.addInput(pad_width_vals[i]); + } + auto x_pad = pad(x, pad_width_vals); + auto y = sum(x_pad, {kReductionAxis}); + fusion.addOutput(y); + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + + // Return pair of: number of concretizations & total number of kernel runtimes + auto countConcretizations = [&fusion_executor_cache]() { + return fusion_executor_cache.getKernelRuntimes().size(); + }; + size_t num_concretizations = countConcretizations(); + // Check that concretizations and runtimes are cache misses only when they + // should be + auto checkCache = [&countConcretizations, + &num_concretizations](bool expect_miss) { + auto current = countConcretizations(); + ASSERT_EQ(current, num_concretizations + (size_t)expect_miss); + num_concretizations = current; + }; + + for (auto& inv : invocations) { + auto input_shape = std::get<0>(inv); + auto pad_widths = std::get<1>(inv); + auto expect_miss = std::get<2>(inv); + + TORCH_INTERNAL_ASSERT(input_shape.size() == input_dims); + TORCH_INTERNAL_ASSERT(pad_widths.size() == num_pad_widths); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor at_x = at::randn(input_shape, options); + std::vector aten_inputs = {at_x}; + // Add input scalars describing the reshape size for concretization + for (size_t i : c10::irange(pad_widths.size())) { + aten_inputs.emplace_back(pad_widths[i]); + } + + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + checkCache(expect_miss); + + auto at_x_pad = at::pad(at_x, pad_widths); + auto at_y = at::sum(at_x_pad, kReductionAxis); + + testValidate(&fusion, outputs, aten_inputs, {at_y}, __LINE__, __FILE__); + } +} + +// Test dynamic pad for various inputs +TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { + // NOLINTBEGIN(bugprone-implicit-widening-of-multiplication-result) + auto invocations = std::vector{ + {{3, 5}, {0, 0}, true}, // trivial + + {{3, 5}, {2, 1}, false}, // simple pad of both sides + {{3, 5}, {-1, 1}, false}, // shift by one + // The following fails with a SIGFPE in innerReductionHeuristic + // See https://github.com/NVIDIA/Fuser/issues/264 + //{{3, 5}, {-3, -2}, false}, // output is zero-dimensional + + // Output has size 1 so is set to broadcast. + {{3, 5}, {0, -4}, true}, + + // Test full negative shifts, so output doesn't overlap input + {{3, 5}, {-5, 2}, false}, + {{3, 5}, {2, -5}, false}, // full shift the other direction, re-use + + // The following reuses the schedule of {3, 5} inputs, and does not set + // broadcast on the second input dimension. + {{3, 1}, {1, 1}, false}, + + // Test zero-dimensional input + //{{3, 0}, {0, 0}, false}, // SIGFPE (see #264 above) + {{3, 0}, {1, 1}, false}, + //{{3, 0}, {-1, 1}, false}, // SIGFPE (see #264 above) + }; + // NOLINTEND(bugprone-implicit-widening-of-multiplication-result) + reductionDynamicPadAddFusion(invocations); +} + } // namespace nvfuser diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 0214d8f2e62..d8fe4cffa76 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -567,10 +567,11 @@ TEST_F(NVFuserTest, FusionResizeCat3_CUDA) { std::vector shape0({4, 2}); std::vector shape1({4, 3}); - auto tv0 = makeSymbolicTensor(2); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape0); fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); + auto tv1 = makeConcreteTensor(shape1); fusion.addInput(tv1); auto tv2 = cat({tv0, tv1}, 1); @@ -608,10 +609,11 @@ TEST_F(NVFuserTest, FusionResizeCat4_CUDA) { std::vector shape0({11, 12}); std::vector shape1({11, 13}); - auto tv0 = makeSymbolicTensor(2); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape0); fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); + auto tv1 = makeConcreteTensor(shape1); fusion.addInput(tv1); auto tv2 = cat({tv0, tv1}, 1); @@ -649,11 +651,12 @@ TEST_F(NVFuserTest, FusionResizeCat5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - auto tv0 = makeSymbolicTensor(2); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor({11, 12}); fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); + auto tv1 = makeConcreteTensor({11, 13}); fusion.addInput(tv1); - auto tv2 = makeSymbolicTensor(2); + auto tv2 = makeConcreteTensor({11, 25}); fusion.addInput(tv2); auto tv3 = cat({tv0, tv1}, 1); @@ -743,6 +746,7 @@ TEST_F(NVFuserTest, FusionResizeCat6_CUDA) { // Cat many tensors TEST_F(NVFuserTest, FusionResizeCat7_CUDA) { int num_tensors_to_concat = 10; + std::vector base_shape({11, 13}); for (int concat_dim : {0, 1}) { Fusion fusion; @@ -751,7 +755,10 @@ TEST_F(NVFuserTest, FusionResizeCat7_CUDA) { std::vector inputs; for (const auto i : c10::irange(num_tensors_to_concat)) { (void)i; - auto tv = makeSymbolicTensor(2); + // concrete shapes to avoid dynamic Fusion + auto shape = base_shape; + shape[concat_dim] = 10 + (i % 5); + auto tv = makeConcreteTensor(shape); fusion.addInput(tv); inputs.push_back(tv); } @@ -774,7 +781,6 @@ TEST_F(NVFuserTest, FusionResizeCat7_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - std::vector base_shape({11, 13}); std::vector aten_inputs; for (const auto i : c10::irange(num_tensors_to_concat)) { auto shape = base_shape; @@ -914,7 +920,8 @@ TEST_F(NVFuserTest, FusionResizeSlice1_CUDA) { std::vector shape({9}); - auto tv0 = makeSymbolicTensor(1); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); auto tv1 = slice( @@ -1004,7 +1011,8 @@ TEST_F(NVFuserTest, FusionResizeSlice4_CUDA) { std::vector shape({5, 100}); - auto tv0 = makeSymbolicTensor(2); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); // Consider a fusion of: @@ -1083,7 +1091,10 @@ TEST_F(NVFuserTest, FusionResizeSlice5_CUDA) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - auto tv0 = makeSymbolicTensor(2); + std::vector shape({11, 1000}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); fusion.addInput(tv0); auto tv1 = slice( @@ -1118,7 +1129,6 @@ TEST_F(NVFuserTest, FusionResizeSlice5_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - std::vector shape({11, 1000}); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0});