-
Notifications
You must be signed in to change notification settings - Fork 2k
[wasm] Add flag WASM_THREAD_POOL_SIZE #4942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,12 @@ export class BackendWasm extends KernelBackend { | |
| constructor(public wasm: BackendWasmModule | BackendWasmThreadedSimdModule) { | ||
| super(); | ||
| this.wasm.tfjs.init(); | ||
| // Register the used thread pool size flag. Done it here to avoid circular | ||
| // import. | ||
| env().registerFlag('WASM_THREAD_POOL_SIZE', () => { | ||
| return this.getThreadPoolSize(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually this is not the size of pool, but count of threads in pool. According to https://github.com/Maratyszcza/pthreadpool/blob/master/include/pthreadpool.h#L77, please change the name to threads_count.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename done |
||
| }); | ||
|
|
||
| this.dataIdMap = new DataStorage(this, engine()); | ||
| } | ||
|
|
||
|
|
@@ -158,6 +164,11 @@ export class BackendWasm extends KernelBackend { | |
| return this.dataIdMap.get(dataId).memoryOffset; | ||
| } | ||
|
|
||
| // Return the used thread pool size. | ||
| getThreadPoolSize(): number { | ||
| return this.wasm.tfjs.getThreadPoolSize(); | ||
| } | ||
|
|
||
| dispose() { | ||
| this.wasm.tfjs.dispose(); | ||
| if ('PThread' in this.wasm) { | ||
|
|
@@ -346,6 +357,7 @@ export async function init(): Promise<{wasm: BackendWasmModule}> { | |
| // Using the tfjs namespace to avoid conflict with emscripten's API. | ||
| module.tfjs = { | ||
| init: module.cwrap('init', null, []), | ||
| getThreadPoolSize: module.cwrap('get_thread_pool_size', 'number', []), | ||
| registerTensor: module.cwrap( | ||
| 'register_tensor', null, | ||
| [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually, we register flag in
flags_wasm.ts. And callenv().set('WASM_THREAD_POOL_SIZE', this.getThreadPoolSize());. But I am not sure if it is always the rules.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently backend_wasm.ts deppends on flags_wasm.ts.
If we put those code in flags_wasm.ts, flags_wasm.ts will also depends on backend_wasm.ts, the test case will pass, but the yarn lint will fail due to circular imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a good way to pass variable in .ts file to .cc? I'm not sure if this can work: move calculation logic to flags_wasm.ts, and register flag there. In backend.cc, provide a set method. Then in backend_wasm.ts, set the threads count by calling the set method.