diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 0c875045032f..499f48618b86 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -276,3 +276,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("clip", False, elemwise_shape_func) register_shape_func("log2", False, elemwise_shape_func) register_shape_func("sigmoid", False, elemwise_shape_func) +register_shape_func("tanh", False, elemwise_shape_func) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 47b9db4f390a..a718f5ecf695 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -77,13 +77,26 @@ def _callback(op): def _schedule_dense_small_batch(cfg, s, C): - A, _ = C.op.input_tensors - _, in_dim = get_const_tuple(A.shape) - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + A, weights = C.op.input_tensors + _, in_dim_weights = get_const_tuple(weights.shape) + _, in_dim_A = get_const_tuple(A.shape) + + if isinstance(in_dim_A, int): + in_dim = in_dim_A + elif isinstance(in_dim_weights, int): + in_dim = in_dim_weights + else: + in_dim = None + + if in_dim is not None: + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) + else: + tile_k = 64 + _, kf = s[C].split(C.op.reduce_axis[0], tile_k) - _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) CF = s.rfactor(C, kf) if C.op in s.outputs: