diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 1c07c0a4340..4bbea966fbf 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -1051,7 +1051,9 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) { // Gets maximum vectorizable width of tv, assumes we can merge across all // iteration domains if contiguous. Cannot permute the dimensions to fix // contiguity. -size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { +size_t SchedulerRuntimeInfo::getMaxVectorizableWidth( + TensorView* tv, + bool contig_merge) { // Gets the vectorizable width of the tv starting from the inner most // dimension, working its way towards the outer most dimension, if they're // contiguous. Ignores broadcast and reduction domains. @@ -1130,117 +1132,25 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { // Still contiguous numel *= dim_size->as(); - } - - // Assuming intermediate tensors have friendly alignment, and - // all contiguity true. Determine the largest power of 2 below - // innermost dimension size for the word size of vectorizaiton - size_t vector_size = 1; - size_t next_vector_size = 2; - while (next_vector_size <= max_vector_size && - next_vector_size <= (size_t)numel && numel % next_vector_size == 0) { - vector_size = next_vector_size; - next_vector_size *= 2; - } - // save output to avoid re-compute - max_vectorword_map_[tv] = vector_size; - - return vector_size; -} - -// Gets the vectorizable width of the inner most dimension of tv if it's -// contiguous. Ignores inner most dimensions that are broadcast or reduction. -size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { - auto inner_vectorword_map_it_ = inner_vectorword_map_.find(tv); - if (inner_vectorword_map_it_ != inner_vectorword_map_.end()) { - return inner_vectorword_map_it_->second; - } - - // If we don't have an record, either it is a tv with innermost broadcast, - // or it is an intermediate tensor allocated by fuser. Logic copied to get - // root according to scheduler_utils::innerMostRootDim. - auto tv_root = tv->hasReduction() && tv->hasRFactor() - ? tv->getRootDomain() - : tv->getMaybeRFactorDomain(); - - auto tv_root_no_reductions = TensorDomain::noReductions(tv_root); - - auto contiguity = tv->domain()->contiguity(); - // Appears after reductions the reduction domain often has a contiguity entry. - // This only matters if the result of the reduction is an output - if (contiguity.size() == tv_root.size() && - contiguity.size() != tv_root_no_reductions.size()) { - std::vector> new_contiguity; - for (auto i : c10::irange(tv_root.size())) { - if (!tv_root[i]->isReduction()) { - new_contiguity.push_back(contiguity[i]); - } - } - contiguity = new_contiguity; - } - tv_root = tv_root_no_reductions; - - auto tv_root_no_reductions_size = tv_root_no_reductions.size(); - - // Filter out 0-dim tensors - if (tv_root_no_reductions_size < 1) { - return 1; - } - - // Filter out mismatched contiguity info - if (tv_root_no_reductions_size != contiguity.size()) { - return 1; - } - - auto inner_most_dim = scheduler_utils::innerMostRootDim(tv); - - int id_pos = -1; - for (auto root_i : c10::irange((int)tv_root_no_reductions_size)) { - if (tv_root_no_reductions[root_i] == inner_most_dim) { - id_pos = root_i; + if (!contig_merge) { break; } } - // Something went wrong with finding the inner most dimension, just - // return 1. - if (id_pos == -1) { - return 1; - } - - // If the inner most dimension is not contiguous return 1 - auto contiguity_opt = contiguity.at(id_pos); - TORCH_INTERNAL_ASSERT(contiguity_opt.has_value()); - if (!*contiguity_opt) { - return 1; - } - - size_t item_size = dataTypeSize(tv->dtype(), getIndexType()); - - // Alignment should always at least be the data type size - TORCH_INTERNAL_ASSERT(getAlignmentSize(tv) % item_size == 0); - size_t max_vector_size = getAlignmentSize(tv) / item_size; - // Assuming intermediate tensors have friendly alignment, and // all contiguity true. Determine the largest power of 2 below // innermost dimension size for the word size of vectorizaiton size_t vector_size = 1; size_t next_vector_size = 2; - auto maybe_inner_dimension_size = - expression_evaluator_->evaluate(inner_most_dim->extent()); - TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value()); - size_t inner_dimension_size = maybe_inner_dimension_size->as(); - while (next_vector_size <= max_vector_size && - next_vector_size <= inner_dimension_size && - inner_dimension_size % next_vector_size == 0) { + next_vector_size <= (size_t)numel && numel % next_vector_size == 0) { vector_size = next_vector_size; next_vector_size *= 2; } // save output to avoid re-compute - inner_vectorword_map_[tv] = vector_size; + max_vectorword_map_[tv] = vector_size; return vector_size; } diff --git a/csrc/scheduler/registry.h b/csrc/scheduler/registry.h index 81ce5f4b997..b1ffc5b1d84 100644 --- a/csrc/scheduler/registry.h +++ b/csrc/scheduler/registry.h @@ -67,13 +67,10 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { size_t getAlignmentSize(TensorView* tv); // Gets maximum vectorizable width of tv, assumes we can merge across all - // iteration domains if contiguous. Cannot permute the dimensions to fix - // contiguity. Ignores dimensions that are broadcast or reduction. - size_t getMaxVectorizableWidth(TensorView* tv); - - // Gets the vectorizable width of the inner most dimension of tv if it's - // contiguous. Ignores inner most dimensions that are broadcast or reduction. - size_t getInnerDimVectorizableWidth(TensorView* tv); + // iteration domains if contiguous, unless contig_merge=false. Cannot permute + // the dimensions to fix contiguity. Ignores dimensions that are broadcast or + // reduction. + size_t getMaxVectorizableWidth(TensorView* tv, bool contig_merge = true); // Computes alignment size in bytes for provided ptr address static size_t computeAlignmentSize(size_t ptr_address); @@ -129,8 +126,6 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { std::unordered_map alignment_map_; // Cache for getMaxVectorizableWidth std::unordered_map max_vectorword_map_; - // Cache for getInnerDimVectorizableWidth - std::unordered_map inner_vectorword_map_; // Found index mode kernel needs to be run in PrimDataType index_type_ = PrimDataType::Int; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 69175bbe002..9c10f9865ec 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -700,7 +700,7 @@ std::shared_ptr getTransposeHeuristics( for (auto tv : grouped_inputs_outputs[0]) { const auto tv_vectorize_factor = - runtime_info.getInnerDimVectorizableWidth(tv); + runtime_info.getMaxVectorizableWidth(tv, false); vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor); } // TODO: Since group2 only has global->shared and shared->global set op, we @@ -709,7 +709,7 @@ std::shared_ptr getTransposeHeuristics( // group 2 for (auto tv : grouped_inputs_outputs[1]) { const auto tv_vectorize_factor = - runtime_info.getInnerDimVectorizableWidth(tv); + runtime_info.getMaxVectorizableWidth(tv, false); vectorize_factor2 = std::min(vectorize_factor2, tv_vectorize_factor); } diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index ca2fcf27b86..26f99d0d5bd 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -1157,6 +1157,11 @@ int64_t getVectorizationSize( auto denominator = denominator_optional->as(); auto extent = extent_optional->as(); + // TODO: we should clean this up with expr simplifier + auto gcd = std::gcd(numerator, denominator); + numerator = numerator / gcd; + denominator = denominator / gcd; + if (denominator != 1) { break; } @@ -1201,18 +1206,34 @@ int64_t getVectorizationSize( return vectorize_size; } -size_t getExpandedVectorization( - const std::vector& reference_maps, +size_t getVectorizationFactor( SchedulerRuntimeInfo& runtime_info, - const std::vector vectorizable_inputs_outputs, TensorView* reference_tv, - int break_point, - size_t default_word_size) { + HeuristicSummary* data_cache, + int break_point) { + auto vectorizable_inputs_outputs_entry = + HeuristicSummaryEntry( + data_cache, [&reference_tv]() { + return std::make_unique>( + scheduler_utils::getInputsOutputsWithInnerDim( + reference_tv, true, true)); + }); + + auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get(); + + auto vectorize_maps_entry = + HeuristicSummaryEntry( + data_cache, [&reference_tv]() { + return std::make_unique< + std::vector>( + vectorize_helper::getAllVectorizedMapsOf(reference_tv)); + }); + if (vectorizable_inputs_outputs.empty()) { return 1; } - size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; + size_t max_vec_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; size_t common_alignment_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; @@ -1220,25 +1241,18 @@ size_t getExpandedVectorization( auto dtype_size = dataTypeSize(inp_or_out->dtype(), runtime_info.getIndexType()); - max_expand_size = std::min( - max_expand_size, + max_vec_size = std::min( + max_vec_size, SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size); - max_expand_size = std::min( - max_expand_size, runtime_info.getMaxVectorizableWidth(inp_or_out)); + max_vec_size = std::min( + max_vec_size, runtime_info.getMaxVectorizableWidth(inp_or_out)); common_alignment_size = std::min( common_alignment_size, runtime_info.getAlignmentSize(inp_or_out)); } - // If there's no possibility to increase vector size of provided tensors, - // then don't bother doing a more complex analysis to try and do so, just - // return early. - if (max_expand_size == default_word_size) { - return default_word_size; - } - - auto reference_map = reference_maps[break_point]; + auto reference_map = vectorize_maps_entry.get().at(break_point); // Initialize to max the tensors could support. - size_t max_supported_vector_size = max_expand_size; + size_t max_supported_vector_size = max_vec_size; for (auto inp_or_out : vectorizable_inputs_outputs) { size_t contig_dim_size = getVectorizationSize( getContigVectorSizesOf(inp_or_out, reference_map), @@ -1246,7 +1260,7 @@ size_t getExpandedVectorization( size_t local_max_vec_size = 1; while (contig_dim_size > 1 && contig_dim_size % 2 == 0 && - local_max_vec_size < max_expand_size) { + local_max_vec_size < max_vec_size) { contig_dim_size /= 2; local_max_vec_size *= 2; } @@ -1257,51 +1271,5 @@ size_t getExpandedVectorization( return max_supported_vector_size; } -size_t getVectorizationFactor( - SchedulerRuntimeInfo& runtime_info, - TensorView* reference_tv, - HeuristicSummary* data_cache, - int break_point) { - auto vectorizable_inputs_outputs_entry = - HeuristicSummaryEntry( - data_cache, [&reference_tv]() { - return std::make_unique>( - scheduler_utils::getInputsOutputsWithInnerDim( - reference_tv, true, true)); - }); - - auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get(); - - size_t vectorize_factor = std::numeric_limits::max(); - - for (auto tv : vectorizable_inputs_outputs) { - const auto tv_vectorize_factor = - runtime_info.getInnerDimVectorizableWidth(tv); - vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor); - } - - if (vectorize_factor == std::numeric_limits::max()) { - vectorize_factor = 1; - } - - auto vectorize_maps_entry = - HeuristicSummaryEntry( - data_cache, [&reference_tv]() { - return std::make_unique< - std::vector>( - vectorize_helper::getAllVectorizedMapsOf(reference_tv)); - }); - - vectorize_factor = vectorize_helper::getExpandedVectorization( - vectorize_maps_entry.get(), - runtime_info, - vectorizable_inputs_outputs, - reference_tv, - break_point, - vectorize_factor); - - return vectorize_factor; -} - } // namespace vectorize_helper } // namespace nvfuser diff --git a/csrc/scheduler/vectorize_helper.h b/csrc/scheduler/vectorize_helper.h index 0c8b6536195..2204782fe85 100644 --- a/csrc/scheduler/vectorize_helper.h +++ b/csrc/scheduler/vectorize_helper.h @@ -594,17 +594,6 @@ std::vector> getContigVectorSizesOf( TensorView* of_tv, ContiguousInnerDimensionsMapper& mapper); -// TODO: vectorizable_inputs_outputs should actually be known based on the -// computed mappings. If nothing is mapped for a tensorview it's not -// vectorizable. -size_t getExpandedVectorization( - const std::vector& reference_maps, - SchedulerRuntimeInfo& runtime_info, - const std::vector vectorizable_inputs_outputs, - TensorView* reference_tv, - int break_point, - size_t default_word_size); - size_t getVectorizationFactor( SchedulerRuntimeInfo& runtime_info, TensorView* reference_tv, diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 902e08afdff..a74671215f6 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i86; - i86 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - if ((i86 < T0.size[0])) { + int64_t i87; + i87 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + if ((i87 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i86]; + = T1[i87]; float T4[1]; T4[0] = 0; T4[0] - = T0[i86]; + = T0[i87]; float T2[1]; T2[0] = T4[0] @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T6[0] = T2[0] * T4[0]; - T3[i86] + T3[i87] = T6[0]; } } diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 86c27ba710b..c92b0b0d5ff 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -9043,27 +9043,27 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4, 4> T0, Tensor<__half, 4, 4> T2, Tensor<__half, 4, 4> T7) { - int64_t i1201; - i1201 = T0.size[2] * T0.size[1]; - int64_t i1204; - i1204 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); - int64_t i1206; - i1206 = (T0.size[1] * T0.size[2]) * T0.size[3]; - int64_t i1238; - i1238 = i1204 % i1206; - int64_t i1215; - i1215 = T0.size[2] * T0.size[3]; + int64_t i1202; + i1202 = T0.size[2] * T0.size[1]; + int64_t i1205; + i1205 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)); + int64_t i1207; + i1207 = (T0.size[1] * T0.size[2]) * T0.size[3]; int64_t i1239; - i1239 = i1238 % i1215; - if ((i1204 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { + i1239 = i1205 % i1207; + int64_t i1216; + i1216 = T0.size[2] * T0.size[3]; + int64_t i1240; + i1240 = i1239 % i1216; + if ((i1205 < (((T0.size[0] * T0.size[1]) * T0.size[2]) * T0.size[3]))) { __half T9[1]; T9[0] = 0; T9[0] - = T2[(((((i1201 * T0.size[3]) * (i1204 / i1206)) + (i1201 * (i1239 % T0.size[3]))) + (T0.size[2] * (i1238 / i1215))) + (i1239 / T0.size[3]))]; + = T2[(((((i1202 * T0.size[3]) * (i1205 / i1207)) + (i1202 * (i1240 % T0.size[3]))) + (T0.size[2] * (i1239 / i1216))) + (i1240 / T0.size[3]))]; __half T8[1]; T8[0] = 0; T8[0] - = T0[i1204]; + = T0[i1205]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -9083,7 +9083,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4, 4> T0, Tensor<__half, 4, 4 __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i1204] + T7[i1205] = T10[0]; } }