From a2cbc8bfaa8fe0ca5c7ab23ea6f269178fee7713 Mon Sep 17 00:00:00 2001 From: monklof Date: Wed, 11 Nov 2020 13:56:09 +0800 Subject: [PATCH 1/3] add ShapeFunc for tanh --- python/tvm/relay/op/_tensor.py | 2 ++ python/tvm/topi/cuda/dense.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 0c875045032f..dae233a67745 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -276,3 +276,5 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("clip", False, elemwise_shape_func) register_shape_func("log2", False, elemwise_shape_func) register_shape_func("sigmoid", False, elemwise_shape_func) +register_shape_func("tanh", False, elemwise_shape_func) + diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 47b9db4f390a..cbd917580dd4 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -77,8 +77,8 @@ def _callback(op): def _schedule_dense_small_batch(cfg, s, C): - A, _ = C.op.input_tensors - _, in_dim = get_const_tuple(A.shape) + A, weights = C.op.input_tensors + _, in_dim = get_const_tuple(weights.shape) cfg.define_split("tile_k", in_dim, num_outputs=2) if cfg.is_fallback: cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) From aadf4c9a7bf7250c294d6a456225dc1edf1e96a2 Mon Sep 17 00:00:00 2001 From: monklof Date: Fri, 22 Jan 2021 15:27:16 +0800 Subject: [PATCH 2/3] _schedule_dense_small_batch turn autotvm off when dense's inner dim is unknown --- python/tvm/topi/cuda/dense.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index cbd917580dd4..5108dc12e48b 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -78,12 +78,25 @@ def _callback(op): def _schedule_dense_small_batch(cfg, s, C): A, weights = C.op.input_tensors - _, in_dim = get_const_tuple(weights.shape) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + _, in_dim_weights = get_const_tuple(weights.shape) + _, in_dim_A = get_const_tuple(A.shape) + + if isinstance(in_dim_A, int): + in_dim = in_dim_A + elif isinstance(in_dim_weights, int): + in_dim = in_dim_weights + else: + in_dim = None + + if in_dim != None: + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) + else: + tile_k = 64 + _, kf = s[C].split(C.op.reduce_axis[0], tile_k) - _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) CF = s.rfactor(C, kf) if C.op in s.outputs: @@ -102,6 +115,7 @@ def _schedule_dense_small_batch(cfg, s, C): s[Out].set_store_predicate(thread_x.var.equal(0)) + @autotvm.register_topi_compute("dense_large_batch.cuda") def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None): """Dense operator on CUDA""" From ee1a74965ea5cffc7cbfa5995057ed7356d88a60 Mon Sep 17 00:00:00 2001 From: monklof Date: Tue, 23 Feb 2021 15:09:39 +0800 Subject: [PATCH 3/3] fix CI pylint --- python/tvm/relay/op/_tensor.py | 1 - python/tvm/topi/cuda/dense.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index dae233a67745..499f48618b86 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -277,4 +277,3 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("log2", False, elemwise_shape_func) register_shape_func("sigmoid", False, elemwise_shape_func) register_shape_func("tanh", False, elemwise_shape_func) - diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 5108dc12e48b..a718f5ecf695 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -88,7 +88,7 @@ def _schedule_dense_small_batch(cfg, s, C): else: in_dim = None - if in_dim != None: + if in_dim is not None: cfg.define_split("tile_k", in_dim, num_outputs=2) if cfg.is_fallback: cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) @@ -115,7 +115,6 @@ def _schedule_dense_small_batch(cfg, s, C): s[Out].set_store_predicate(thread_x.var.equal(0)) - @autotvm.register_topi_compute("dense_large_batch.cuda") def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None): """Dense operator on CUDA"""