diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 85bbab692574..bd25704d2e74 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -655,7 +655,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): data, weights = inputs b, i = get_const_tuple(data.shape) o, _ = get_const_tuple(weights.shape) - if out_type.dtype == "int8": + if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32": strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_int8), wrap_topi_schedule(topi.cuda.schedule_dense_int8), diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index d6bfd8d0ec11..b3f1868969cc 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -102,5 +102,26 @@ def test_task_extraction(): assert len(tasks) == 31 +def test_task_extraction_for_dense_int8_cuda(): + target = "cuda" + dense = relay.op.get("nn.dense") + + def get_net(batch, in_dim, out_dim, dtype, out_dtype): + data = tvm.relay.var("data", shape=[batch, in_dim], dtype=dtype) + weight = tvm.relay.var("weight", shape=[out_dim, in_dim], dtype=dtype) + out = relay.nn.dense(data, weight, out_dtype=out_dtype) + mod, params = relay.testing.create_workload(out) + return mod, params + + mod, params = get_net(1, 16, 32, "float32", "float32") + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) + assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.cuda" + + mod, params = get_net(1, 16, 32, "int8", "int32") + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) + assert len(tasks) == 1 and tasks[0].name == "dense_int8.cuda" + + if __name__ == "__main__": test_task_extraction() + test_task_extraction_for_dense_int8_cuda()