diff --git a/index.bs b/index.bs index a9d600b2..5c40f8ab 100644 --- a/index.bs +++ b/index.bs @@ -1666,11 +1666,11 @@ partial interface MLGraphBuilder { let currentRecurrentBias = []; for (let dir = 0; dir < numDirections; ++dir) { - currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); - currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); - currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, -1]), { axes: [0] })) : null); + currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, 3 * hidden_size, input_size]), { axes: [0] })); + currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, 3 * hidden_size, hidden_size]), { axes: [0] })); + currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null); currentRecurrentBias.push(options.recurrentBias ? - (builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, -1]), { axes: [0] })) : null); + (builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null); } for (let step = 0; step < steps; ++step) { @@ -1678,12 +1678,12 @@ partial interface MLGraphBuilder { let currentOutput = null; for (let dir = 0; dir < numDirections; ++dir) { - currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); + currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] })); } for (let dir = 0; dir < numDirections; ++dir) { let slice = (dir == 1 || options.direction == "backward" ? steps - step - 1 : step); - let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, -1, -1]), { axes: [0] }); + let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, batch_size, input_size]), { axes: [0] }); let result = builder.reshape( builder.gruCell( @@ -1758,11 +1758,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, hidden_size])) ) ) ) @@ -1778,11 +1778,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -1797,7 +1797,7 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size])) ), builder.mul( r, @@ -1805,7 +1805,7 @@ partial interface MLGraphBuilder { (options.recurrentBias ? builder.slice(options.recurrentBias, [2 * hiddenSize], [hiddenSize]) : zero), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -1823,11 +1823,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size])) ), builder.matmul( builder.mul(r, hiddenState), - builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -2121,13 +2121,13 @@ partial interface MLGraphBuilder { let currentPeepholeWeight = []; for (let dir = 0; dir < numDirections; ++dir) { - currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); - currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); - currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, -1]), { axes: [0] })) : null); + currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, 4 * hidden_size, input_size]), { axes: [0] })); + currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, 4 * hidden_size, hidden_size]), { axes: [0] })); + currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null); currentRecurrentBias.push(options.recurrentBias ? - (builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, -1]), { axes: [0] })) : null); + (builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null); currentPeepholeWeight.push(options.peepholeWeight ? - (builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, -1]), { axes: [0] })) : null); + (builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null); } for (let step = 0; step < steps; ++step) { @@ -2137,13 +2137,13 @@ partial interface MLGraphBuilder { let nextCell = null; for (let dir = 0; dir < numDirections; ++dir) { - currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); - currentCell.push(builder.squeeze(builder.slice(cellState, [dir, 0, 0], [1, -1, -1]), { axes: [0] })); + currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] })); + currentCell.push(builder.squeeze(builder.slice(cellState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] })); } for (let dir = 0; dir < numDirections; ++dir) { let slice = (dir == 1 || options.direction == "backward" ? steps - step - 1 : step); - let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, -1, -1]), { axes: [0] }); + let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, batch_size, input_size]), { axes: [0] }); let results = builder.lstmCell( currentInput, currentWeight[dir], currentRecurrentWeight[dir], @@ -2227,11 +2227,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, hidden_size])) ) ) ) @@ -2253,11 +2253,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -2274,11 +2274,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [3 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [3 * hiddenSize, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [3 * hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [3 * hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -2299,11 +2299,11 @@ partial interface MLGraphBuilder { builder.add( builder.matmul( input, - builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, input_size])) ), builder.matmul( hiddenState, - builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, -1])) + builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, hidden_size])) ) ) ) @@ -2690,23 +2690,15 @@ partial interface MLGraphBuilder { ### The slice() method ### {#api-mlgraphbuilder-slice} Produce a slice of the input tensor.
// This sample shows the case that the splits parameter is an array.
const outputs = [];
+ let starts = Array(input_rank).fill(0);
+ let sizes = input_shape;
let start = 0;
for (const size of splits) {
- outputs.push(builder.slice(input, [start], [size], { axes: [options.axis] }));
+ starts[options.axis] = start;
+ sizes[options.axis] = size;
+ outputs.push(builder.slice(input, starts, sizes));
start += size;
}
return outputs;