Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "sketch_policy.h"

#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>

#include <algorithm>
#include <iomanip>
Expand Down Expand Up @@ -334,28 +335,44 @@ Array<State> SketchPolicyNode::GenerateSketches() {
Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches, int out_size) {
int fail_ct = 0;
Array<State> out_states;
std::vector<std::mt19937> rand_gens;
rand_gens.reserve(out_size);
for (int i = 0; i < out_size; i++) {
rand_gens.push_back(std::mt19937(rand_gen()));
}
auto tic_begin = std::chrono::high_resolution_clock::now();

while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
State tmp_s = sketches[(rand_gen)() % sketches.size()];

// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
std::vector<State> temp_states(out_size);

support::parallel_for(0, out_size - out_states.size(),
[this, &temp_states, &sketches, &rand_gens](int index) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different
// possibility for they may have different potential on generating state
// with better performance
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});

for (int i = 0; i < out_size; i++) {
if (temp_states[i].defined()) {
out_states.push_back(std::move(temp_states[i]));
} else {
fail_ct++;
}
}

if (valid) {
out_states.push_back(std::move(tmp_s));
} else {
fail_ct++;
}
}

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
Expand Down Expand Up @@ -461,7 +478,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul

if (dis(rand_gen) < mutation_prob) {
const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
mutation_success_ct++;
} else {
Expand Down
66 changes: 34 additions & 32 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(

/********** Init Population **********/

PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
Expand All @@ -461,7 +461,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
extent, ps->lengths.size(),
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor));
const auto& candidate_lengths = candidate_lens[(policy->rand_gen)() % candidate_lens.size()];
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];

pstate->transform_steps.Set(
step_id,
Expand All @@ -475,8 +475,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(
SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kValid;
}
Expand All @@ -495,7 +495,7 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli
std::vector<std::pair<int, int>> candidates =
GetComputeLocationCandidates(policy->search_task, *state, stage_id);

int choice = (policy->rand_gen)() % (candidates.size() + 2);
int choice = (*rand_gen)() % (candidates.size() + 2);

if (choice == 0) {
if (!HasReduceIter(stage)) {
Expand All @@ -518,8 +518,8 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
annotate_parallel;
annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state,
Expand Down Expand Up @@ -583,8 +583,8 @@ PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* polic
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::vector<int>& auto_unroll_configs =
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
Expand Down Expand Up @@ -625,7 +625,7 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,

if (HasReduceIter(stage)) {
// Use auto unroll for multi level tiled stage
int value = auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()];
int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()];
state->pragma(stage_id, (*state)->stages[stage_id]->iters[0],
std::string("auto_unroll_max_step") + "$" + std::to_string(value));
}
Expand All @@ -635,7 +635,8 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,
}

PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
State* state) const {
State* state,
std::mt19937* rand_gen) const {
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
Expand Down Expand Up @@ -679,7 +680,7 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode*

if (num_fusible > 1) {
// Select a random range to fuse
num_fusible = 1 + (policy->rand_gen)() % (num_fusible - 1);
num_fusible = 1 + (*rand_gen)() % (num_fusible - 1);
}

if (num_fusible == 1) {
Expand All @@ -693,8 +694,8 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode*
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
std::set<int> multi_level_tiling_root_set;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
Expand Down Expand Up @@ -847,8 +848,8 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

Expand Down Expand Up @@ -877,7 +878,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
const SplitStepNode* ps;

do {
step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()];
step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()];
ps = (*state)->transform_steps[step_id].as<SplitStepNode>();
CHECK(ps != nullptr);
extent = GetIntImm(ps->extent.value());
Expand All @@ -898,7 +899,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol

// Random permute the tile size order.
std::vector<int> random_perm;
RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen));
RandomPermutation(lengths.size(), &random_perm, rand_gen);

// Try to divide a factor from one tile size and multiple it to another.
for (size_t i = 0; i < random_perm.size(); ++i) {
Expand Down Expand Up @@ -926,9 +927,9 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
// Failed on this dst_idx, try next one.
continue;
}
divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)];
divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)];
} else {
divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)];
divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)];
}

// Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx].
Expand All @@ -955,8 +956,8 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol
return ResultKind::kInvalid;
}

PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// Extract all auto_unroll_max_step pragma steps.
std::vector<int> pragma_steps;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
Expand All @@ -974,20 +975,21 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p
IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu;

// Randomly pick up an auto unroll pragma step
auto step_id = pragma_steps[(policy->rand_gen)() % pragma_steps.size()];
auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
CHECK(ps);

// Mutate its value to a random candidates
auto val = std::to_string(auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]);
auto val = std::to_string(auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]);
StateNode* pstate = state->CopyOnWrite();
pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id,
std::string("auto_unroll_max_step") + "$" + val));
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
State* state,
std::mt19937* rand_gen) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kInvalid;
}
Expand All @@ -1013,7 +1015,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
}

// Randomly pick one step
size_t step_id = compute_at_steps[(policy->rand_gen)() % compute_at_steps.size()];
size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
CHECK(ps != nullptr);
Expand All @@ -1025,7 +1027,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
return PopulationGenerationRule::ResultKind::kInvalid;
}

int choice = (policy->rand_gen)() % (candidates.size());
int choice = (*rand_gen)() % (candidates.size());
int new_compute_at_stage_id = candidates[choice].first;
int new_compute_at_iter_id = candidates[choice].second;

Expand All @@ -1049,8 +1051,8 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
// This mutation rule only focuses on a case that parallel was added to
// the outermost loop and the loop is generated by fusing other loops.
// In short, we mutate the fusion step before the parallel step.
Expand All @@ -1074,7 +1076,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol
}

// Randomly pick one parallel step.
size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()];
size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<AnnotationStepNode>();
CHECK(ps);
size_t stage_id = ps->stage_id;
Expand Down Expand Up @@ -1113,7 +1115,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol

// Mutate the fusion iters and replay the mutated fused/annotation steps.
int iter_offset = 0;
if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) {
if (RandomChoose(fuse_dir, rand_gen) == 0) {
fused_ids.pop_back();
iter_offset = 1;
} else {
Expand Down
21 changes: 11 additions & 10 deletions src/auto_scheduler/search_policy/sketch_policy_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,18 @@ class PopulationGenerationRule {
* \param state The state to apply this rule, update inplace.
* \return The result of this rule, indicate if there's any valid state generated.
*/
virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;
virtual ResultKind Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const = 0;

/*! \brief The deconstructor */
virtual ~PopulationGenerationRule() = default;
};

// A helper to define population initialization rules
#define DEFINE_INIT_POPULATION_RULE(rule_name) \
class rule_name : public PopulationGenerationRule { \
public: \
ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
#define DEFINE_INIT_POPULATION_RULE(rule_name) \
class rule_name : public PopulationGenerationRule { \
public: \
ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
};

/*! \brief The rule that fills the incomplete SplitSteps. */
Expand Down Expand Up @@ -185,11 +186,11 @@ class PopulationMutationRule : public PopulationGenerationRule {
};

// A helper to define mutation rules used in the evolutionary search
#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \
class rule_name : public PopulationMutationRule { \
public: \
explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \
class rule_name : public PopulationMutationRule { \
public: \
explicit rule_name(double weight) : PopulationMutationRule(weight) {} \
ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
};

/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor
Expand Down
Loading