Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions explainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,24 @@ The WebNN API is a specification for constructing and executing computational gr

``` JavaScript
const operandType = {type: 'float32', dimensions: [2, 2]};
const context = navigator.ml.getNeuralNetworkContext();
const builder = context.createModelBuilder();
// 1. Create a model of the computational graph 'C = 0.2 * A + B'.
const context = navigator.ml.createContext();
const builder = new MLGraphBuilder(context);
// 1. Create a computational graph 'C = 0.2 * A + B'.
const constant = builder.constant(0.2);
const A = builder.input('A', operandType);
const B = builder.input('B', operandType);
const C = builder.add(builder.mul(A, constant), B);
const model = builder.createModel({'C': C});
// 2. Compile the model into executable.
const compilation = await model.compile();
// 3. Bind inputs to the model and execute for the result.
// 2. Compile it into an executable.
const graph = await builder.build({'C': C});
// 3. Bind inputs to the graph and execute for the result.
const bufferA = new Float32Array(4).fill(1.0);
const bufferB = new Float32Array(4).fill(0.8);
const inputs = {'A': {buffer: bufferA}, 'B': {buffer: bufferB}};
const outputs = await compilation.compute(inputs);
const inputs = {'A': {data: bufferA}, 'B': {data: bufferB}};
const outputs = await graph.compute(inputs);
// The computed result of [[1, 1], [1, 1]] is in the buffer associated with
// the output operand.
console.log('Output shape: ' + outputs.C.dimensions);
console.log('Output value: ' + outputs.C.buffer);
console.log('Output value: ' + outputs.C.data);
```

Check it out in [WebNN Code Editor](https://webmachinelearning.github.io/webnn-samples/code/?example=mul_add.js).
Expand Down Expand Up @@ -80,15 +79,14 @@ There are many important [application use cases](https://webmachinelearning.gith
// Noise Suppression Net 2 (NSNet2) Baseline Model for Deep Noise Suppression Challenge (DNS) 2021.
export class NSNet2 {
constructor() {
this.model = null;
this.compiledModel = null;
this.graph = null;
this.frameSize = 161;
this.hiddenSize = 400;
}

async load(baseUrl, batchSize, frames) {
const nn = navigator.ml.getNeuralNetworkContext();
const builder = nn.createModelBuilder();
const context = navigator.ml.createContext();
const builder = new MLGraphBuilder(context);
// Create constants by loading pre-trained data from .npy files.
const weight172 = await buildConstantByNpy(builder, baseUrl + '172.npy');
const biasFcIn0 = await buildConstantByNpy(builder, baseUrl + 'fc_in_0_bias.npy');
Expand Down Expand Up @@ -122,20 +120,20 @@ export class NSNet2 {
const relu163 = builder.relu(builder.add(builder.matmul(transpose159, weight215), biasFcOut0));
const relu167 = builder.relu(builder.add(builder.matmul(relu163, weight216), biasFcOut2));
const output = builder.sigmoid(builder.add(builder.matmul(relu167, weight217), biasFcOut4));
this.model = builder.createModel({output, gru94, gru157});
this.builder = builder;
}

async compile(options) {
this.compiledModel = await this.model.compile(options);
async build() {
this.graph = await this.builder.build({output, gru94, gru157});
}

async compute(inputBuffer, initialState92Buffer, initialState155Buffer) {
const inputs = {
input: {buffer: inputBuffer},
initialState92: {buffer: initialState92Buffer},
initialState155: {buffer: initialState155Buffer},
input: {data: inputBuffer},
initialState92: {data: initialState92Buffer},
initialState155: {data: initialState155Buffer},
};
return await this.compiledModel.compute(inputs);
return await this.graph.compute(inputs);
}
}
```
Expand Down
Loading