From 177efd9a7e8faef764c1575646c50cbdce777bad Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 04:28:20 -0800 Subject: [PATCH 01/12] flush_l2 --- src/runtime/cuda/l2_cache_flush.cc | 81 +++++++++++++++++++ src/runtime/profiling.cc | 19 +++-- .../unittest/test_evaluator_flush_l2_cache.py | 63 +++++++++++++++ 3 files changed, 155 insertions(+), 8 deletions(-) create mode 100644 src/runtime/cuda/l2_cache_flush.cc create mode 100644 tests/python/unittest/test_evaluator_flush_l2_cache.py diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc new file mode 100644 index 000000000000..a98498ec134f --- /dev/null +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// Acknowledgement: l2flush struct in nvbench project. +// Reference: +// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh +#include +#include +#include +#include +#include + +#include "cuda_common.h" + +namespace tvm { + +namespace runtime { + +class L2Flush { + public: + explicit L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} + + ~L2Flush() { + if (l2_size_ > 0) { + CUDA_CALL(cudaFree(l2_buffer_)); + } + } + + void Flush() { + if (!initialized_) { + // initialize l2_buffer_ and l2_size_ + initialized_ = true; + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); + if (l2_size_ > 0) { + void* buffer = l2_buffer_; + CUDA_CALL(cudaMalloc(&buffer, l2_size_)); + l2_buffer_ = reinterpret_cast(buffer); + } + } + cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + if (l2_size_ > 0) { + CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream)); + } + } + + static L2Flush* ThreadLocal(); + + private: + bool initialized_ = false; + int l2_size_; + int* l2_buffer_; +}; + +typedef dmlc::ThreadLocalStore L2FlushStore; + +L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } + +TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; + L2Flush::ThreadLocal()->Flush(); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 168441d1708d..6333360d1277 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -882,9 +882,6 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, DeviceAPI::Get(dev)->StreamSync(dev, nullptr); for (int i = 0; i < repeat; ++i) { - if (f_preproc != nullptr) { - f_preproc.CallPacked(args, &temp); - } double duration_ms = 0.0; int absolute_zero_times = 0; do { @@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio)); } + int64_t accum_t_nanos = 0; // start timing - Timer t = Timer::Start(dev); for (int j = 0; j < number; ++j) { + // call preprocessing function + if (f_preproc != nullptr) { + f_preproc.CallPacked(args, &temp); + } + Timer t = Timer::Start(dev); pf.CallPacked(args, &temp); + t->Stop(); + int64_t t_nanos = t->SyncAndGetElapsedNanos(); + accum_t_nanos += t_nanos; } - t->Stop(); - int64_t t_nanos = t->SyncAndGetElapsedNanos(); - if (t_nanos == 0) absolute_zero_times++; - duration_ms = t_nanos / 1e6; + if (accum_t_nanos == 0) absolute_zero_times++; + duration_ms = accum_t_nanos / 1e6; } while (duration_ms < min_repeat_ms && absolute_zero_times < limit_zero_time_iterations); double speed = duration_ms / 1e3 / number; diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_flush_l2_cache.py new file mode 100644 index 000000000000..0bbb33f0fd4d --- /dev/null +++ b/tests/python/unittest/test_evaluator_flush_l2_cache.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.script import tir as T +import tvm.testing +import numpy as np + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.testing.requires_cuda +def test_evaluator_flush_l2_cache(): + mod = tvm.IRModule.from_expr(matmul) + sch = tvm.tir.Schedule(mod) + blk = sch.get_block("matmul") + i, j, k = sch.get_loops(blk) + sch.bind(i, "blockIdx.x") + sch.bind(k, "threadIdx.x") + f = tvm.build(sch.mod["main"], target=tvm.target.cuda(arch="sm_86")) + dev = tvm.cuda(0) + evaluator_no_flush = f.time_evaluator(f.entry_name, dev, number=100) + + a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) + b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) + c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev) + args = [a, b, c] + print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000)) + + evaluator_with_flush = f.time_evaluator( + f.entry_name, dev, number=100, f_preproc="l2_cache_flush_cuda" + ) + print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000)) + + +if __name__ == "__main__": + test_evaluator_flush_l2_cache() From 43580b5bcbef8f8fc5949b471012a2f9bbe3e6bb Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 05:05:19 -0800 Subject: [PATCH 02/12] not necessarily sm_86 --- tests/python/unittest/test_evaluator_flush_l2_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_flush_l2_cache.py index 0bbb33f0fd4d..57f2b3cbdd29 100644 --- a/tests/python/unittest/test_evaluator_flush_l2_cache.py +++ b/tests/python/unittest/test_evaluator_flush_l2_cache.py @@ -43,7 +43,7 @@ def test_evaluator_flush_l2_cache(): i, j, k = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(k, "threadIdx.x") - f = tvm.build(sch.mod["main"], target=tvm.target.cuda(arch="sm_86")) + f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) evaluator_no_flush = f.time_evaluator(f.entry_name, dev, number=100) From dd32136fa447c5419088ed8cf75ed63c3623d6c4 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 05:26:47 -0800 Subject: [PATCH 03/12] fix lint and test --- src/runtime/cuda/l2_cache_flush.cc | 2 +- tests/python/unittest/test_evaluator_flush_l2_cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index a98498ec134f..7b0fc3c3fcf1 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -33,7 +33,7 @@ namespace runtime { class L2Flush { public: - explicit L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} + L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} ~L2Flush() { if (l2_size_ > 0) { diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_flush_l2_cache.py index 57f2b3cbdd29..241c7572a4c9 100644 --- a/tests/python/unittest/test_evaluator_flush_l2_cache.py +++ b/tests/python/unittest/test_evaluator_flush_l2_cache.py @@ -45,7 +45,7 @@ def test_evaluator_flush_l2_cache(): sch.bind(k, "threadIdx.x") f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) - evaluator_no_flush = f.time_evaluator(f.entry_name, dev, number=100) + evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=100) a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) @@ -54,7 +54,7 @@ def test_evaluator_flush_l2_cache(): print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000)) evaluator_with_flush = f.time_evaluator( - f.entry_name, dev, number=100, f_preproc="l2_cache_flush_cuda" + f.entry_name, dev, repeat=100, f_preproc="l2_cache_flush_cuda" ) print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000)) From f1237c44ece70b0c43415824aafe0dfa234a21f9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 05:37:33 -0800 Subject: [PATCH 04/12] fix --- tests/python/unittest/test_evaluator_flush_l2_cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_flush_l2_cache.py index 241c7572a4c9..0b9731c58134 100644 --- a/tests/python/unittest/test_evaluator_flush_l2_cache.py +++ b/tests/python/unittest/test_evaluator_flush_l2_cache.py @@ -32,7 +32,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.testing.requires_cuda @@ -42,10 +42,10 @@ def test_evaluator_flush_l2_cache(): blk = sch.get_block("matmul") i, j, k = sch.get_loops(blk) sch.bind(i, "blockIdx.x") - sch.bind(k, "threadIdx.x") + sch.bind(j, "threadIdx.x") f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) - evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=100) + evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000) a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) @@ -54,7 +54,7 @@ def test_evaluator_flush_l2_cache(): print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000)) evaluator_with_flush = f.time_evaluator( - f.entry_name, dev, repeat=100, f_preproc="l2_cache_flush_cuda" + f.entry_name, dev, repeat=1000, f_preproc="l2_cache_flush_cuda" ) print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000)) From 530fabfbbec9294279cbf0798c60688247e555b5 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 05:46:03 -0800 Subject: [PATCH 05/12] revert profiling --- src/runtime/profiling.cc | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 6333360d1277..168441d1708d 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -882,6 +882,9 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, DeviceAPI::Get(dev)->StreamSync(dev, nullptr); for (int i = 0; i < repeat; ++i) { + if (f_preproc != nullptr) { + f_preproc.CallPacked(args, &temp); + } double duration_ms = 0.0; int absolute_zero_times = 0; do { @@ -891,21 +894,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio)); } - int64_t accum_t_nanos = 0; // start timing + Timer t = Timer::Start(dev); for (int j = 0; j < number; ++j) { - // call preprocessing function - if (f_preproc != nullptr) { - f_preproc.CallPacked(args, &temp); - } - Timer t = Timer::Start(dev); pf.CallPacked(args, &temp); - t->Stop(); - int64_t t_nanos = t->SyncAndGetElapsedNanos(); - accum_t_nanos += t_nanos; } - if (accum_t_nanos == 0) absolute_zero_times++; - duration_ms = accum_t_nanos / 1e6; + t->Stop(); + int64_t t_nanos = t->SyncAndGetElapsedNanos(); + if (t_nanos == 0) absolute_zero_times++; + duration_ms = t_nanos / 1e6; } while (duration_ms < min_repeat_ms && absolute_zero_times < limit_zero_time_iterations); double speed = duration_ms / 1e3 / number; From d65069526c2915c0bc16ad01abeadb6f92b101f6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 8 Jan 2023 05:47:01 -0800 Subject: [PATCH 06/12] number=1 --- tests/python/unittest/test_evaluator_flush_l2_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_flush_l2_cache.py index 0b9731c58134..b41f0aa68496 100644 --- a/tests/python/unittest/test_evaluator_flush_l2_cache.py +++ b/tests/python/unittest/test_evaluator_flush_l2_cache.py @@ -45,7 +45,7 @@ def test_evaluator_flush_l2_cache(): sch.bind(j, "threadIdx.x") f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) - evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000) + evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1) a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) @@ -54,7 +54,7 @@ def test_evaluator_flush_l2_cache(): print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000)) evaluator_with_flush = f.time_evaluator( - f.entry_name, dev, repeat=1000, f_preproc="l2_cache_flush_cuda" + f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda" ) print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000)) From d56c8b8751ef9bfb7ec286195c9d44b3c4e45b17 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 9 Jan 2023 05:08:13 -0800 Subject: [PATCH 07/12] use parametrize --- ...l2_cache.py => test_evaluator_with_preproc.py} | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) rename tests/python/unittest/{test_evaluator_flush_l2_cache.py => test_evaluator_with_preproc.py} (79%) diff --git a/tests/python/unittest/test_evaluator_flush_l2_cache.py b/tests/python/unittest/test_evaluator_with_preproc.py similarity index 79% rename from tests/python/unittest/test_evaluator_flush_l2_cache.py rename to tests/python/unittest/test_evaluator_with_preproc.py index b41f0aa68496..fc6eec25b8da 100644 --- a/tests/python/unittest/test_evaluator_flush_l2_cache.py +++ b/tests/python/unittest/test_evaluator_with_preproc.py @@ -20,6 +20,7 @@ from tvm.script import tir as T import tvm.testing import numpy as np +import pytest @T.prim_func @@ -36,7 +37,8 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.testing.requires_cuda -def test_evaluator_flush_l2_cache(): +@pytest.mark.parametrize("f_preproc", ["", "l2_cache_flush_cuda"]) +def test_time_evalutor_with_preproc(f_preproc: str): mod = tvm.IRModule.from_expr(matmul) sch = tvm.tir.Schedule(mod) blk = sch.get_block("matmul") @@ -45,19 +47,14 @@ def test_evaluator_flush_l2_cache(): sch.bind(j, "threadIdx.x") f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) - evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1) + evaluator = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1, f_preproc=f_preproc) a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev) c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev) args = [a, b, c] - print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000)) - - evaluator_with_flush = f.time_evaluator( - f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda" - ) - print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000)) + print("Evaluator (f_preproc={}):\t{:.5f}ms".format(f_preproc, evaluator(*args).mean * 1000)) if __name__ == "__main__": - test_evaluator_flush_l2_cache() + test_time_evalutor_with_preproc("l2_cache_flush_cuda") From 486288c3e01ec667a3a47ca5e3717d4f0b997128 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 9 Jan 2023 06:41:15 -0800 Subject: [PATCH 08/12] use (void**) --- src/runtime/cuda/l2_cache_flush.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 7b0fc3c3fcf1..58e02c707547 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -49,9 +49,7 @@ class L2Flush { CUDA_CALL(cudaGetDevice(&device_id)); CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); if (l2_size_ > 0) { - void* buffer = l2_buffer_; - CUDA_CALL(cudaMalloc(&buffer, l2_size_)); - l2_buffer_ = reinterpret_cast(buffer); + CUDA_CALL(cudaMalloc((void**)&l2_buffer_, l2_size_)); } } cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; From 5fd7717244e379745b48d6cced6881d7cc47c2e9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 9 Jan 2023 07:06:44 -0800 Subject: [PATCH 09/12] use reinterpret_cast for lint --- src/runtime/cuda/l2_cache_flush.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 58e02c707547..ca86cfdbb052 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -49,7 +49,7 @@ class L2Flush { CUDA_CALL(cudaGetDevice(&device_id)); CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); if (l2_size_ > 0) { - CUDA_CALL(cudaMalloc((void**)&l2_buffer_, l2_size_)); + CUDA_CALL(cudaMalloc(reinterpret_cast(&l2_buffer_), l2_size_)); } } cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; From fcb0f3f1f6b858131ef0bbfde20d5df650afa318 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 10 Jan 2023 01:14:08 -0800 Subject: [PATCH 10/12] refactor and add license --- 3rdparty/nvbench/l2_cache_flush.h | 74 ++++++++++ LICENSE | 2 +- licenses/LICENSE.l2_cache_flush.txt | 218 ++++++++++++++++++++++++++++ src/runtime/cuda/l2_cache_flush.cc | 48 +----- 4 files changed, 298 insertions(+), 44 deletions(-) create mode 100644 3rdparty/nvbench/l2_cache_flush.h create mode 100644 licenses/LICENSE.l2_cache_flush.txt diff --git a/3rdparty/nvbench/l2_cache_flush.h b/3rdparty/nvbench/l2_cache_flush.h new file mode 100644 index 000000000000..9aff6e3f72aa --- /dev/null +++ b/3rdparty/nvbench/l2_cache_flush.h @@ -0,0 +1,74 @@ +/* + * Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 with the LLVM exception + * (the "License"); you may not use this file except in compliance with + * the License. + * + * You may obtain a copy of the License at + * + * http://llvm.org/foundation/relicensing/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * \file l2_cache_flush.h + * \brief Functions to flush L2 cache using CUDA's API, adopted from nvbench. + */ +#ifndef L2_CACHE_FLUSH_H_ +#define L2_CACHE_FLUSH_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { + +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + +class L2Flush { + public: + L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} + + ~L2Flush() { + if (l2_size_ > 0) { + CUDA_CALL(cudaFree(l2_buffer_)); + } + } + + void Flush(cudaStream_t stream) { + if (!initialized_) { + // initialize l2_buffer_ and l2_size_ + initialized_ = true; + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); + if (l2_size_ > 0) { + CUDA_CALL(cudaMalloc(reinterpret_cast(&l2_buffer_), l2_size_)); + } + } + if (l2_size_ > 0) { + CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream)); + } + } + + static L2Flush* ThreadLocal(); + + private: + bool initialized_ = false; + int l2_size_; + int* l2_buffer_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // L2_CACHE_FLUSH_H_ \ No newline at end of file diff --git a/LICENSE b/LICENSE index 6524d530deca..fbc11be2deb5 100644 --- a/LICENSE +++ b/LICENSE @@ -212,6 +212,7 @@ Apache Software Foundation License 2.0 3rdparty/dlpack 3rdparty/dmlc-core 3rdparty/OpenCL-Headers +3rdparty/nvbench (with LLVM exception) BSD 2-clause License @@ -234,7 +235,6 @@ MIT License 3rdparty/cma 3rdparty/compiler-rt/builtin_fp16.h - The Unlicense ------------- diff --git a/licenses/LICENSE.l2_cache_flush.txt b/licenses/LICENSE.l2_cache_flush.txt new file mode 100644 index 000000000000..bd8b243dfa02 --- /dev/null +++ b/licenses/LICENSE.l2_cache_flush.txt @@ -0,0 +1,218 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +--- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index ca86cfdbb052..eeacb5c3dcd1 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -16,63 +16,25 @@ * specific language governing permissions and limitations * under the License. */ -// Acknowledgement: l2flush struct in nvbench project. -// Reference: -// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh -#include -#include +#include "../../../3rdparty/nvbench/l2_cache_flush.h" +#include "cuda_common.h" + #include #include #include -#include "cuda_common.h" - namespace tvm { namespace runtime { -class L2Flush { - public: - L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {} - - ~L2Flush() { - if (l2_size_ > 0) { - CUDA_CALL(cudaFree(l2_buffer_)); - } - } - - void Flush() { - if (!initialized_) { - // initialize l2_buffer_ and l2_size_ - initialized_ = true; - int device_id; - CUDA_CALL(cudaGetDevice(&device_id)); - CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id)); - if (l2_size_ > 0) { - CUDA_CALL(cudaMalloc(reinterpret_cast(&l2_buffer_), l2_size_)); - } - } - cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; - if (l2_size_ > 0) { - CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream)); - } - } - - static L2Flush* ThreadLocal(); - - private: - bool initialized_ = false; - int l2_size_; - int* l2_buffer_; -}; - typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; - L2Flush::ThreadLocal()->Flush(); + cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + L2Flush::ThreadLocal()->Flush(stream); }); } // namespace runtime From 284b12576dae40ee1fbee9277f4d7ed691be5dfd Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 10 Jan 2023 01:18:01 -0800 Subject: [PATCH 11/12] empty line for lint --- 3rdparty/nvbench/l2_cache_flush.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/nvbench/l2_cache_flush.h b/3rdparty/nvbench/l2_cache_flush.h index 9aff6e3f72aa..3d0211564535 100644 --- a/3rdparty/nvbench/l2_cache_flush.h +++ b/3rdparty/nvbench/l2_cache_flush.h @@ -71,4 +71,4 @@ class L2Flush { } // namespace runtime } // namespace tvm -#endif // L2_CACHE_FLUSH_H_ \ No newline at end of file +#endif // L2_CACHE_FLUSH_H_ From 71eb46d76c4afb0798af1f272f8df0f072154e9a Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 10 Jan 2023 01:33:41 -0800 Subject: [PATCH 12/12] header order --- src/runtime/cuda/l2_cache_flush.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index eeacb5c3dcd1..6b2c4665301c 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -17,12 +17,13 @@ * under the License. */ #include "../../../3rdparty/nvbench/l2_cache_flush.h" -#include "cuda_common.h" #include #include #include +#include "cuda_common.h" + namespace tvm { namespace runtime {