Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 6 additions & 96 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,9 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) {
// Gets maximum vectorizable width of tv, assumes we can merge across all
// iteration domains if contiguous. Cannot permute the dimensions to fix
// contiguity.
size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) {
size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(
TensorView* tv,
bool contig_merge) {
// Gets the vectorizable width of the tv starting from the inner most
// dimension, working its way towards the outer most dimension, if they're
// contiguous. Ignores broadcast and reduction domains.
Expand Down Expand Up @@ -1130,117 +1132,25 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) {

// Still contiguous
numel *= dim_size->as<int64_t>();
}

// Assuming intermediate tensors have friendly alignment, and
// all contiguity true. Determine the largest power of 2 below
// innermost dimension size for the word size of vectorizaiton
size_t vector_size = 1;
size_t next_vector_size = 2;
while (next_vector_size <= max_vector_size &&
next_vector_size <= (size_t)numel && numel % next_vector_size == 0) {
vector_size = next_vector_size;
next_vector_size *= 2;
}

// save output to avoid re-compute
max_vectorword_map_[tv] = vector_size;

return vector_size;
}

// Gets the vectorizable width of the inner most dimension of tv if it's
// contiguous. Ignores inner most dimensions that are broadcast or reduction.
size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) {
auto inner_vectorword_map_it_ = inner_vectorword_map_.find(tv);
if (inner_vectorword_map_it_ != inner_vectorword_map_.end()) {
return inner_vectorword_map_it_->second;
}

// If we don't have an record, either it is a tv with innermost broadcast,
// or it is an intermediate tensor allocated by fuser. Logic copied to get
// root according to scheduler_utils::innerMostRootDim.
auto tv_root = tv->hasReduction() && tv->hasRFactor()
? tv->getRootDomain()
: tv->getMaybeRFactorDomain();

auto tv_root_no_reductions = TensorDomain::noReductions(tv_root);

auto contiguity = tv->domain()->contiguity();
// Appears after reductions the reduction domain often has a contiguity entry.
// This only matters if the result of the reduction is an output
if (contiguity.size() == tv_root.size() &&
contiguity.size() != tv_root_no_reductions.size()) {
std::vector<std::optional<bool>> new_contiguity;
for (auto i : c10::irange(tv_root.size())) {
if (!tv_root[i]->isReduction()) {
new_contiguity.push_back(contiguity[i]);
}
}
contiguity = new_contiguity;
}
tv_root = tv_root_no_reductions;

auto tv_root_no_reductions_size = tv_root_no_reductions.size();

// Filter out 0-dim tensors
if (tv_root_no_reductions_size < 1) {
return 1;
}

// Filter out mismatched contiguity info
if (tv_root_no_reductions_size != contiguity.size()) {
return 1;
}

auto inner_most_dim = scheduler_utils::innerMostRootDim(tv);

int id_pos = -1;
for (auto root_i : c10::irange((int)tv_root_no_reductions_size)) {
if (tv_root_no_reductions[root_i] == inner_most_dim) {
id_pos = root_i;
if (!contig_merge) {
break;
}
}

// Something went wrong with finding the inner most dimension, just
// return 1.
if (id_pos == -1) {
return 1;
}

// If the inner most dimension is not contiguous return 1
auto contiguity_opt = contiguity.at(id_pos);
TORCH_INTERNAL_ASSERT(contiguity_opt.has_value());
if (!*contiguity_opt) {
return 1;
}

size_t item_size = dataTypeSize(tv->dtype(), getIndexType());

// Alignment should always at least be the data type size
TORCH_INTERNAL_ASSERT(getAlignmentSize(tv) % item_size == 0);
size_t max_vector_size = getAlignmentSize(tv) / item_size;

// Assuming intermediate tensors have friendly alignment, and
// all contiguity true. Determine the largest power of 2 below
// innermost dimension size for the word size of vectorizaiton
size_t vector_size = 1;
size_t next_vector_size = 2;
auto maybe_inner_dimension_size =
expression_evaluator_->evaluate(inner_most_dim->extent());
TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value());
size_t inner_dimension_size = maybe_inner_dimension_size->as<int64_t>();

while (next_vector_size <= max_vector_size &&
next_vector_size <= inner_dimension_size &&
inner_dimension_size % next_vector_size == 0) {
next_vector_size <= (size_t)numel && numel % next_vector_size == 0) {
vector_size = next_vector_size;
next_vector_size *= 2;
}

// save output to avoid re-compute
inner_vectorword_map_[tv] = vector_size;
max_vectorword_map_[tv] = vector_size;

return vector_size;
}
Expand Down
13 changes: 4 additions & 9 deletions csrc/scheduler/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,10 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
size_t getAlignmentSize(TensorView* tv);

