diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 46d606c628d9..f0d076e75f02 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -19,7 +19,6 @@ from . import compute_dag from . import dispatcher -from . import env from . import feature from . import loop_state from . import measure @@ -36,7 +35,6 @@ from .compute_dag import ComputeDAG from .cost_model import RandomModel, XGBModel from .dispatcher import DispatchContext, ApplyHistoryBest -from .env import enable_relay_integration, is_relay_integration_enabled from .measure import ( MeasureInput, MeasureResult, diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index 8822f3963f7b..19bae8622355 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -44,7 +44,7 @@ class DispatchContext(object): def __init__(self): self._old_ctx = DispatchContext.current - def query(self, target, workload_key): + def query(self, target, workload_key, has_complex_op, dag): """ Query the context to get the specific config for a workload. If cannot find the result inside this context, this function will query it @@ -56,6 +56,10 @@ def query(self, target, workload_key): The current target workload_key : str The workload key + has_complex_op: bool + Whether this workload has at least one complex op. + dag: ComputeDAG + The ComputeDAG of the workload. Returns ------- @@ -64,7 +68,7 @@ def query(self, target, workload_key): """ ret = self._query_inside(target, workload_key) if ret is None: - ret = self._old_ctx.query(target, workload_key) + ret = self._old_ctx.query(target, workload_key, has_complex_op, dag) return ret def update(self, target, workload_key, state): @@ -220,11 +224,11 @@ def _query_inside(self, target, workload_key): def update(self, target, workload_key, state): model = target.model - key = (model, workload) + key = (model, workload_key) self._best_user_defined[key] = state for k in target.keys: - key = (k, workload) + key = (k, workload_key) self._best_user_defined[key] = state @@ -237,21 +241,27 @@ class FallbackContext(DispatchContext): def __init__(self): super(FallbackContext, self).__init__() self.memory = {} - self.silent = False + + # Verbose level: + # 0: Completely silent. + # 1: Warning the missing configs for querying complex tasks. + # 2: Warning the missing configs for querying all tasks. + self.verbose = 1 # a set to prevent print duplicated message self.messages = set() - def query(self, target, workload_key): + def query(self, target, workload_key, has_complex_op, dag): key = (str(target), workload_key) if key in self.memory: return self.memory[key] - if not self.silent: + if self.verbose == 2 or (has_complex_op and self.verbose == 1): msg = ( - "Cannot find tuned schedules for target=%s, workload_key=%s. " - "A fallback schedule is used, " - "which may bring great performance regression." % (target, workload_key) + "Cannot find tuned schedules for target=%s, workload_key=%s, compute:\n%s" + "A fallback TOPI schedule is used, " + "which may bring great performance regression or even compilation failure." + % (target, workload_key, dag) ) if msg not in self.messages: self.messages.add(msg) diff --git a/python/tvm/auto_scheduler/env.py b/python/tvm/auto_scheduler/env.py deleted file mode 100644 index 95c7ccf971a2..000000000000 --- a/python/tvm/auto_scheduler/env.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The scope to store global environmental variables of the auto-scheduler""" - - -class AutoSchedulerGlobalScope(object): - """The global scope to store environmental variables of the auot-scheduler""" - - def __init__(self): - self.enable_relay_integration = False - - -GLOBAL_SCOPE = AutoSchedulerGlobalScope() - - -def is_relay_integration_enabled(): - """Return whether the relay integration is enabled - - Returns - ------- - enabled: bool - Whether the relay integration is enabled - """ - return GLOBAL_SCOPE.enable_relay_integration - - -def enable_relay_integration(new_value=True): - """Set the relay integration - - Parameters - --------- - new_value: bool = True - The new setting of relay integration - - Returns - ------- - old_value: bool - The old setting. - """ - old_value = GLOBAL_SCOPE.enable_relay_integration - GLOBAL_SCOPE.enable_relay_integration = new_value - return old_value diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0b0157c421b5..283d8bf7db45 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -25,7 +25,7 @@ import threading import tvm -from tvm import te, transform +from tvm import autotvm, te, transform from tvm.te.tensor import ComputeOp, PlaceholderOp from .compute_dag import ComputeDAG from .dispatcher import DispatchContext @@ -34,18 +34,26 @@ def call_all_topi_funcs(mod, params, target): - """Call all TOPI compute + schedule to extract tasks in a relay program""" + """Call all TOPI compute to extract auto_scheduler tasks in a Relay program""" # pylint: disable=import-outside-toplevel from tvm import relay from tvm.relay.backend import graph_runtime_codegen - with transform.PassContext(opt_level=3): + # Turn off AutoTVM config not found warnings + old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + + with transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): opt_mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(opt_mod["main"]) + autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent + -def extract_tasks(mod, params, target, target_host=None, hardware_params=None): +def extract_tasks( + mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False +): """Extract tuning tasks from a relay program. Parameters @@ -60,6 +68,8 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None): The host compilation target hardware_params : Optional[HardwareParams] Hardware parameters used for the search tasks + include_simple_tasks: bool + Whether to extract simple tasks that do not include complicated ops. Returns ------- @@ -77,7 +87,9 @@ def extract_tasks(mod, params, target, target_host=None, hardware_params=None): target_host = tvm.target.Target(target_host) # Run the compiler to collect all TOPI calls during compilation. - env = TracingEnvironment(TracingMode.EXTRACT_TASK) + env = TracingEnvironment( + TracingMode.EXTRACT_TASK if include_simple_tasks else TracingMode.EXTRACT_COMPLEX_TASK_ONLY + ) with env: # Wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool @@ -109,7 +121,8 @@ class TracingMode: """Two modes for tracing""" EXTRACT_TASK = 0 # trace all topi calls to extract tasks - PREPARE_LAYOUT_REWRITE = 1 # trace topi calls to prepare layout rewrite + EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops + PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite class TracingEnvironment: @@ -181,11 +194,8 @@ def traverse(t): return inputs + list(outs), has_layout_free -# The suffix of implementations that use the auto-scheduler in the OpStrategy. -auto_schedule_impl_suffix = ".auto_scheduler" - - -def auto_schedule_topi(outs): +@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute") +def auto_schedule_topi(outs, has_complex_op): """Use auto-scheduler to schedule any topi compute function. Note: This is used internally for relay integration. Do @@ -195,35 +205,40 @@ def auto_schedule_topi(outs): ---------- outs: List[Tensor] The output tensors of topi compute functions + has_complex_op: bool + Whether the topi compute function includes at least one complex op. Returns ------- - sch: te.Schedule - A topi schedule function + sch: Optional[te.Schedule] + A tuned schedule or none (if not tuned) in the final build mode; + An initial schdule in the tracing mode. """ # pylint: disable=import-outside-toplevel from tvm import relay io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) key = register_workload_tensors(io_tensors) + if key is None: # skip this compute if failed to register the workload + return None # only enable layout rewrite for cpu backend enable_layout_rewrite = "cpu" in tvm.target.Target.current().keys env = TracingEnvironment.current if env is None: # in the final build mode - state = DispatchContext.current.query(tvm.target.Target.current(), key) + dag = ComputeDAG(io_tensors) + state = DispatchContext.current.query(tvm.target.Target.current(), key, has_complex_op, dag) if state is None: - if "gpu" in tvm.target.Target.current().keys: - raise RuntimeError("Cannot compile for GPU targets if no valid schedule is found.") - return te.create_schedule([x.op for x in outs]) + return None - dag = ComputeDAG(io_tensors) schedule, _ = dag.apply_steps_from_state(state) - elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode - engine = relay.backend.compile_engine.get() - ccache_key = engine.get_current_ccache_key() - env.add_workload_key(key, ccache_key) + elif env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]: + # in the task extraction mode + if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK: + engine = relay.backend.compile_engine.get() + ccache_key = engine.get_current_ccache_key() + env.add_workload_key(key, ccache_key) schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # todo(merrymercy, minminsun): port layout rewrite diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 8a42c5f9b83a..6a4809b1796c 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """ Workload registration and serialization. @@ -29,12 +30,14 @@ When we need the dag, we decode the string and call the function, which will return the dag. """ +import logging import pickle import json import tvm._ffi from .utils import serialize_args, deserialize_args, get_func_name +logger = logging.getLogger("auto_scheduler") # Global workload function and hash key registry # It stores two types of workload: @@ -105,13 +108,18 @@ def register_workload_tensors(tensors): Returns ------- - key: str - The workload key + key: Optional[str] + The workload key, or None if failed to create a compute DAG. """ # pylint: disable=import-outside-toplevel from .compute_dag import ComputeDAG - key = ComputeDAG(tensors).hash_key() + try: + key = ComputeDAG(tensors).hash_key() + except tvm.error.TVMError as err: + logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err)) + return None + WORKLOAD_FUNC_REGISTRY[key] = tensors return json.dumps((key,)) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index d874732d6fa0..28f2ac6d489b 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -21,7 +21,7 @@ import logging import numpy as np import tvm -from tvm import te, autotvm, auto_scheduler +from tvm import te, autotvm from tvm.runtime import Object from tvm.support import libinfo from tvm.target import Target @@ -196,25 +196,13 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) outs = best_plevel_impl.compute(attrs, inputs, out_type) return best_plevel_impl, outs - # If auto-scheduler is enabled for Relay, always prefer auto-scheduler - if auto_scheduler.is_relay_integration_enabled(): - auto_scheduler_impls = [] - for impl in all_impls: - if impl.name.endswith(auto_scheduler.relay_integration.auto_schedule_impl_suffix): - auto_scheduler_impls.append(impl) - - if auto_scheduler_impls: - assert len(auto_scheduler_impls) == 1 - impl = auto_scheduler_impls[0] - outs = impl.compute(attrs, inputs, out_type) - return impl, outs - # Otherwise, try autotvm templates outputs = {} workloads = {} best_autotvm_impl = None best_cfg = None dispatch_ctx = autotvm.task.DispatchContext.current + old_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = True for impl in all_impls: outs = impl.compute(attrs, inputs, out_type) @@ -232,7 +220,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg - autotvm.GLOBAL_SCOPE.silent = False + autotvm.GLOBAL_SCOPE.silent = old_silent if best_autotvm_impl: # The best autotvm implementation definitely doesn't use fallback config @@ -251,7 +239,10 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) "is used, which may bring great performance regression." % (target, workloads[best_plevel_impl]) ) - if msg not in autotvm.task.DispatchContext.warning_messages: + if ( + not autotvm.env.GLOBAL_SCOPE.silent + and msg not in autotvm.task.DispatchContext.warning_messages + ): autotvm.task.DispatchContext.warning_messages.add(msg) autotvm_logger.warning(msg) logger.info( @@ -300,7 +291,6 @@ def lower_call(call, inputs, target): best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target) else: # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes. - # Currently, we just use the implementation with highest plevel best_impl, outputs = select_implementation( op, call.attrs, inputs, ret_type, target, use_autotvm=False ) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 35bd8e6d3d4d..cba97c43b25a 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -23,6 +23,7 @@ from tvm.ir import IRModule +from tvm.ir.transform import PassContext from tvm.tir import expr as tvm_expr from .. import nd as _nd, autotvm from ..target import Target @@ -123,8 +124,20 @@ def build(self, mod, target=None, target_host=None, params=None): # Setup the params. if params: self._set_params(params) - # Build the IR module + + # Build the IR module. If auto_scheduler is not enabled, + # then use the TOPI-defined schedule. + use_auto_scheduler = PassContext.current().config.get( + "relay.backend.use_auto_scheduler", False + ) + + # Turn off AutoTVM config not found warnings if auto_scheduler is enabled. + old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler + self._build(mod, target, target_host) + autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent + # Get artifacts graph_json = self.get_json() mod = self.get_module() diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index fa420c4e71a3..d4d20b3ebc4a 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -18,7 +18,6 @@ """The base node types for the Relay language.""" import tvm._ffi import tvm.ir -from tvm.auto_scheduler.relay_integration import auto_schedule_topi, auto_schedule_impl_suffix from tvm.driver import lower, build from tvm.target import get_native_generic_func, GenericFunc from tvm.runtime import Object @@ -144,30 +143,6 @@ def add_implementation(self, compute, schedule, name="default", plevel=10): """ _OpStrategyAddImplementation(self, compute, schedule, name, plevel) - def add_auto_scheduler(self, compute, name, plevel=10): - """Add an implementation using the auto-scheduler. - - Parameters - ---------- - compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type) - -> List[Tensor] - The compute function. - - name : str - The name of implementation. - - plevel : int - The priority level of implementation. - """ - - def wrap_schedule(attrs, outs, target): - with target: - return auto_schedule_topi(outs) - - self.add_implementation( - compute, wrap_schedule, name=name + auto_schedule_impl_suffix, plevel=plevel - ) - def _wrap_default_fstrategy(compute, schedule, name): def _fstrategy(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index f4ce61b8fa39..105f50116c3e 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import from tvm import topi import tvm +from tvm.ir.transform import PassContext from tvm.te import SpecializedCondition from tvm.contrib import nvcc from tvm._ffi import get_global_func @@ -142,10 +143,6 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): name="conv2d_nchw_winograd.cuda", plevel=5, ) - - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.conv2d_nchw), name="conv2d_nchw" - ) elif layout == "HWCN": assert kernel_layout == "HWIO" strategy.add_implementation( @@ -221,13 +218,15 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ) # register auto-scheduler implementations - if judge_winograd_auto_scheduler: - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), name="conv2d_nhwc.winograd" - ) - else: - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.conv2d_nhwc), name="conv2d_nhwc" + use_auto_scheduler = PassContext.current().config.get( + "relay.backend.use_auto_scheduler", False + ) + if use_auto_scheduler and judge_winograd_auto_scheduler: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), + wrap_topi_schedule(tvm.te.create_schedule), + name="conv2d_nhwc.winograd", + plevel=15, ) elif layout == "HWNC": @@ -286,11 +285,6 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.cuda", ) - - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.cuda", - ) elif layout == "NHWC": assert kernel_layout == "HWOI" strategy.add_implementation( @@ -298,11 +292,6 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.cuda", ) - - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.cuda", - ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d @@ -459,11 +448,13 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda", ) - # register auto-scheduler implementations - strategy.add_auto_scheduler( - wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), - name="conv2d_nhwc_winograd_without_weight_transform", - ) + if PassContext.current().config.get("relay.backend.use_auto_scheduler", False): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), + wrap_topi_schedule(tvm.te.create_schedule), + name="conv2d_nhwc_winograd_without_weight_transform", + plevel=15, + ) else: raise RuntimeError( "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) @@ -553,11 +544,6 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): name="conv3d_ncdhw_winograd.cuda", plevel=5, ) - - strategy.add_auto_scheduler( - wrap_compute_conv3d(topi.nn.conv3d_ncdhw), - name="conv3d_ncdhw.cuda", - ) else: # layout == "NDHWC": strategy.add_implementation( wrap_compute_conv3d(topi.cuda.conv3d_ndhwc), @@ -581,11 +567,6 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): plevel=20, ) - strategy.add_auto_scheduler( - wrap_compute_conv3d(topi.nn.conv3d_ndhwc), - name="conv3d_ndhwc.cuda", - ) - if target.kind.name == "cuda" and "cudnn" in target.libs: strategy.add_implementation( wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True), @@ -681,11 +662,6 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): name="dense_small_batch.cuda", ) - strategy.add_auto_scheduler( - wrap_compute_dense(topi.nn.dense), - name="dense", - ) - with SpecializedCondition(b >= 32): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_large_batch), diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 3a58d40cb847..ad6635de0116 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -19,7 +19,7 @@ import logging import tvm -from tvm import te, relay, autotvm, auto_scheduler +from tvm import te, relay, autotvm from .. import nn from ..utils import get_const_tuple @@ -52,9 +52,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): # The best implementation is not an AutoTVM template. # It may be from the auto-scheduler - if impl.name == ( - "conv2d_nhwc.winograd" + auto_scheduler.relay_integration.auto_schedule_impl_suffix - ): + if impl.name.find("winograd") != -1: if dilation != (1, 1): logger.warning("Does not support weight pre-transform for dilated convolution.") return None diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index c8327de94232..1559d7edf35f 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -99,7 +99,12 @@ Array GetShape(const Array& shape) { class ScheduleGetter : public backend::MemoizedExprTranslator> { public: explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) {} + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = transform::PassContext::Current() + ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) + .value(); + } CachedFunc Create(const Function& prim_func) { auto cache_node = make_object(); @@ -145,11 +150,27 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> tensor_outs.push_back(tensor); } } + te::Schedule schedule; // No need to register schedule for device copy op. if (anchor_attrs_.as() == nullptr) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + bool has_complex_op = anchor_op_pattern_ >= kCommReduce; + ObjectRef obj = (*fauto_schedule)(tensor_outs, has_complex_op); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + + // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined()) { + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } for (const auto& scalar : scalars_) { if (schedule->Contain(scalar)) { schedule[scalar].compute_inline(); @@ -228,9 +249,9 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } int op_pattern = fpattern[op]; - if (op_pattern >= kCommReduce) { + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Two complicated op in a primitive function " + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" << " anchor=" << anchor_op_ << " current=" << op; } if (op_pattern >= anchor_op_pattern_) { @@ -295,6 +316,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> OpImplementation anchor_implementation_; std::ostringstream readable_name_stream_; Array scalars_; + bool use_auto_scheduler_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; @@ -812,6 +834,8 @@ CompileEngine& CompileEngine::Global() { return *inst; } +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); + TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array outputs, OpImplementation impl) { return LoweredOutput(outputs, impl); diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index 4ca2ddb3cf10..1899f9521013 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Test task extraction for auto-scheduler""" +import pytest + import tvm.relay.testing import tvm.testing from tvm import auto_scheduler, relay @@ -45,7 +47,6 @@ def get_network(name, batch_size=1, layout="NHWC"): ) elif name == "winograd-test": input_shape = [1, 7, 7, 64] - output_shape = input_shape data = relay.var("data", shape=input_shape, dtype="float32") net = relay.testing.layers.conv2d( @@ -96,7 +97,6 @@ def get_network(name, batch_size=1, layout="NHWC"): @tvm.testing.requires_cuda def test_task_extraction_cuda(): - auto_scheduler.enable_relay_integration() target = tvm.target.Target("cuda") mod, params = get_network("mlp") @@ -108,24 +108,122 @@ def test_task_extraction_cuda(): mod, params = get_network("resnet-18", layout=layout) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - assert len(tasks) == 21 - assert sum(task_weights) == 22 + assert len(tasks) == 24 + assert sum(task_weights) == 25 mod, params = get_network("mobilenet", layout=layout) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - assert len(tasks) == 20 - assert sum(task_weights) == 28 + assert len(tasks) == 22 + assert sum(task_weights) == 30 for layout in ["NCDHW", "NDHWC"]: mod, params = get_network("resnet3d-18", layout=layout) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - assert len(tasks) == 21 - assert sum(task_weights) == 22 + assert len(tasks) == 23 + assert sum(task_weights) == 24, sum(task_weights) + + +def test_task_extraction(): + ishape = (1, 3, 224, 224) + w1shape = (32, 3, 3, 3) + w2shape = (32, 32, 3, 3) + dtype = "float32" + target = tvm.target.Target("llvm") + + def get_func(): + data = relay.var("data", shape=(ishape), dtype=dtype) + weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype) + weight2 = relay.var("weight2", shape=(w2shape), dtype=dtype) + + conv2d = relay.nn.conv2d(data, weight1, kernel_size=(3, 3), padding=(1, 1)) + relu = relay.nn.relu(conv2d) + conv2d = relay.nn.conv2d(relu, weight2, kernel_size=(3, 3), padding=(1, 1)) + out = relay.nn.relu(conv2d) + return relay.Function([data, weight1, weight2], out) + + def get_fused_func(): + data = relay.var("data", shape=(ishape), dtype=dtype) + weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype) + weight2 = relay.var("weight2", shape=(w2shape), dtype=dtype) + + fused_func = get_func() + + # Set to primitive to keep fuse_ops untouch. + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + call = relay.Call(fused_func, [data, weight1, weight2]) + return relay.Function([data, weight1, weight2], call) + + def get_simple_func(): + data = relay.var("data", relay.TensorType((1, 2, 3), "float32")) + out = relay.image.affine_grid(data, (150, 150)) + return relay.Function([data], out) + + def get_func_with_unsupported_op(): + def get_postproc_func(): + data = relay.var("data", shape=((1, 3, 6)), dtype=dtype) + out = relay.nn.relu(data) + func = relay.Function([data], out) + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return func + + cls_prob = relay.var("cls_prob", relay.ty.TensorType((1, 3, 3), "float32")) + loc_pred = relay.var("loc_pred", relay.ty.TensorType((1, 3 * 4), "float32")) + anchors = relay.var("anchors", relay.ty.TensorType((1, 3, 4), "float32")) + + mtl = relay.vision.multibox_transform_loc( + cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors + ) + nms = relay.vision.non_max_suppression(mtl[0], mtl[1], mtl[0], return_indices=False) + out = relay.Call(get_postproc_func(), [nms]) + return relay.Function([cls_prob, loc_pred, anchors], out) + + func = get_func() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], None, target) + + # Relay FuseOps puts two conv2ds to separate functions and results in two tasks. + assert len(tasks) == 2 + assert len(task_weights) == 2 + + func = get_fused_func() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], None, target) + + # By setting the function to primitive, Relay FuseOps will not break it and result in one task. + assert len(tasks) == 1 + assert len(task_weights) == 1 + + func = get_simple_func() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], None, target) - auto_scheduler.enable_relay_integration(False) + # The Relay function without complex ops will not form a task by default. + assert len(tasks) == 0 + assert len(task_weights) == 0 + + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], None, target, include_simple_tasks=True + ) + + # Every Relay function becomes a task regardless what ops in its body. + assert len(tasks) == 1 + assert len(task_weights) == 1 + + # Func1 (with NMS) -> Func2 (injective). + func = get_func_with_unsupported_op() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], None, target, include_simple_tasks=True + ) + + # The function with NMS should fail, but the other function with ReLU should be a task. + assert len(tasks) == 1 + assert len(task_weights) == 1 if __name__ == "__main__": test_task_extraction_cuda() + test_task_extraction() diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 089f51cdf047..d42373c86626 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -24,8 +24,6 @@ def tune_network(network, target): - auto_scheduler.enable_relay_integration() - # Extract tasks mod, params = get_network(network) target = tvm.target.Target(target) @@ -50,15 +48,15 @@ def tune_network(network, target): # Compile with the history best with auto_scheduler.ApplyHistoryBest(log_file): - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): lib = relay.build(mod, target=target, params=params) # Todo(merrymercy): when the cpu backend is upstreamed, do the following things: # 1. compile without history to test the fallback mechanism # 2. check the correctness of layout rewrite / winograd pre-transform - auto_scheduler.enable_relay_integration(False) - @tvm.testing.requires_cuda def test_tuning_cuda(): diff --git a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json index 37a129844390..41b6c0e554ed 100644 --- a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json +++ b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1.json @@ -1,23 +1,26 @@ # Provide valid schedules for resnet-18. # This is used to run the tutorial on the documentation web server. -{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [2, 5, 2, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 6], ["FU", 6, [0, 1]], ["AN", 6, 0, 5], ["FU", 6, [1, 2]], ["AN", 6, 1, 4], ["FU", 6, [2, 3]], ["AN", 6, 2, 6], ["FU", 3, [0, 1]], ["SP", 3, 0, 2, [1], 1], ["AN", 3, 1, 2], ["FFSP", 3, 0, [1, 0], 1, 1], ["AN", 3, 1, 6], ["FU", 1, [0, 1]], ["SP", 1, 0, 1, [1], 1], ["AN", 1, 1, 2], ["FFSP", 1, 0, [1, 0], 1, 1], ["AN", 1, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"]]]], "r": [[7.2561e-05], 0, 1.93892, 1605186325], "v": "v0.3"} -{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [2, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 4, 1], 1], ["SP", 6, 10, 16, [4, 2, 1, 1], 1], ["SP", 6, 15, 512, [1, 16, 1, 1], 1], ["SP", 6, 20, 512, [2, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 13, 3], ["FSP", 7, 4, 14, 3], ["FSP", 7, 8, 15, 3], ["FSP", 7, 12, 16, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 19, [0, 1, 2, 3]], ["SP", 19, 0, 25088, [32], 1], ["AN", 19, 0, 5], ["AN", 19, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [16, 15, 14, 13], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [16, 15, 14, 13], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[0.000195701], 0, 2.67988, 1605186412], "v": "v0.3"} -{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 2, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 16, [1, 16, 1, 1], 1], ["SP", 6, 15, 512, [2, 1, 4, 1], 1], ["SP", 6, 20, 512, [32, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [16], 1], ["SP", 4, 4, 512, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 25088, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[0.000162045], 0, 2.32406, 1605186499], "v": "v0.3"} -{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 1, 8, 1], 1], ["SP", 6, 15, 512, [2, 64, 1, 1], 1], ["SP", 6, 20, 512, [16, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [4], 1], ["SP", 4, 4, 512, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 25088, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 128, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [2], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[0.000102843], 0, 2.42044, 1605186574], "v": "v0.3"} -{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [49], 1], ["SP", 8, 4, 256, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 2, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 49, [1, 7, 1, 7], 1], ["SP", 6, 15, 256, [1, 8, 1, 2], 1], ["SP", 6, 20, 256, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 50176, [32], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 112, [2], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[9.61516e-05], 0, 2.69389, 1605186690], "v": "v0.3"} -{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 2], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 1, 1], 1], ["SP", 6, 15, 256, [1, 4, 8, 1], 1], ["SP", 6, 20, 256, [1, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [49], 1], ["SP", 4, 4, 256, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 50176, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 2, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000156995], 0, 2.11666, 1605186772], "v": "v0.3"} -{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [49], 1], ["SP", 8, 4, 256, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 4, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [4, 2, 1, 1], 1], ["SP", 6, 20, 256, [1, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [7], 1], ["SP", 4, 4, 256, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 50176, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [4], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000131082], 0, 2.24166, 1605186844], "v": "v0.3"} -{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 128, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [2, 1, 1, 1], 1], ["SP", 6, 5, 4, [2, 2, 1, 1], 1], ["SP", 6, 10, 196, [2, 7, 2, 1], 1], ["SP", 6, 15, 128, [1, 32, 1, 4], 1], ["SP", 6, 20, 128, [4, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 128, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 100352, [16], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [16], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000166673], 0, 2.43832, 1605186977], "v": "v0.3"} -{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 2, 1], 1], ["SP", 6, 5, 4, [1, 2, 1, 1], 1], ["SP", 6, 10, 196, [2, 49, 1, 1], 1], ["SP", 6, 15, 128, [1, 1, 4, 8], 1], ["SP", 6, 20, 128, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [49], 1], ["SP", 4, 4, 128, [8], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 100352, [32], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 1024, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000108367], 0, 3.89975, 1605187058], "v": "v0.3"} -{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 2, 2], 1], ["SP", 6, 10, 196, [1, 4, 7, 1], 1], ["SP", 6, 15, 128, [2, 16, 2, 1], 1], ["SP", 6, 20, 128, [4, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 128, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 100352, [32], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 112, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[8.0137e-05], 0, 2.28468, 1605187134], "v": "v0.3"} -{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 2, 2, 1], 1], ["SP", 3, 10, 28, [1, 14, 1, 1], 1], ["SP", 3, 15, 128, [1, 2, 16, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 64, [1, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 384, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 24, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$16"]]]], "r": [[9.74847e-05], 0, 1.97907, 1605187182], "v": "v0.3"} -{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 2, 1], 1], ["SP", 6, 10, 196, [1, 7, 14, 1], 1], ["SP", 6, 15, 64, [2, 4, 2, 1], 1], ["SP", 6, 20, 64, [1, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [1], 1], ["SP", 4, 4, 64, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 200704, [32], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [8], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 56, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[8.09982e-05], 0, 3.52776, 1605187295], "v": "v0.3"} -{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 1], 1], ["SP", 6, 5, 6, [1, 3, 1, 1], 1], ["SP", 6, 10, 196, [1, 14, 1, 2], 1], ["SP", 6, 15, 64, [1, 2, 8, 2], 1], ["SP", 6, 20, 64, [4, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [4], 1], ["SP", 4, 4, 64, [4], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 200704, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 512, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[7.15745e-05], 0, 3.73944, 1605187404], "v": "v0.3"} -{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 2, 3, 1], 1], ["SP", 6, 10, 196, [1, 4, 1, 7], 1], ["SP", 6, 15, 64, [1, 8, 2, 1], 1], ["SP", 6, 20, 64, [2, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [4], 1], ["SP", 4, 4, 64, [16], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 200704, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 144, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 252, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[6.79478e-05], 0, 5.10446, 1605187506], "v": "v0.3"} -{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [2, 14, 1, 1], 1], ["SP", 3, 10, 112, [1, 8, 2, 1], 1], ["SP", 3, 15, 64, [2, 2, 2, 2], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [7, 1], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 1176, [21], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 189, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[5.53397e-05], 0, 2.2607, 1605187548], "v": "v0.3"} -{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [2, 28, 1, 1], 1], ["SP", 3, 10, 56, [1, 2, 2, 1], 1], ["SP", 3, 15, 64, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [1, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 16, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[8.11163e-06], 0, 1.93343, 1605187596], "v": "v0.3"} -{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [2, 2, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 1], 1], ["SP", 3, 15, 128, [2, 8, 4, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [4, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 256, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 96, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$64"]]]], "r": [[1.40126e-05], 0, 1.82931, 1605187624], "v": "v0.3"} -{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 7, 1, 2], 1], ["SP", 3, 10, 14, [1, 1, 1, 2], 1], ["SP", 3, 15, 256, [4, 64, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 32, [16], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 324, [6], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$64"]]]], "r": [[2.35384e-05], 0, 1.78652, 1605187663], "v": "v0.3"} -{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 32, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 64], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 16, [4], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 4, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$0"]]]], "r": [[3.09105e-05], 0, 1.85659, 1605187687], "v": "v0.3"} -{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 7, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 8, 2, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 96, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 48, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000154153], 0, 2.18601, 1605187723], "v": "v0.3"} -{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 2], 1], ["SP", 3, 10, 14, [1, 14, 1, 1], 1], ["SP", 3, 15, 256, [1, 32, 1, 2], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 128, [2, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 144, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 72, [24], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[5.97747e-05], 0, 2.13918, 1605187759], "v": "v0.3"} +{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 4, 1, 1000, [50], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$0"], ["PR", 3, 0, "auto_unroll_max_step$1024"]]]], "r": [[4.54041e-06], 0, 1.27943, 1605490839], "v": "v0.3"} +{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 4], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 6], ["FU", 6, [0, 1]], ["AN", 6, 0, 5], ["FU", 6, [1, 2]], ["AN", 6, 1, 4], ["FU", 6, [2, 3]], ["AN", 6, 2, 6], ["FU", 3, [0, 1]], ["SP", 3, 0, 4, [4], 1], ["AN", 3, 1, 2], ["FFSP", 3, 0, [1, 0], 1, 1], ["AN", 3, 1, 6], ["FU", 1, [0, 1]], ["SP", 1, 0, 4, [2], 1], ["AN", 1, 1, 2], ["FFSP", 1, 0, [1, 0], 1, 1], ["AN", 1, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.03431e-05], 0, 2.09134, 1605490924], "v": "v0.3"} +{"i": [["[\"7de313da0ca29a8c63f647791692430d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [8], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[5.51259e-06], 0, 1.30207, 1605491060], "v": "v0.3"} +{"i": [["[\"944921d3fd999ba7aa9ffe5a592a9241\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [56], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$512"]]]], "r": [[2.24305e-05], 0, 1.60311, 1605493879], "v": "v0.3"} +{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [2, 1, 1, 8], 1], ["SP", 3, 10, 112, [1, 8, 1, 1], 1], ["SP", 3, 15, 64, [2, 16, 2, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [1, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 294, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 441, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[7.63468e-05], 0, 2.59544, 1605493932], "v": "v0.3"} +{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [7, 4, 2, 1], 1], ["SP", 3, 10, 56, [1, 2, 2, 1], 1], ["SP", 3, 15, 64, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [8, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 32, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 128, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.26775e-05], 0, 1.94247, 1605494103], "v": "v0.3"} +{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 2], 1], ["SP", 3, 10, 28, [1, 1, 2, 1], 1], ["SP", 3, 15, 128, [1, 16, 1, 8], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 128, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 144, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[1.13004e-05], 0, 1.86312, 1605494224], "v": "v0.3"} +{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 2, 1], 1], ["SP", 3, 10, 14, [1, 14, 1, 1], 1], ["SP", 3, 15, 256, [1, 8, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 64, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 48, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[1.29425e-05], 0, 1.70493, 1605494303], "v": "v0.3"} +{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [2, 16, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 16, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 16, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[2.04683e-05], 0, 1.80217, 1605494406], "v": "v0.3"} +{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 1, 7], 1], ["SP", 3, 10, 28, [1, 4, 1, 1], 1], ["SP", 3, 15, 128, [1, 32, 2, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 64, [1, 4], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 72, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 348, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[4.93528e-05], 0, 1.74125, 1605498773], "v": "v0.3"} +{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [8], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 2, 1, 1], 1], ["SP", 6, 10, 16, [2, 1, 8, 1], 1], ["SP", 6, 15, 512, [1, 32, 2, 1], 1], ["SP", 6, 20, 512, [8, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [16], 1], ["SP", 4, 4, 512, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 25088, [49], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 256, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.000129562], 0, 3.40317, 1605500470], "v": "v0.3"} +{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 1, 7], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 16, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 256, [4, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 288, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 1440, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[7.57476e-05], 0, 2.59558, 1605501054], "v": "v0.3"} +{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 2, 2, 1], 1], ["SP", 6, 10, 196, [4, 1, 1, 7], 1], ["SP", 6, 15, 128, [2, 32, 1, 1], 1], ["SP", 6, 20, 128, [2, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [14], 1], ["SP", 4, 4, 128, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 100352, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [49], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 56, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$1024"]]]], "r": [[6.77244e-05], 0, 2.67201, 1605501438], "v": "v0.3"} +{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 128, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 7, 1], 1], ["SP", 6, 15, 128, [8, 16, 1, 1], 1], ["SP", 6, 20, 128, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [14], 1], ["SP", 4, 4, 128, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 100352, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 8, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 8, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [32], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[6.23875e-05], 0, 1.93274, 1605501606], "v": "v0.3"} +{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 2, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 1, 4], 1], ["SP", 6, 15, 64, [2, 16, 1, 1], 1], ["SP", 6, 20, 64, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [28], 1], ["SP", 4, 4, 64, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 200704, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$512"], ["PR", 8, 0, "auto_unroll_max_step$64"], ["PR", 11, 0, "auto_unroll_max_step$512"]]]], "r": [[6.65448e-05], 0, 2.94376, 1605501803], "v": "v0.3"} +{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 2], 1], ["SP", 3, 10, 14, [2, 7, 1, 1], 1], ["SP", 3, 15, 256, [1, 32, 2, 1], 1], ["SP", 3, 20, 3, [1, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 6, 0, 1, 3], ["FSP", 6, 4, 2, 3], ["FSP", 6, 8, 3, 3], ["FSP", 6, 12, 4, 3], ["RE", 6, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 6, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 8, [0, 1, 2, 3]], ["AN", 8, 0, 5], ["FU", 8, [1, 2, 3, 4]], ["AN", 8, 1, 4], ["FU", 8, [2, 3, 4, 5]], ["AN", 8, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 192, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 240, [1], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$1024"]]]], "r": [[6.31245e-05], 0, 1.9322, 1605501903], "v": "v0.3"} +{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 2, 4, 2], 1], ["SP", 6, 15, 512, [2, 32, 1, 1], 1], ["SP", 6, 20, 512, [16, 1], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 25088, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 128, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[0.000143154], 0, 2.20107, 1605502293], "v": "v0.3"} +{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [8, 2, 2, 2], 1], ["SP", 6, 20, 256, [2, 16], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [1], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 50176, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 32, [4], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$64"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$16"]]]], "r": [[0.000115017], 0, 3.89122, 1605502608], "v": "v0.3"} +{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [4], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [2, 1, 1, 1], 1], ["SP", 6, 10, 196, [1, 1, 2, 14], 1], ["SP", 6, 15, 128, [1, 32, 1, 2], 1], ["SP", 6, 20, 128, [1, 8], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [1], 1], ["SP", 4, 4, 128, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 100352, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 25088, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 224, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 25088, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.20936e-05], 0, 3.36582, 1605502968], "v": "v0.3"} +{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 49, 1, 1], 1], ["SP", 6, 15, 256, [8, 1, 2, 2], 1], ["SP", 6, 20, 256, [1, 32], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [1], 1], ["SP", 4, 4, 256, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 50176, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 128, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [2], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[0.000122349], 0, 4.2774, 1605503135], "v": "v0.3"} +{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 1, 7], 1], ["SP", 6, 15, 256, [8, 4, 1, 1], 1], ["SP", 6, 20, 256, [1, 16], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 11, 3], ["FSP", 7, 4, 12, 3], ["FSP", 7, 8, 13, 3], ["FSP", 7, 12, 14, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 49, [7], 1], ["SP", 4, 4, 256, [2], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 15, [0, 1, 2, 3]], ["SP", 15, 0, 50176, [64], 1], ["AN", 15, 0, 5], ["AN", 15, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [64], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 256, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [14, 13, 12, 11], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [14, 13, 12, 11], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.9277e-05], 0, 3.07064, 1605503350], "v": "v0.3"} +{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 1], 1], ["SP", 6, 5, 6, [1, 2, 1, 1], 1], ["SP", 6, 10, 196, [7, 7, 1, 4], 1], ["SP", 6, 15, 64, [1, 8, 4, 1], 1], ["SP", 6, 20, 64, [4, 2], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 10, 3], ["FSP", 7, 4, 11, 3], ["FSP", 7, 8, 12, 3], ["FSP", 7, 12, 13, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [28], 1], ["SP", 4, 4, 64, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 14, [0, 1, 2, 3]], ["SP", 14, 0, 200704, [64], 1], ["AN", 14, 0, 5], ["AN", 14, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 32, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [13, 12, 11, 10], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 16, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [13, 12, 11, 10], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$16"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$0"]]]], "r": [[7.64176e-05], 0, 5.45091, 1605503568], "v": "v0.3"} +{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [1], 1], ["SP", 8, 4, 64, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 1], 1], ["SP", 6, 10, 196, [14, 7, 1, 2], 1], ["SP", 6, 15, 64, [1, 16, 1, 2], 1], ["SP", 6, 20, 64, [1, 4], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 12, 3], ["FSP", 7, 4, 13, 3], ["FSP", 7, 8, 14, 3], ["FSP", 7, 12, 15, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 196, [2], 1], ["SP", 4, 4, 64, [64], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 17, [0, 1, 2, 3]], ["SP", 17, 0, 200704, [64], 1], ["AN", 17, 0, 5], ["AN", 17, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 12544, [32], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 16, [4], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [15, 14, 13, 12], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 4, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [15, 14, 13, 12], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 12544, [64], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$512"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[7.60496e-05], 0, 3.00771, 1605503805], "v": "v0.3"} +{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [16], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [1, 1, 4, 4], 1], ["SP", 6, 15, 512, [1, 64, 1, 1], 1], ["SP", 6, 20, 512, [1, 32], 1], ["RE", 6, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 21, 3, 8, 13, 18, 22, 4, 9, 14, 19]], ["FSP", 7, 0, 13, 3], ["FSP", 7, 4, 14, 3], ["FSP", 7, 8, 15, 3], ["FSP", 7, 12, 16, 3], ["RE", 7, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 6, 7, 11], ["CHR", 5, "shared", [6]], ["CA", 6, 7, 12], ["CHR", 4, "shared", [7]], ["CA", 5, 8, 12], ["AN", 4, 0, 1], ["AN", 4, 1, 1], ["SP", 4, 2, 16, [8], 1], ["SP", 4, 4, 512, [32], 1], ["AN", 4, 6, 1], ["AN", 4, 7, 1], ["RE", 4, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 3], ["CA", 2, 4, 3], ["CI", 1], ["FU", 19, [0, 1, 2, 3]], ["SP", 19, 0, 25088, [32], 1], ["AN", 19, 0, 5], ["AN", 19, 1, 6], ["FU", 11, [0, 1, 2, 3]], ["SP", 11, 0, 8192, [16], 1], ["AN", 11, 0, 5], ["AN", 11, 1, 6], ["FU", 9, [0, 1, 2, 3]], ["AN", 9, 0, 5], ["FU", 9, [1, 2, 3, 4]], ["AN", 9, 1, 4], ["FU", 9, [2, 3, 4, 5]], ["AN", 9, 2, 6], ["FU", 7, [0, 1, 2, 3]], ["SP", 7, 0, 64, [1], 1], ["AN", 7, 1, 2], ["FFSP", 7, 0, [16, 15, 14, 13], 1, 1], ["AN", 7, 1, 6], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 64, [1], 1], ["AN", 5, 1, 2], ["FFSP", 5, 0, [16, 15, 14, 13], 1, 1], ["AN", 5, 1, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 8192, [16], 1], ["AN", 4, 0, 5], ["AN", 4, 1, 6], ["PR", 4, 0, "auto_unroll_max_step$0"], ["PR", 8, 0, "auto_unroll_max_step$1024"], ["PR", 11, 0, "auto_unroll_max_step$64"]]]], "r": [[0.000135079], 0, 2.40957, 1605504233], "v": "v0.3"} diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py index 4756ea390b5c..723b8d15ea88 100644 --- a/tutorials/auto_scheduler/tune_network_cuda.py +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -102,10 +102,10 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape ) elif name == "squeezenet_v1.1": + assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" mod, params = relay.testing.squeezenet.get_workload( version="1.1", batch_size=batch_size, - layout=layout, dtype=dtype, image_shape=image_shape, ) @@ -148,9 +148,6 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # latency of a task and :code:`weight[t]` is the weight of the task. # The task scheduler will just optimize this objective. -# Enable auto-scheduler in relay -auto_scheduler.enable_relay_integration() - # Extract tasks from the network print("Extract tasks...") mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) @@ -219,29 +216,32 @@ def run_tuning(): # ---------------------------------------------------------------------- # | ID | Latency (ms) | Speed (GFLOPS) | Trials | # ------------------------------------------------- -# | 0 | 0.014 | 72.07 | 64 | -# | 1 | 0.185 | 1250.68 | 128 | -# | 2 | 0.142 | 1626.36 | 192 | -# | 3 | 0.137 | 1689.42 | 128 | -# | 4 | 0.097 | 1189.75 | 128 | -# | 5 | 0.092 | 2505.25 | 128 | -# | 6 | 0.080 | 2893.08 | 128 | -# | 7 | 0.119 | 1947.84 | 128 | -# | 8 | 0.090 | 1292.62 | 64 | -# | 9 | 0.107 | 2172.30 | 64 | -# | 10 | 0.095 | 2439.36 | 64 | -# | 11 | 0.077 | 3003.22 | 64 | -# | 12 | 0.068 | 1695.13 | 64 | -# | 13 | 0.058 | 3979.29 | 64 | -# | 14 | 0.048 | 4859.95 | 128 | -# | 15 | 0.073 | 3151.76 | 64 | -# | 16 | 0.056 | 4265.94 | 64 | -# | 17 | 0.009 | 2754.90 | 64 | -# | 18 | 0.011 | 1156.08 | 64 | -# | 19 | 0.013 | 955.80 | 64 | -# | 20 | 0.029 | 437.71 | 64 | +# | 0 | 0.005 | 0.88 | 64 | +# | 1 | 0.010 | 99.10 | 64 | +# | 2 | 0.006 | 0.00 | 64 | +# | 3 | 0.145 | 979.78 | 384 | +# | 4 | 0.130 | 1097.02 | 384 | +# | 5 | 0.143 | 992.69 | 384 | +# | 6 | 0.076 | 1526.86 | 192 | +# | 7 | 0.115 | 999.44 | 320 | +# | 8 | 0.079 | 1449.39 | 320 | +# | 9 | 0.122 | 938.73 | 384 | +# | 10 | 0.063 | 1832.98 | 192 | +# | 11 | 0.072 | 1763.62 | 256 | +# | 12 | 0.062 | 2036.40 | 192 | +# | 13 | 0.068 | 1874.44 | 192 | +# | 14 | 0.049 | 2346.50 | 128 | +# | 15 | 0.076 | 1694.31 | 256 | +# | 16 | 0.067 | 1933.30 | 448 | +# | 17 | 0.076 | 1680.90 | 256 | +# | 18 | 0.022 | 98.43 | 64 | +# | 19 | 0.076 | 3112.55 | 192 | +# | 20 | 0.013 | 2026.44 | 64 | +# | 21 | 0.011 | 1136.69 | 64 | +# | 22 | 0.013 | 992.47 | 64 | +# | 23 | 0.020 | 627.56 | 64 | # ------------------------------------------------- -# Estimated total latency: 1.649 ms Trials: 1920 Used time : 3598 s Next ID: 9 +# Estimated total latency: 1.587 ms Trials: 4992 Used time : 13296 s Next ID: 3 # # This table lists the latency and (estimated) speed of all tasks. # It also lists the allocation of measurement trials for all tasks. @@ -276,7 +276,7 @@ def run_tuning(): # Compile with the history best print("Compile...") with auto_scheduler.ApplyHistoryBest(log_file): - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): lib = relay.build(mod, target=target, params=params) # Create graph runtime