diff --git a/e2e/benchmarks/local-benchmark/index.js b/e2e/benchmarks/local-benchmark/index.js index 5e90662d27..4328b129fa 100644 --- a/e2e/benchmarks/local-benchmark/index.js +++ b/e2e/benchmarks/local-benchmark/index.js @@ -21,6 +21,7 @@ const BACKEND_FLAGS_MAP = { wasm: [ 'WASM_HAS_SIMD_SUPPORT', 'WASM_HAS_MULTITHREAD_SUPPORT', + 'WASM_THREAD_POOL_SIZE', 'CHECK_COMPUTATION_FOR_ERRORS', ], webgl: [ diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index b229ad83f4..fafe486f3c 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -45,6 +45,12 @@ export class BackendWasm extends KernelBackend { constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) { super(); this.wasm.tfjs.init(); + // Register the used thread pool size flag. Done it here to avoid circular + // import. + env().registerFlag('WASM_THREAD_POOL_SIZE', () => { + return this.getThreadPoolSize(); + }); + this.dataIdMap = new DataStorage(this, engine()); } @@ -158,6 +164,11 @@ export class BackendWasm extends KernelBackend { return this.dataIdMap.get(dataId).memoryOffset; } + // Return the used thread pool size. + getThreadPoolSize(): number { + return this.wasm.tfjs.getThreadPoolSize(); + } + dispose() { this.wasm.tfjs.dispose(); if ('PThread' in this.wasm) { @@ -346,6 +357,7 @@ export async function init(): Promise<{wasm: BackendWasmModule}> { // Using the tfjs namespace to avoid conflict with emscripten's API. module.tfjs = { init: module.cwrap('init', null, []), + getThreadPoolSize: module.cwrap('get_thread_pool_size', 'number', []), registerTensor: module.cwrap( 'register_tensor', null, [ diff --git a/tfjs-backend-wasm/src/cc/backend.cc b/tfjs-backend-wasm/src/cc/backend.cc index 72f2428c7f..a3cac7c24c 100644 --- a/tfjs-backend-wasm/src/cc/backend.cc +++ b/tfjs-backend-wasm/src/cc/backend.cc @@ -62,8 +62,9 @@ int num_cores = 1; int min_num_threads = 1; int max_num_threads = 4; -pthreadpool *threadpool = pthreadpool_create( - std::min(std::max(num_cores, min_num_threads), max_num_threads)); +int thread_pool_size = + std::min(std::max(num_cores, min_num_threads), max_num_threads); +pthreadpool *threadpool = pthreadpool_create(thread_pool_size); // Registers a disposal callback for a tensor id with a given callback function. void register_disposal_callback(const size_t tensor_id, @@ -90,6 +91,11 @@ EMSCRIPTEN_KEEPALIVE #endif void init() { xnn_initialize(nullptr); } +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +const size_t get_thread_pool_size() { return backend::thread_pool_size; } + #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif diff --git a/tfjs-backend-wasm/src/cc/backend.h b/tfjs-backend-wasm/src/cc/backend.h index e67d3b0601..70af87ccd2 100644 --- a/tfjs-backend-wasm/src/cc/backend.h +++ b/tfjs-backend-wasm/src/cc/backend.h @@ -103,6 +103,9 @@ void init(); void register_tensor(const size_t tensor_id, const size_t size, void *memory_offset); +// Return the used thread pool size. +const size_t get_thread_pool_size(); + // Disposes the internal bookeeping for a given tensor ID. void dispose_data(const size_t tensor_id); diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index ce0ef4f612..433abbcafe 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -28,6 +28,13 @@ import {BackendWasm, setWasmPath, setWasmPaths} from './index'; * 'wasm' so that they are always included in the test runner. See * `env.specFilter` in `setup_test.ts` for details. */ +describeWithFlags('wasm thread pool', ALL_ENVS, () => { + it('thread pool size', async () => { + const threadPoolSize = tf.env().getNumber('WASM_THREAD_POOL_SIZE'); + expect(threadPoolSize).toBeGreaterThanOrEqual(1); + }); +}); + describeWithFlags('wasm read/write', ALL_ENVS, () => { it('write and read values', async () => { const x = tf.tensor1d([1, 2, 3]); diff --git a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts index 2af757f8bc..bbcbf36c49 100644 --- a/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts +++ b/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts @@ -19,6 +19,8 @@ export interface BackendWasmModule extends EmscriptenModule { // Using the tfjs namespace to avoid conflict with emscripten's API. tfjs: { init(): void, + // Return the used thread pool size. + getThreadPoolSize(): number, registerTensor(id: number, size: number, memoryOffset: number): void, // Disposes the data behind the data bucket. disposeData(id: number): void,