Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 38 additions & 42 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -1666,24 +1666,24 @@ partial interface MLGraphBuilder {
let currentRecurrentBias = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, -1]), { axes: [0] })) : null);
currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, 3 * hidden_size, input_size]), { axes: [0] }));
currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, 3 * hidden_size, hidden_size]), { axes: [0] }));
currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null);
currentRecurrentBias.push(options.recurrentBias ?
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, -1]), { axes: [0] })) : null);
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null);
}

for (let step = 0; step < steps; ++step) {
let currentHidden = [];
let currentOutput = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] }));
}

for (let dir = 0; dir < numDirections; ++dir) {
let slice = (dir == 1 || options.direction == "backward" ? steps - step - 1 : step);
let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, -1, -1]), { axes: [0] });
let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, batch_size, input_size]), { axes: [0] });

let result = builder.reshape(
builder.gruCell(
Expand Down Expand Up @@ -1758,11 +1758,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -1778,11 +1778,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -1797,15 +1797,15 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size]))
),
builder.mul(
r,
builder.add(
(options.recurrentBias ? builder.slice(options.recurrentBias, [2 * hiddenSize], [hiddenSize]) : zero),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -1823,11 +1823,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size]))
),
builder.matmul(
builder.mul(r, hiddenState),
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand Down Expand Up @@ -2121,13 +2121,13 @@ partial interface MLGraphBuilder {
let currentPeepholeWeight = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, -1]), { axes: [0] })) : null);
currentWeight.push(builder.squeeze(builder.slice(weight, [dir, 0, 0], [1, 4 * hidden_size, input_size]), { axes: [0] }));
currentRecurrentWeight.push(builder.squeeze(builder.slice(recurrentWeight, [dir, 0, 0], [1, 4 * hidden_size, hidden_size]), { axes: [0] }));
currentBias.push(options.bias ? (builder.squeeze(builder.slice(options.bias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
currentRecurrentBias.push(options.recurrentBias ?
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, -1]), { axes: [0] })) : null);
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
currentPeepholeWeight.push(options.peepholeWeight ?
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, -1]), { axes: [0] })) : null);
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
}

for (let step = 0; step < steps; ++step) {
Expand All @@ -2137,13 +2137,13 @@ partial interface MLGraphBuilder {
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentCell.push(builder.squeeze(builder.slice(cellState, [dir, 0, 0], [1, -1, -1]), { axes: [0] }));
currentHidden.push(builder.squeeze(builder.slice(hiddenState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] }));
currentCell.push(builder.squeeze(builder.slice(cellState, [dir, 0, 0], [1, batch_size, hidden_size]), { axes: [0] }));
}

for (let dir = 0; dir < numDirections; ++dir) {
let slice = (dir == 1 || options.direction == "backward" ? steps - step - 1 : step);
let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, -1, -1]), { axes: [0] });
let currentInput = builder.squeeze(builder.slice(input, [slice, 0, 0], [1, batch_size, input_size]), { axes: [0] });

let results = builder.lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
Expand Down Expand Up @@ -2227,11 +2227,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [0, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [0, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -2253,11 +2253,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [2 * hiddenSize, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [2 * hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -2274,11 +2274,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [3 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [3 * hiddenSize, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [3 * hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [3 * hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand All @@ -2299,11 +2299,11 @@ partial interface MLGraphBuilder {
builder.add(
builder.matmul(
input,
builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(weight, [hiddenSize, 0], [hiddenSize, input_size]))
),
builder.matmul(
hiddenState,
builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, -1]))
builder.transpose(builder.slice(recurrentWeight, [hiddenSize, 0], [hiddenSize, hidden_size]))
)
)
)
Expand Down Expand Up @@ -2690,23 +2690,15 @@ partial interface MLGraphBuilder {
### The slice() method ### {#api-mlgraphbuilder-slice}
Produce a slice of the input tensor.
<script type=idl>
dictionary MLSliceOptions {
sequence<unsigned long> axes;
};

partial interface MLGraphBuilder {
MLOperand slice(MLOperand input, sequence<long> starts, sequence<long> sizes,
optional MLSliceOptions options = {});
MLOperand slice(MLOperand input, sequence<unsigned long> starts, sequence<unsigned long> sizes);
};
</script>
<div algorithm=slice>
**Arguments:**
- *input*: an {{MLOperand}}. The input tensor.
- *starts*: a sequence of {{long}}. The starting indices to slice of the corresponding axes of the input shape. A negative index value is interpreted as counting back from the end. For example, the value -1
- *sizes*: a sequence of {{long}}. The lengths to slice of the corresponding axes of the input shape.
The length value of -1 selects all the remaining elements from the starting index of the given axis.
- *options*: an optional {{MLSliceOptions}}. The optional parameters of the operation.
- *axes*: a sequence of {{unsigned long}}. The dimensions of the input shape to which *starts* and *sizes* apply. The values in the sequence must be in the range [0, N-1] where N is the rank of input tensor. When not specified, the sequence is assumed to be [0, ..., N-1], e.g. [0,1,2] for a 3-D tensor.
- *starts*: a sequence of {{unsigned long}}. The sequence of unsigned integer values indicating the starting index to slice of each input dimension, of length N where N is the rank of the input tensor. For each dimension *d* of *input*, *starts[d]* indicates the starting index to slice in that dimension. The starting index must be in the range [0, input size - 1] in that dimension.
- *sizes*: a sequence of {{unsigned long}}. The sequence of unsigned integer values indicating the number of elements to slice of each input dimension, of length N where N is the rank of the input tensor. For each dimension *d* of *input*, *sizes[d]* indicates the number of elements to slice in that dimension. The size must not be 0 and must satisfy the constraint *starting index + size <= input size* in that dimension.

**Returns:** an {{MLOperand}}. The output tensor of the same rank as the input tensor with tensor values stripped to the specified starting and ending indices in each dimension.
</div>
Expand Down Expand Up @@ -2839,9 +2831,13 @@ partial interface MLGraphBuilder {
<pre highlight="js">
// This sample shows the case that the splits parameter is an array.
const outputs = [];
let starts = Array(input_rank).fill(0);
let sizes = input_shape;
let start = 0;
for (const size of splits) {
outputs.push(builder.slice(input, [start], [size], { axes: [options.axis] }));
starts[options.axis] = start;
sizes[options.axis] = size;
outputs.push(builder.slice(input, starts, sizes));
start += size;
}
return outputs;
Expand Down