Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e/benchmarks/local-benchmark/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const BACKEND_FLAGS_MAP = {
wasm: [
'WASM_HAS_SIMD_SUPPORT',
'WASM_HAS_MULTITHREAD_SUPPORT',
'WASM_THREAD_POOL_SIZE',
'CHECK_COMPUTATION_FOR_ERRORS',
],
webgl: [
Expand Down
12 changes: 12 additions & 0 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ export class BackendWasm extends KernelBackend {
constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) {
super();
this.wasm.tfjs.init();
// Register the used thread pool size flag. Done it here to avoid circular
// import.
env().registerFlag('WASM_THREAD_POOL_SIZE', () => {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, we register flag in flags_wasm.ts. And call env().set('WASM_THREAD_POOL_SIZE', this.getThreadPoolSize());. But I am not sure if it is always the rules.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently backend_wasm.ts deppends on flags_wasm.ts.
If we put those code in flags_wasm.ts, flags_wasm.ts will also depends on backend_wasm.ts, the test case will pass, but the yarn lint will fail due to circular imports.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good way to pass variable in .ts file to .cc? I'm not sure if this can work: move calculation logic to flags_wasm.ts, and register flag there. In backend.cc, provide a set method. Then in backend_wasm.ts, set the threads count by calling the set method.

return this.getThreadPoolSize();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is not the size of pool, but count of threads in pool. According to https://github.com/Maratyszcza/pthreadpool/blob/master/include/pthreadpool.h#L77, please change the name to threads_count.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename done

});

this.dataIdMap = new DataStorage(this, engine());
}

Expand Down Expand Up @@ -158,6 +164,11 @@ export class BackendWasm extends KernelBackend {
return this.dataIdMap.get(dataId).memoryOffset;
}

// Return the used thread pool size.
getThreadPoolSize(): number {
return this.wasm.tfjs.getThreadPoolSize();
}

dispose() {
this.wasm.tfjs.dispose();
if ('PThread' in this.wasm) {
Expand Down Expand Up @@ -346,6 +357,7 @@ export async function init(): Promise<{wasm: BackendWasmModule}> {
// Using the tfjs namespace to avoid conflict with emscripten's API.
module.tfjs = {
init: module.cwrap('init', null, []),
getThreadPoolSize: module.cwrap('get_thread_pool_size', 'number', []),
registerTensor: module.cwrap(
'register_tensor', null,
[
Expand Down
10 changes: 8 additions & 2 deletions tfjs-backend-wasm/src/cc/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ int num_cores = 1;

int min_num_threads = 1;
int max_num_threads = 4;
pthreadpool *threadpool = pthreadpool_create(
std::min(std::max(num_cores, min_num_threads), max_num_threads));
int thread_pool_size =
std::min(std::max(num_cores, min_num_threads), max_num_threads);
pthreadpool *threadpool = pthreadpool_create(thread_pool_size);

// Registers a disposal callback for a tensor id with a given callback function.
void register_disposal_callback(const size_t tensor_id,
Expand All @@ -90,6 +91,11 @@ EMSCRIPTEN_KEEPALIVE
#endif
void init() { xnn_initialize(nullptr); }

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
const size_t get_thread_pool_size() { return backend::thread_pool_size; }

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-wasm/src/cc/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ void init();
void register_tensor(const size_t tensor_id, const size_t size,
void *memory_offset);

// Return the used thread pool size.
const size_t get_thread_pool_size();

// Disposes the internal bookeeping for a given tensor ID.
void dispose_data(const size_t tensor_id);

Expand Down
7 changes: 7 additions & 0 deletions tfjs-backend-wasm/src/index_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ import {BackendWasm, setWasmPath, setWasmPaths} from './index';
* 'wasm' so that they are always included in the test runner. See
* `env.specFilter` in `setup_test.ts` for details.
*/
describeWithFlags('wasm thread pool', ALL_ENVS, () => {
it('thread pool size', async () => {
const threadPoolSize = tf.env().getNumber('WASM_THREAD_POOL_SIZE');
expect(threadPoolSize).toBeGreaterThanOrEqual(1);
});
});

describeWithFlags('wasm read/write', ALL_ENVS, () => {
it('write and read values', async () => {
const x = tf.tensor1d([1, 2, 3]);
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export interface BackendWasmModule extends EmscriptenModule {
// Using the tfjs namespace to avoid conflict with emscripten's API.
tfjs: {
init(): void,
// Return the used thread pool size.
getThreadPoolSize(): number,
registerTensor(id: number, size: number, memoryOffset: number): void,
// Disposes the data behind the data bucket.
disposeData(id: number): void,
Expand Down