From 05d2240da856342ba1a88cc9dc2df64ce5a84879 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 15 Nov 2023 09:30:45 -0500 Subject: [PATCH] [Runtime] Parallel-for with threading backend This PR introduces the runtime parallel-for helper function in C++ with the threading backend in TVM. Right now the existing [parallel-for](https://github.com/apache/tvm/blob/bd67d2e5ebde1aec18bcfa74c087516579bda1ae/include/tvm/support/parallel_for.h#L48-L68) in TVM is not thread persistent, in which case we cannot get persistent TLS for each thread. The introduced parallel-for-with-threading-backend function leverages the threading backend in TVM and persists threads. --- include/tvm/runtime/threading_backend.h | 70 +++++++++++++++++++++++++ tests/cpp/threading_backend_test.cc | 9 ++++ 2 files changed, 79 insertions(+) diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index 77d6730c096e..3122b000e048 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -24,6 +24,9 @@ #ifndef TVM_RUNTIME_THREADING_BACKEND_H_ #define TVM_RUNTIME_THREADING_BACKEND_H_ +#include + +#include #include #include #include @@ -147,6 +150,73 @@ TVM_DLL void Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode, int32_t NumThreads(); } // namespace threading + +/*! + * \brief Execute the given lambda function in parallel with + * threading backend in TVM. + * \tparam T The type of the lambda: "void (int i)". + * \param flambda The lambda to be executed in parallel. + * It should have the signature "void (int i)". + * \param begin The start index of this parallel loop (inclusive). + * \param end The end index of this parallel loop (exclusive). + * \example + * + * The for loop + * for (int i = 0; i < 10; i++) { + * a[i] = i; + * } + * should work the same as: + * parallel_for_with_threading_backend([&a](int i) { + * a[i] = i; + * }, 0, 10); + */ +template +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end); + +namespace detail { + +// The detailed implementation of `parallel_for_with_threading_backend`. +// To avoid template expansion, the implementation cannot be placed +// in .cc files. + +template +struct ParallelForWithThreadingBackendLambdaInvoker { + static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + int num_task = penv->num_task; + // Convert void* back to lambda type. + T* lambda_ptr = static_cast(cdata); + // Invoke the lambda with the task id (thread id). + (*lambda_ptr)(task_id, num_task); + return 0; + } +}; + +template +inline void parallel_launch_with_threading_backend(T flambda) { + // Launch the lambda by passing its address. + void* cdata = &flambda; + TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker::TVMParallelLambdaInvoke, + cdata, /*num_task=*/0); +} + +} // namespace detail + +template +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) { + auto flaunch = [begin, end, flambda](int task_id, int num_task) { + // For each thread, do static division and call into flambda. + int64_t total_len = end - begin; + int64_t step = (total_len + num_task - 1) / num_task; + int64_t local_begin = std::min(begin + step * task_id, end); + int64_t local_end = std::min(local_begin + step, end); + for (int64_t i = local_begin; i < local_end; ++i) { + flambda(i); + } + }; + // Launch with all threads. + detail::parallel_launch_with_threading_backend(flaunch); +} + } // namespace runtime } // namespace tvm diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index 5adf1f9ae36c..b156eec8ab3a 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -185,3 +185,12 @@ TEST(ThreadingBackend, TVMBackendAffinityConfigure) { t->join(); } } + +TEST(ThreadingBackend, TVMBackendParallelForWithThreadingBackend) { + int n = 100; + std::vector vec(/*size=*/n, /*value=*/0); + tvm::runtime::parallel_for_with_threading_backend([&vec](int i) { vec[i] = i; }, 0, n); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(vec[i], i); + } +}