From 42add5bf5014c6c29687bf9e4922d688d177046d Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Thu, 19 Dec 2019 23:25:08 +0000 Subject: [PATCH 1/6] cublaslt added --- python/tvm/contrib/cublaslt.py | 51 ++++++++++ src/runtime/contrib/cublas/cublas.cc | 111 +++++++++++++++++++++- src/runtime/contrib/cublas/cublas_utils.h | 6 ++ tests/python/contrib/test_cublas.py | 65 ++++++++++++- 4 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 python/tvm/contrib/cublaslt.py diff --git a/python/tvm/contrib/cublaslt.py b/python/tvm/contrib/cublaslt.py new file mode 100644 index 000000000000..5470fd0b4c18 --- /dev/null +++ b/python/tvm/contrib/cublaslt.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External function interface to cuBLASlt libraries.""" +from __future__ import absolute_import as _abs + +from .. import api as _api +from .. import intrin as _intrin + +def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): + """Create an extern op that compute matrix mult of A and rhs with cuBLAS + + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + if n == 0: + n = lhs.shape[1] if transa else lhs.shape[0] + if m == 0: + m = rhs.shape[0] if transb else rhs.shape[1] + dtype = dtype if dtype is not None else lhs.dtype + return _api.extern( + (n, m), [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cublaslt.matmul", + ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index bbb2d2e952cc..7d9a5449f601 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -181,6 +181,93 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s } } +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +#if CUDART_VERSION >= 10010 +inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + int M = ColumnCount(B, transb); + int N = RowCount(A, transa); + int K = ColumnCount(A, transa); + int N_out = ColumnCount(C, false); + int m = M; + int n = m; + int k = m; + int lda = M * K / (roundoff(K, 32) / 32); + int ldb = K * N / (roundoff(K, 32) / 32); + int ldc = M * N_out / (roundoff(N_out, 32) / 32); + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + + CHECK(TypeEqual(A->dtype, B->dtype)); + CHECK(TypeMatch(A->dtype, kDLInt, 8)); + CHECK(TypeMatch(C->dtype, kDLInt, 32)); + + CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + int32_t alpha = args.size() > 5 ? args[5] : 1; + int32_t beta = args.size() > 6 ? args[6] : 0; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + cublasOperation_t opTranspose = CUBLAS_OP_T; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtMatmulDesc_t operationDesc = nullptr; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); + cublasOperation_t opTransA = BooleanToTranspose(transa); + cublasOperation_t opTransB = BooleanToTranspose(transb); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + // Create descriptors for the original matrices + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, + operationDesc, + &alpha, + B_data, + Adesc, + A_data, + Bdesc, + &beta, + C_data, + Cdesc, + C_data, + Cdesc, + NULL, + NULL, + 0, + 0)); + +} +#endif + inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { DLTensor *A = args[0]; DLTensor *B = args[1]; @@ -342,13 +429,33 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") } }); +TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor* A = args[0]; + + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + + TryEnableTensorCore(entry_ptr->handle); + + int version; + CHECK_CUBLAS_ERROR(cublasGetVersion(entry_ptr->handle, &version)); + #if CUDART_VERSION >= 10010 + if (TypeMatch(A->dtype, kDLInt, 8) && version >= 10100) { + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); + } else + #endif // CUDART_VERSION >= 10010 + LOG(FATAL) << "Cublas version needs to be equal or larger than 10.1, but currently is " + << version; +}); + TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") .set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; DLTensor* C = args[2]; - - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); TryEnableTensorCore(entry_ptr->handle); diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 17e123219089..e28b8ce7760f 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -27,6 +27,12 @@ #include #include #include +#include +#include +#include +#if CUDART_VERSION >= 10010 +#include +#endif // CUDART_VERSION >= 10010 namespace tvm { namespace contrib { diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 85268b95a7a8..8658de25c2e2 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -17,6 +17,7 @@ import tvm import numpy as np from tvm.contrib import cublas +from tvm.contrib import cublaslt def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): n = 1024 @@ -44,6 +45,64 @@ def verify(target="cuda"): c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol) verify() +def roundoff(v, d): + return int(np.floor((v + d - 1) / d) * d) + +def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5): + n = 1024 + l = 1024 + m = 1024 + L = roundoff(l, 32) + N = roundoff(n, 8) + N_out = roundoff(n, 32) + + A = tvm.placeholder((N, L), name='A', dtype=in_dtype) + B = tvm.placeholder((m, L), name='B', dtype=in_dtype) + # C has CUBLASLT_ORDER_COL32 layout, thus a different shape + C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype) + s = tvm.create_schedule(C.op) + + def verify(target="cuda"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.gpu(0) + f = tvm.build(s, [A, B, C], target) + a_old = np.random.uniform(0, 128, size=(n, l)) + b_old = np.random.uniform(0, 128, size=(l, m)) + + # Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout + a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l]))) + a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L]))) + a_even = np.vsplit(a_new[::2], N / 8) + a_odd = np.vsplit(a_new[1::2], N / 8) + a_new = [None]*(len(a_even) + len(a_odd)) + a_new[::2] = a_even + a_new[1::2] = a_odd + a_new = np.vstack(a_new) + a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32)) + a_new = a_new.reshape([N, L]) + # Transform b to become CUBLASLT_ORDER_COL32 layout + b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32)) + b_new = b_new.reshape([m, L]) + + a = tvm.nd.array(a_new.astype(A.dtype), ctx) + b = tvm.nd.array(b_new.astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((m, N_out), dtype=C.dtype), ctx) + f(a, b, c) + # Transform output c from layout CUBLASLT_ORDER_COL32 to row major layout + c_out = c.asnumpy() + c_out = c_out.reshape([int(m * N_out / 32), 32]) + c_out = np.hstack(np.vsplit(c_out, int(N_out / 32))) + c_out = c_out[:, :n] + c_out = c_out.T + tvm.testing.assert_allclose( + c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol) + verify() + def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): j = 16 n = 1024 @@ -73,11 +132,14 @@ def verify(target="cuda"): verify() def test_matmul_add(): - verify_matmul_add('float', 'float') + verify_matmul_add('float', 'float', rtol=1e-3) verify_matmul_add('float16', 'float') verify_matmul_add('float16', 'float16', rtol=1e-2) verify_matmul_add('int8', 'int32') +def test_matmul_add_igemm(): + verify_matmul_add_igemm('int8', 'int32') + def test_batch_matmul(): verify_batch_matmul('float', 'float') verify_batch_matmul('float16', 'float') @@ -86,4 +148,5 @@ def test_batch_matmul(): if __name__ == "__main__": test_matmul_add() test_batch_matmul() + test_matmul_add_igemm() From a944991ad1394815a81c45a09c62af6130e51bf9 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 20 Dec 2019 23:18:43 +0000 Subject: [PATCH 2/6] fix lint --- src/runtime/contrib/cublas/cublas.cc | 24 ++++++++++++++--------- src/runtime/contrib/cublas/cublas_utils.h | 2 +- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 7d9a5449f601..2278f6ff32c3 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -182,7 +182,7 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s } int roundoff(int v, int d) { - return (v + d - 1) / d * d; + return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 @@ -234,11 +234,17 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = BooleanToTranspose(transa); cublasOperation_t opTransB = BooleanToTranspose(transb); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); // Create descriptors for the original matrices - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , opTransA == CUBLAS_OP_N ? k : m, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , + opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , + opTransB == CUBLAS_OP_N ? n : k, ldb)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( @@ -264,7 +270,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { NULL, 0, 0)); - } #endif @@ -439,16 +444,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") int version; CHECK_CUBLAS_ERROR(cublasGetVersion(entry_ptr->handle, &version)); - #if CUDART_VERSION >= 10010 if (TypeMatch(A->dtype, kDLInt, 8) && version >= 10100) { + #if CUDART_VERSION >= 10010 cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); CallLtIgemm(args, ret, ltHandle); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); - } else - #endif // CUDART_VERSION >= 10010 + #endif // CUDART_VERSION >= 10010 + } else { LOG(FATAL) << "Cublas version needs to be equal or larger than 10.1, but currently is " << version; + } }); TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index e28b8ce7760f..2e553e28493b 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -32,7 +32,7 @@ #include #if CUDART_VERSION >= 10010 #include -#endif // CUDART_VERSION >= 10010 +#endif // CUDART_VERSION >= 10010 namespace tvm { namespace contrib { From 32a5ac10203b4adb288be43f2d147cdb6348a809 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Wed, 25 Dec 2019 20:31:39 +0000 Subject: [PATCH 3/6] address comments --- src/runtime/contrib/cublas/cublas.cc | 12 ++++++------ tests/python/contrib/test_cublas.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 2278f6ff32c3..53ff49231b98 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -221,9 +221,9 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { int32_t alpha = args.size() > 5 ? args[5] : 1; int32_t beta = args.size() > 6 ? args[6] : 0; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); cublasOperation_t opTranspose = CUBLAS_OP_T; cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; @@ -434,8 +434,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") } }); +#if CUDART_VERSION >= 10010 TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); @@ -445,17 +446,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") int version; CHECK_CUBLAS_ERROR(cublasGetVersion(entry_ptr->handle, &version)); if (TypeMatch(A->dtype, kDLInt, 8) && version >= 10100) { - #if CUDART_VERSION >= 10010 cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); CallLtIgemm(args, ret, ltHandle); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); - #endif // CUDART_VERSION >= 10010 } else { LOG(FATAL) << "Cublas version needs to be equal or larger than 10.1, but currently is " << version; } }); +#endif // CUDART_VERSION >= 10010 TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index 8658de25c2e2..4d4789663a9f 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -66,7 +66,7 @@ def verify(target="cuda"): if not tvm.module.enabled(target): print("skip because %s is not enabled..." % target) return - if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True): print("skip because extern function is not available") return ctx = tvm.gpu(0) From ab2450accccc67273e3a7ab0672bec790cbb4148 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 27 Dec 2019 23:47:48 +0000 Subject: [PATCH 4/6] address more comments --- src/runtime/contrib/cublas/cublas.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 53ff49231b98..2cb677729654 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -443,17 +443,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") TryEnableTensorCore(entry_ptr->handle); - int version; - CHECK_CUBLAS_ERROR(cublasGetVersion(entry_ptr->handle, &version)); - if (TypeMatch(A->dtype, kDLInt, 8) && version >= 10100) { - cublasLtHandle_t ltHandle; - CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - CallLtIgemm(args, ret, ltHandle); - CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); - } else { - LOG(FATAL) << "Cublas version needs to be equal or larger than 10.1, but currently is " - << version; - } + CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); #endif // CUDART_VERSION >= 10010 From b4d4fab496e580f002b32a24275a6b580fa2a2f9 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sat, 28 Dec 2019 06:15:56 +0000 Subject: [PATCH 5/6] Trigger CI From 0fe0c7ff2d9c65da9a3376c559a48abbed261cca Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 29 Dec 2019 20:50:10 +0000 Subject: [PATCH 6/6] Trigger CI