// Gets maximum vectorizable width of tv, assumes we can merge across all
// iteration domains if contiguous. Cannot permute the dimensions to fix
// contiguity. Ignores dimensions that are broadcast or reduction.
size_t getMaxVectorizableWidth(TensorView* tv);

// Gets the vectorizable width of the inner most dimension of tv if it's
// contiguous. Ignores inner most dimensions that are broadcast or reduction.
size_t getInnerDimVectorizableWidth(TensorView* tv);
// iteration domains if contiguous, unless contig_merge=false. Cannot permute
// the dimensions to fix contiguity. Ignores dimensions that are broadcast or
// reduction.
size_t getMaxVectorizableWidth(TensorView* tv, bool contig_merge = true);

// Computes alignment size in bytes for provided ptr address
static size_t computeAlignmentSize(size_t ptr_address);
Expand Down Expand Up @@ -129,8 +126,6 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
std::unordered_map<TensorView*, size_t> alignment_map_;
// Cache for getMaxVectorizableWidth
std::unordered_map<TensorView*, size_t> max_vectorword_map_;
// Cache for getInnerDimVectorizableWidth
std::unordered_map<TensorView*, size_t> inner_vectorword_map_;

// Found index mode kernel needs to be run in
PrimDataType index_type_ = PrimDataType::Int;
Expand Down
4 changes: 2 additions & 2 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(

for (auto tv : grouped_inputs_outputs[0]) {
const auto tv_vectorize_factor =
runtime_info.getInnerDimVectorizableWidth(tv);
runtime_info.getMaxVectorizableWidth(tv, false);
vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor);
}
// TODO: Since group2 only has global->shared and shared->global set op, we
Expand All @@ -709,7 +709,7 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
// group 2
for (auto tv : grouped_inputs_outputs[1]) {
const auto tv_vectorize_factor =
runtime_info.getInnerDimVectorizableWidth(tv);
runtime_info.getMaxVectorizableWidth(tv, false);
vectorize_factor2 = std::min(vectorize_factor2, tv_vectorize_factor);
}

