From fb9b5cb8f90ef94a92be164927c72e98728b89b2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 22 Sep 2020 18:22:50 +0800 Subject: [PATCH 1/5] Add a ReadWrite Lock for parallel Initpopulation --- .../search_policy/sketch_policy.cc | 53 +++++++++++------ .../search_policy/sketch_policy_rules.cc | 58 +++++++++++-------- .../search_policy/sketch_policy_rules.h | 6 +- src/auto_scheduler/search_policy/utils.cc | 41 ++++++++++++- src/auto_scheduler/search_policy/utils.h | 15 +++++ 5 files changed, 125 insertions(+), 48 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 6b4b6ae120bd..497ebb89889a 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -27,6 +27,7 @@ #include "sketch_policy.h" #include +#include #include #include @@ -332,29 +333,45 @@ Array SketchPolicyNode::GenerateSketches() { } Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { - int fail_ct = 0; + std::atomic fail_ct(0); Array out_states; + std::vector rand_seeds; + rand_seeds.reserve(out_size); + for (int i = 0; i < out_size; i++) { + rand_seeds.push_back(std::mt19937(rand_gen())); + } auto tic_begin = std::chrono::high_resolution_clock::now(); while (static_cast(out_states.size()) < out_size && fail_ct < out_size) { - // Random choose a starting sketch - // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have - // different potential on generating state with better performance - State tmp_s = sketches[(rand_gen)() % sketches.size()]; - - // Derivation rule based enumeration - bool valid = true; - for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { - valid = false; - break; + std::vector temp_states(out_size); + + support::parallel_for(0, out_size - out_states.size(), + [this, &temp_states, &sketches, &rand_seeds, &fail_ct](int index) { + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have + // different potential on generating state with better performance + State tmp_s = sketches[(rand_seeds[index])() % sketches.size()]; + + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &tmp_s, &rand_seeds[index]) == PopulationGenerationRule::ResultKind::kInvalid) { + valid = false; + break; + } } - } - if (valid) { - out_states.push_back(std::move(tmp_s)); - } else { - fail_ct++; + if (valid) { + temp_states[index] = std::move(tmp_s); + } else { + fail_ct++; + } + }); + + for (int i = 0; i < out_size; i++) { + if (temp_states[i].defined()) { + out_states.push_back(std::move(temp_states[i])); + } } } @@ -461,7 +478,7 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul if (dis(rand_gen) < mutation_prob) { const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)]; - if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) { + if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) { pnext->push_back(std::move(tmp_s)); mutation_success_ct++; } else { diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 228dda461beb..4c0ef6fd2136 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -440,7 +440,8 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** Init Population **********/ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -461,7 +462,7 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes( extent, ps->lengths.size(), GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor)); - const auto& candidate_lengths = candidate_lens[(policy->rand_gen)() % candidate_lens.size()]; + const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()]; pstate->transform_steps.Set( step_id, @@ -476,7 +477,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p } PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { return PopulationGenerationRule::ResultKind::kValid; } @@ -495,7 +497,7 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli std::vector> candidates = GetComputeLocationCandidates(policy->search_task, *state, stage_id); - int choice = (policy->rand_gen)() % (candidates.size() + 2); + int choice = (*rand_gen)() % (candidates.size() + 2); if (choice == 0) { if (!HasReduceIter(stage)) { @@ -519,7 +521,8 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli } PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { std::function annotate_parallel; annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, @@ -584,7 +587,8 @@ PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* polic } PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { std::vector& auto_unroll_configs = IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { @@ -625,7 +629,7 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, if (HasReduceIter(stage)) { // Use auto unroll for multi level tiled stage - int value = auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]; + int value = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]; state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], std::string("auto_unroll_max_step") + "$" + std::to_string(value)); } @@ -635,7 +639,8 @@ PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, } PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; // Skip the inlined stage and placeholder stage @@ -679,7 +684,7 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* if (num_fusible > 1) { // Select a random range to fuse - num_fusible = 1 + (policy->rand_gen)() % (num_fusible - 1); + num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); } if (num_fusible == 1) { @@ -694,7 +699,8 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* } PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { std::set multi_level_tiling_root_set; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { @@ -848,7 +854,8 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol } PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { int max_innermost_split_factor = GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); @@ -877,7 +884,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol const SplitStepNode* ps; do { - step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()]; + step_id = split_step_ids[(*rand_gen)() % split_step_ids.size()]; ps = (*state)->transform_steps[step_id].as(); CHECK(ps != nullptr); extent = GetIntImm(ps->extent.value()); @@ -898,7 +905,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol // Random permute the tile size order. std::vector random_perm; - RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen)); + RandomPermutation(lengths.size(), &random_perm, rand_gen); // Try to divide a factor from one tile size and multiple it to another. for (size_t i = 0; i < random_perm.size(); ++i) { @@ -926,9 +933,9 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol // Failed on this dst_idx, try next one. continue; } - divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)]; + divide_factor = factors[1 + (*rand_gen)() % (max_factor_index)]; } else { - divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)]; + divide_factor = factors[1 + (*rand_gen)() % (factors.size() - 1)]; } // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]. @@ -956,7 +963,8 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol } PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { // Extract all auto_unroll_max_step pragma steps. std::vector pragma_steps; for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { @@ -974,12 +982,12 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; // Randomly pick up an auto unroll pragma step - auto step_id = pragma_steps[(policy->rand_gen)() % pragma_steps.size()]; + auto step_id = pragma_steps[(*rand_gen)() % pragma_steps.size()]; auto ps = (*state)->transform_steps[step_id].as(); CHECK(ps); // Mutate its value to a random candidates - auto val = std::to_string(auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]); + auto val = std::to_string(auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]); StateNode* pstate = state->CopyOnWrite(); pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val)); @@ -987,7 +995,8 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p } PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { return PopulationGenerationRule::ResultKind::kInvalid; } @@ -1013,7 +1022,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo } // Randomly pick one step - size_t step_id = compute_at_steps[(policy->rand_gen)() % compute_at_steps.size()]; + size_t step_id = compute_at_steps[(*rand_gen)() % compute_at_steps.size()]; auto ps = (*state)->transform_steps[step_id].as(); int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id; CHECK(ps != nullptr); @@ -1025,7 +1034,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo return PopulationGenerationRule::ResultKind::kInvalid; } - int choice = (policy->rand_gen)() % (candidates.size()); + int choice = (*rand_gen)() % (candidates.size()); int new_compute_at_stage_id = candidates[choice].first; int new_compute_at_iter_id = candidates[choice].second; @@ -1050,7 +1059,8 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo } PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, - State* state) const { + State* state, + std::mt19937* rand_gen) const { // This mutation rule only focuses on a case that parallel was added to // the outermost loop and the loop is generated by fusing other loops. // In short, we mutate the fusion step before the parallel step. @@ -1074,7 +1084,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol } // Randomly pick one parallel step. - size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()]; + size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()]; auto ps = (*state)->transform_steps[step_id].as(); CHECK(ps); size_t stage_id = ps->stage_id; @@ -1113,7 +1123,7 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol // Mutate the fusion iters and replay the mutated fused/annotation steps. int iter_offset = 0; - if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) { + if (RandomChoose(fuse_dir, rand_gen) == 0) { fused_ids.pop_back(); iter_offset = 1; } else { diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 4098df23a604..115378a96561 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -137,7 +137,7 @@ class PopulationGenerationRule { * \param state The state to apply this rule, update inplace. * \return The result of this rule, indicate if there's any valid state generated. */ - virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0; + virtual ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const = 0; /*! \brief The deconstructor */ virtual ~PopulationGenerationRule() = default; @@ -147,7 +147,7 @@ class PopulationGenerationRule { #define DEFINE_INIT_POPULATION_RULE(rule_name) \ class rule_name : public PopulationGenerationRule { \ public: \ - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ + ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ }; /*! \brief The rule that fills the incomplete SplitSteps. */ @@ -189,7 +189,7 @@ class PopulationMutationRule : public PopulationGenerationRule { class rule_name : public PopulationMutationRule { \ public: \ explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ + ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ }; /*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index 744573a44124..24d3fd37810f 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -414,19 +414,52 @@ void PruneInvalidState(const SearchTask& task, Array* states) { } } +/********** SplitFactorizationMemo **********/ + +void SplitFactorizationMemo::ReadWriteLock::GetRead() { + std::unique_lock lock(cv_mutex_); + cv_.wait(lock, [this](){ return !this->is_writing_; }); + read_count_++; +} + +void SplitFactorizationMemo::ReadWriteLock::GetWrite() { + std::unique_lock lock(cv_mutex_); + cv_.wait(lock, [this](){ return this->read_count_ == 0 && !this->is_writing_; }); + is_writing_ = true; +} + +void SplitFactorizationMemo::ReadWriteLock::UnlockRead() { + std::lock_guard lock(cv_mutex_); + read_count_--; + if (read_count_ == 0) { + cv_.notify_one(); + } +} + +void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() { + std::lock_guard lock(cv_mutex_); + is_writing_ = false; + cv_.notify_one(); +} + const Array>& SplitFactorizationMemo::GetFactorizationSchemes( int extent, int n_lengths, int max_innermost_factor) { QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); - auto it = memory_.find(key); - if (it != memory_.end()) { + const auto& const_memory = memory_; + lock_.GetRead(); + const auto& it = const_memory.find(key); + const auto& memory_end = const_memory.end(); + lock_.UnlockRead(); + if (it != memory_end) { return it->second; } + lock_.GetWrite(); tmp_stack_ = Array(n_lengths, Integer()); results_ = &memory_[key]; n_lengths_ = n_lengths; - DfsEnumerate(0, extent, max_innermost_factor); + lock_.UnlockWrite(); return *results_; } @@ -464,6 +497,8 @@ const std::vector& SplitFactorizationMemo::GetFactors(int n) { return res; } +/********** Utils interface API for ffi **********/ + TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled") .set_body_typed([](const Stage& stage) { return IsTiled(stage); }); diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 75bf0d048c11..5705296b0b51 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -680,6 +681,20 @@ class SplitFactorizationMemo { private: void DfsEnumerate(int now, int remaining_length, int max_innermost_factor); + class ReadWriteLock { + public: + void GetRead(); + void GetWrite(); + void UnlockRead(); + void UnlockWrite(); + + private: + uint32_t read_count_ = 0; + bool is_writing_ = false; + std::mutex cv_mutex_; + std::condition_variable cv_; + } lock_; + std::unordered_map>> memory_; int n_lengths_; From ea02fcb155f54d8f3b6b894b94eea1d3fb4db949 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 22 Sep 2020 18:23:34 +0800 Subject: [PATCH 2/5] Fix the warning of test --- tests/python/unittest/test_auto_scheduler_layout_rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index aba27840a61f..caa1d6a99f40 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -45,7 +45,7 @@ def test_layout_rewrite_correctness(): workload = matmul_auto_scheduler_test workload_key = auto_scheduler.make_workload_key(workload, (N, N, N)) dag = auto_scheduler.ComputeDAG(workload_key) - target = tvm.target.create(target) + target = tvm.target.Target(target) task = auto_scheduler.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: From 98655d7105e85865cb97367249467e9313abf5a1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 23 Sep 2020 09:56:08 +0800 Subject: [PATCH 3/5] Lint fix --- .../search_policy/sketch_policy.cc | 44 ++++++++++--------- .../search_policy/sketch_policy_rules.cc | 26 ++++------- .../search_policy/sketch_policy_rules.h | 17 +++---- src/auto_scheduler/search_policy/utils.cc | 4 +- 4 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 497ebb89889a..785b19dee8d6 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -346,27 +346,29 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches std::vector temp_states(out_size); support::parallel_for(0, out_size - out_states.size(), - [this, &temp_states, &sketches, &rand_seeds, &fail_ct](int index) { - // Random choose a starting sketch - // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have - // different potential on generating state with better performance - State tmp_s = sketches[(rand_seeds[index])() % sketches.size()]; - - // Derivation rule based enumeration - bool valid = true; - for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s, &rand_seeds[index]) == PopulationGenerationRule::ResultKind::kInvalid) { - valid = false; - break; - } - } - - if (valid) { - temp_states[index] = std::move(tmp_s); - } else { - fail_ct++; - } - }); + [this, &temp_states, &sketches, &rand_seeds, &fail_ct](int index) { + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different + // possibility for they may have different potential on generating state + // with better performance + State tmp_s = sketches[(rand_seeds[index])() % sketches.size()]; + + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &tmp_s, &rand_seeds[index]) == + PopulationGenerationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } + + if (valid) { + temp_states[index] = std::move(tmp_s); + } else { + fail_ct++; + } + }); for (int i = 0; i < out_size; i++) { if (temp_states[i].defined()) { diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 4c0ef6fd2136..2eaa1329ff13 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -439,8 +439,7 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** Init Population **********/ -PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep @@ -476,9 +475,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p return ResultKind::kValid; } -PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state, - std::mt19937* rand_gen) const { +PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply( + SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { return PopulationGenerationRule::ResultKind::kValid; } @@ -520,8 +518,7 @@ PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPoli return PopulationGenerationRule::ResultKind::kValid; } -PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { std::function annotate_parallel; @@ -586,8 +583,7 @@ PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* polic return ResultKind::kValid; } -PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { std::vector& auto_unroll_configs = IsGPUTask(policy->search_task) ? auto_unroll_configs_gpu : auto_unroll_configs_cpu; @@ -698,8 +694,7 @@ PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* return ResultKind::kValid; } -PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { std::set multi_level_tiling_root_set; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { @@ -853,8 +848,7 @@ PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* pol return ResultKind::kValid; } -PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { int max_innermost_split_factor = GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); @@ -962,8 +956,7 @@ PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* pol return ResultKind::kInvalid; } -PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { // Extract all auto_unroll_max_step pragma steps. std::vector pragma_steps; @@ -1058,8 +1051,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNo return PopulationGenerationRule::ResultKind::kValid; } -PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, - State* state, +PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const { // This mutation rule only focuses on a case that parallel was added to // the outermost loop and the loop is generated by fusing other loops. diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 115378a96561..928efc518827 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -137,16 +137,17 @@ class PopulationGenerationRule { * \param state The state to apply this rule, update inplace. * \return The result of this rule, indicate if there's any valid state generated. */ - virtual ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const = 0; + virtual ResultKind Apply(SketchPolicyNode* policy, State* state, + std::mt19937* rand_gen) const = 0; /*! \brief The deconstructor */ virtual ~PopulationGenerationRule() = default; }; // A helper to define population initialization rules -#define DEFINE_INIT_POPULATION_RULE(rule_name) \ - class rule_name : public PopulationGenerationRule { \ - public: \ +#define DEFINE_INIT_POPULATION_RULE(rule_name) \ + class rule_name : public PopulationGenerationRule { \ + public: \ ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ }; @@ -185,10 +186,10 @@ class PopulationMutationRule : public PopulationGenerationRule { }; // A helper to define mutation rules used in the evolutionary search -#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \ - class rule_name : public PopulationMutationRule { \ - public: \ - explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ +#define DEFINE_MUTATE_POPULATION_RULE(rule_name) \ + class rule_name : public PopulationMutationRule { \ + public: \ + explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ }; diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index 24d3fd37810f..0bcc7539d0a8 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -418,13 +418,13 @@ void PruneInvalidState(const SearchTask& task, Array* states) { void SplitFactorizationMemo::ReadWriteLock::GetRead() { std::unique_lock lock(cv_mutex_); - cv_.wait(lock, [this](){ return !this->is_writing_; }); + cv_.wait(lock, [this]() { return !this->is_writing_; }); read_count_++; } void SplitFactorizationMemo::ReadWriteLock::GetWrite() { std::unique_lock lock(cv_mutex_); - cv_.wait(lock, [this](){ return this->read_count_ == 0 && !this->is_writing_; }); + cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; }); is_writing_ = true; } From 9360208316835390d7d0153c597077f8ab46f84e Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 23 Sep 2020 10:29:28 +0800 Subject: [PATCH 4/5] Update --- .../search_policy/sketch_policy.cc | 20 +++++++++---------- src/auto_scheduler/search_policy/utils.cc | 4 ++++ src/auto_scheduler/search_policy/utils.h | 13 ++++++++++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 785b19dee8d6..6aba8d091bf0 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -333,12 +333,12 @@ Array SketchPolicyNode::GenerateSketches() { } Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { - std::atomic fail_ct(0); + int fail_ct = 0; Array out_states; - std::vector rand_seeds; - rand_seeds.reserve(out_size); + std::vector rand_gens; + rand_gens.reserve(out_size); for (int i = 0; i < out_size; i++) { - rand_seeds.push_back(std::mt19937(rand_gen())); + rand_gens.push_back(std::mt19937(rand_gen())); } auto tic_begin = std::chrono::high_resolution_clock::now(); @@ -346,33 +346,31 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches std::vector temp_states(out_size); support::parallel_for(0, out_size - out_states.size(), - [this, &temp_states, &sketches, &rand_seeds, &fail_ct](int index) { + [this, &temp_states, &sketches, &rand_gens, &fail_ct](int index) { // Random choose a starting sketch // TODO(jcf94, merrymercy): Maybe choose sketches in different // possibility for they may have different potential on generating state // with better performance - State tmp_s = sketches[(rand_seeds[index])() % sketches.size()]; - + State tmp_s = sketches[(rand_gens[index])() % sketches.size()]; // Derivation rule based enumeration bool valid = true; for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s, &rand_seeds[index]) == + if (rule->Apply(this, &tmp_s, &rand_gens[index]) == PopulationGenerationRule::ResultKind::kInvalid) { valid = false; break; } } - if (valid) { temp_states[index] = std::move(tmp_s); - } else { - fail_ct++; } }); for (int i = 0; i < out_size; i++) { if (temp_states[i].defined()) { out_states.push_back(std::move(temp_states[i])); + } else { + fail_ct++; } } } diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index 0bcc7539d0a8..174ca3105a4b 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -418,12 +418,14 @@ void PruneInvalidState(const SearchTask& task, Array* states) { void SplitFactorizationMemo::ReadWriteLock::GetRead() { std::unique_lock lock(cv_mutex_); + // Wake up and get the mutex lock if there's no writing thread cv_.wait(lock, [this]() { return !this->is_writing_; }); read_count_++; } void SplitFactorizationMemo::ReadWriteLock::GetWrite() { std::unique_lock lock(cv_mutex_); + // Wake up and get the mutex lock if there's no reading or writing threads cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; }); is_writing_ = true; } @@ -431,6 +433,7 @@ void SplitFactorizationMemo::ReadWriteLock::GetWrite() { void SplitFactorizationMemo::ReadWriteLock::UnlockRead() { std::lock_guard lock(cv_mutex_); read_count_--; + // Notify the other blocked threads if this is the last reading thread if (read_count_ == 0) { cv_.notify_one(); } @@ -439,6 +442,7 @@ void SplitFactorizationMemo::ReadWriteLock::UnlockRead() { void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() { std::lock_guard lock(cv_mutex_); is_writing_ = false; + // Notify the other blocked threads cv_.notify_one(); } diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 5705296b0b51..6c0fb4c4dcf4 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -681,11 +681,24 @@ class SplitFactorizationMemo { private: void DfsEnumerate(int now, int remaining_length, int max_innermost_factor); + /*! + * \brief A simple implementation of read-write lock. + * The guarded block can be read by multiple threads at the same time, while other operations will + * be blocked if one thread is writing. + * \note Writing threads will wait until all reading threads have finshed. If there're multiple + * writing threads, the process order of them is not guaranteed. + */ class ReadWriteLock { public: + /*! \brief The method to get the read lock. One thread can process read if there's on other + * writing threads. */ void GetRead(); + /*! \brief The method to get the write lock. One thread can process write if there's on other + * reading or writing threads. */ void GetWrite(); + /*! \brief The method to release the read lock. */ void UnlockRead(); + /*! \brief The method to release the write lock. */ void UnlockWrite(); private: From 29edab6766d3032d9875bdb933e9bc6b7b4e8310 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 23 Sep 2020 10:39:43 +0800 Subject: [PATCH 5/5] Update --- src/auto_scheduler/search_policy/sketch_policy.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 6aba8d091bf0..a89fa4b0c77a 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -346,7 +346,7 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches std::vector temp_states(out_size); support::parallel_for(0, out_size - out_states.size(), - [this, &temp_states, &sketches, &rand_gens, &fail_ct](int index) { + [this, &temp_states, &sketches, &rand_gens](int index) { // Random choose a starting sketch // TODO(jcf94, merrymercy): Maybe choose sketches in different // possibility for they may have different potential on generating state