From 050db70a603c76b0c814eb786edb63fd81d16620 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 30 Oct 2020 18:42:33 -0700 Subject: [PATCH] Fix mutate auto unroll --- python/tvm/auto_scheduler/measure.py | 6 ++++-- .../search_policy/sketch_policy_rules.cc | 10 +++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 0121ddf37d03..642e8f85e86b 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -42,8 +42,6 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk @@ -481,6 +479,10 @@ def __init__( cooldown_interval=0.0, enable_cpu_cache_flush=False, ): + # pylint: disable=import-outside-toplevel + from tvm.rpc.tracker import Tracker + from tvm.rpc.server import Server + ctx = tvm.context("cuda", 0) if ctx.exist: cuda_arch = "sm_" + "".join(ctx.compute_version.split(".")) diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 1b6cc06a4c45..692ace103be3 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -998,10 +998,14 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p ICHECK(ps); // Mutate its value to a random candidates - auto val = std::to_string(auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]); + int val = auto_unroll_configs[(*rand_gen)() % auto_unroll_configs.size()]; StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps.Set(step_id, PragmaStep(ps->stage_id, ps->iter_id, - std::string("auto_unroll_max_step") + "$" + val)); + pstate->transform_steps.Set( + step_id, PragmaStep(ps->stage_id, ps->iter_id, + std::string("auto_unroll_max_step") + "$" + std::to_string(val))); + Stage new_stage = pstate->stages[ps->stage_id]; + new_stage.CopyOnWrite()->attrs.auto_unroll_max_step = val; + pstate->stages.Set(ps->stage_id, new_stage); return ResultKind::kValid; }