Expand Down
100 changes: 34 additions & 66 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,11 @@ int64_t getVectorizationSize(
auto denominator = denominator_optional->as<int64_t>();
auto extent = extent_optional->as<int64_t>();

// TODO: we should clean this up with expr simplifier
auto gcd = std::gcd(numerator, denominator);
numerator = numerator / gcd;
denominator = denominator / gcd;

if (denominator != 1) {
break;
}
Expand Down Expand Up @@ -1201,52 +1206,61 @@ int64_t getVectorizationSize(
return vectorize_size;
}

size_t getExpandedVectorization(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A big portion of getVectorizationFactor has been removed, so I just cut-paste the function body of this function into getVectorizationFactor and remove this function.

const std::vector<ContiguousInnerDimensionsMapper>& reference_maps,
size_t getVectorizationFactor(
SchedulerRuntimeInfo& runtime_info,
const std::vector<TensorView*> vectorizable_inputs_outputs,
TensorView* reference_tv,
int break_point,
size_t default_word_size) {
HeuristicSummary* data_cache,
int break_point) {
auto vectorizable_inputs_outputs_entry =
HeuristicSummaryEntry<HeuristicCompileTime::VectorizableInputsAndOutputs>(
data_cache, [&reference_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
reference_tv, true, true));
});

auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();

auto vectorize_maps_entry =
HeuristicSummaryEntry<HeuristicCompileTime::VectorizeMaps>(
data_cache, [&reference_tv]() {
return std::make_unique<
std::vector<vectorize_helper::ContiguousInnerDimensionsMapper>>(
vectorize_helper::getAllVectorizedMapsOf(reference_tv));
});

if (vectorizable_inputs_outputs.empty()) {
return 1;
}

size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
size_t max_vec_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
size_t common_alignment_size =
SchedulerRuntimeInfo::max_alignment_size_in_byte;

for (auto inp_or_out : vectorizable_inputs_outputs) {
auto dtype_size =
dataTypeSize(inp_or_out->dtype(), runtime_info.getIndexType());

max_expand_size = std::min(
max_expand_size,
max_vec_size = std::min(
max_vec_size,
SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size);
max_expand_size = std::min(
max_expand_size, runtime_info.getMaxVectorizableWidth(inp_or_out));
max_vec_size = std::min(
max_vec_size, runtime_info.getMaxVectorizableWidth(inp_or_out));
common_alignment_size = std::min(
common_alignment_size, runtime_info.getAlignmentSize(inp_or_out));
}

// If there's no possibility to increase vector size of provided tensors,
// then don't bother doing a more complex analysis to try and do so, just
// return early.
if (max_expand_size == default_word_size) {
return default_word_size;
}

auto reference_map = reference_maps[break_point];
auto reference_map = vectorize_maps_entry.get().at(break_point);
// Initialize to max the tensors could support.
size_t max_supported_vector_size = max_expand_size;
size_t max_supported_vector_size = max_vec_size;
for (auto inp_or_out : vectorizable_inputs_outputs) {
size_t contig_dim_size = getVectorizationSize(
getContigVectorSizesOf(inp_or_out, reference_map),
runtime_info.expressionEvaluator());
size_t local_max_vec_size = 1;

while (contig_dim_size > 1 && contig_dim_size % 2 == 0 &&
local_max_vec_size < max_expand_size) {
local_max_vec_size < max_vec_size) {
contig_dim_size /= 2;
local_max_vec_size *= 2;
}
Expand All @@ -1257,51 +1271,5 @@ size_t getExpandedVectorization(
return max_supported_vector_size;
}

size_t getVectorizationFactor(
SchedulerRuntimeInfo& runtime_info,
TensorView* reference_tv,
HeuristicSummary* data_cache,
int break_point) {
auto vectorizable_inputs_outputs_entry =
HeuristicSummaryEntry<HeuristicCompileTime::VectorizableInputsAndOutputs>(
data_cache, [&reference_tv]() {
return std::make_unique<std::vector<TensorView*>>(
scheduler_utils::getInputsOutputsWithInnerDim(
reference_tv, true, true));
});

auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();

size_t vectorize_factor = std::numeric_limits<size_t>::max();

for (auto tv : vectorizable_inputs_outputs) {
const auto tv_vectorize_factor =
runtime_info.getInnerDimVectorizableWidth(tv);
vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
}

if (vectorize_factor == std::numeric_limits<size_t>::max()) {
vectorize_factor = 1;
}

auto vectorize_maps_entry =
HeuristicSummaryEntry<HeuristicCompileTime::VectorizeMaps>(
data_cache, [&reference_tv]() {
return std::make_unique<
std::vector<vectorize_helper::ContiguousInnerDimensionsMapper>>(
vectorize_helper::getAllVectorizedMapsOf(reference_tv));
});

vectorize_factor = vectorize_helper::getExpandedVectorization(
vectorize_maps_entry.get(),
runtime_info,
vectorizable_inputs_outputs,
reference_tv,
break_point,
vectorize_factor);

return vectorize_factor;
}

} // namespace vectorize_helper
} // namespace nvfuser
11 changes: 0 additions & 11 deletions csrc/scheduler/vectorize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,17 +594,6 @@ std::vector<std::pair<ProjectedExtent&, IterDomain*>> getContigVectorSizesOf(
TensorView* of_tv,
ContiguousInnerDimensionsMapper& mapper);

// TODO: vectorizable_inputs_outputs should actually be known based on the
// computed mappings. If nothing is mapped for a tensorview it's not
// vectorizable.
size_t getExpandedVectorization(
const std::vector<ContiguousInnerDimensionsMapper>& reference_maps,
SchedulerRuntimeInfo& runtime_info,
const std::vector<TensorView*> vectorizable_inputs_outputs,
TensorView* reference_tv,
int break_point,
size_t default_word_size);

size_t getVectorizationFactor(
SchedulerRuntimeInfo& runtime_info,
TensorView* reference_tv,
Expand Down
12 changes: 6 additions & 6 deletions test/test_gpu1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,17 +1201,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) {
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T3) {
int64_t i86;
i86 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i86 < T0.size[0])) {
int64_t i87;
i87 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
if ((i87 < T0.size[0])) {
float T5[1];
T5[0] = 0;
T5[0]
= T1[i86];
= T1[i87];
float T4[1];
T4[0] = 0;
T4[0]
= T0[i86];
= T0[i87];
float T2[1];
T2[0]
= T4[0]
Expand All @@ -1220,7 +1220,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 1, 1> T0, Tensor<float, 1, 1>
T6[0]
= T2[0]
* T4[0];
T3[i86]
T3[i87]
= T6[0];
}
}
Expand Down
Loading