Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/contrib/target/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def convert_attributes(cls, attrs):


class Cast(OpConverter):
""" Operator converter for Cast."""
"""Operator converter for Cast."""

@classmethod
def convert_attributes(cls, attrs):
Expand Down
125 changes: 125 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,128 @@ def to_int_list(np_array):
cause problems in relay/TOPI.
"""
return [int(x) for x in np_array]


def unbind(data, axis=0):
"""
Unbind was taken from Pytorch frontend. The operation removes a tensor dimension
and returns a tuple of all slices along a given dimension, with specified axis removed.
TODO (vvchernov): It needs such operation on relay side to reduce time consumption
on squeeze operation.

Parameters
----------
data : relay.Expr
Input tensor
axis : int
Axis along which tensor is split.
Returns
-------
result : List[relay.Expr]
The sequence of computed tensors
"""
shape = infer_shape(data)
if axis >= len(shape):
msg = "Please check input dim, it shouldn't be greater than or equal to rank."
raise AttributeError(msg)

selections = shape[axis]
res_split = _op.split(data, selections, axis)
ret = []
for i in range(selections):
ret.append(_op.squeeze(res_split[i], axis=[axis]))
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def lstm_cell(
input_seqs,
hidden_state,
cell_state,
w_inp,
w_hid,
b_inp=None,
b_hid=None,
proj=None,
p_i=None,
p_f=None,
p_o=None,
f_act=_op.sigmoid,
g_act=_op.tanh,
h_act=_op.tanh,
backwards=False,
):
"""
Common implementation of LSTM cell for all frontends of TVM
TODO (vvchernov): currently it is used by onnx and pytorch. Extend for other frontends

Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch, hidden_size)
cell_state : relay.Expr
Cell state. shape = (batch, hidden_size)
w_inp, w_hid : relay.Expr
weight matrices. wi shape = (4 * hidden_size, feature_size)
wh shape = (4 * hidden_size, hidden_size or proj_size)
NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates.
The order is important for correct LSTM calculation!
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size)
proj : relay.Expr
projection matrix. shape = (proj_size, hidden_size)
p_i, p_f, p_o : relay.Expr
peephole LSTM matrices. shape = (batch, hidden_size)
f_act, g_act, h_act : relay.op
activation funtions
backwards : bool
Flag for reverse pass of LSTM

Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
# x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size)
step = _op.concatenate([x_t, hidden_state], axis=1)
cat_w = _op.concatenate([w_inp, w_hid], axis=1)
# Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid)
# nn.dense(step, cat_w) is used
# gates shape = (batch, 4 * hidden_size)
gates = _op.nn.dense(step, cat_w)
# Add biases
if b_inp is not None:
gates += b_inp
if b_hid is not None:
gates += b_hid
# any gate shape = (batch, hidden_size)
inp_gate, fgt_gate, cell_gate, otp_gate = _op.split(gates, 4, axis=-1)

if p_i is not None and p_f is not None:
inp_gate = f_act(inp_gate + p_i * cell_state)
fgt_gate = f_act(fgt_gate + p_f * cell_state)
else:
inp_gate = f_act(inp_gate)
fgt_gate = f_act(fgt_gate)

cell_gate = g_act(cell_gate)
cell_state = fgt_gate * cell_state + inp_gate * cell_gate
if p_o is not None:
otp_gate = f_act(otp_gate + p_o * cell_state)
else:
otp_gate = f_act(otp_gate)

hidden_state = otp_gate * h_act(cell_state)

if proj is not None:
hidden_state = _op.nn.dense(hidden_state, proj)

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state, cell_state
180 changes: 88 additions & 92 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
infer_type,
infer_value,
new_var,
unbind,
lstm_cell,
)

__all__ = ["from_onnx"]
Expand Down Expand Up @@ -2142,58 +2144,44 @@ class LSTM(RNN):
"""Operator converter for LSTM"""

@classmethod
def generate_lstm(
cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False
def bidir_lstm_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled lstm loop.

See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
if B is not None:
WB, RB = _op.split(B, 2)
gates += WB + RB
i, o, f, c = _op.split(gates, 4, axis=-1)

if p_i != 0:
i = f_act(i + p_i * C_t)
else:
i = f_act(i)

if p_f != 0:
f = f_act(f + p_f * C_t)
else:
f = f_act(f)

c = g_act(c)
C = f * C_t + i * c
if p_o != 0:
o = f_act(o + p_o * C)
else:
o = f_act(o)

H = o * h_act(C)

H_t = H
C_t = C
h_list.append(_op.expand_dims(H, axis=0))
Bidirectional LSTM cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t, fw_C_t = lstm_cell(
input_seqs,
**weight_dicts[0],
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t, rev_C_t = lstm_cell(
input_seqs,
**weight_dicts[1],
f_act=acts[3],
g_act=acts[4],
h_act=acts[5],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
C_t = _op.expand_dims(C_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t, C_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
_op.stack([fw_C_t, rev_C_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand Down Expand Up @@ -2224,12 +2212,6 @@ def _impl_v7(cls, inputs, attr, params):
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Cp_0 is None:
Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype)
if Pp is not None:
p_i, p_o, p_f = _op.split(Pp, 3, axis=1)
else:
p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2260,53 +2242,67 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
result_output = []
result_H = []
result_C = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

H_ts = _op.split(Hp_0, num_directions)
C_ts = _op.split(Cp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)
p_is = _op.split(p_i, num_directions)
p_fs = _op.split(p_f, num_directions)
p_os = _op.split(p_o, num_directions)
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
C_t = _op.squeeze(C_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
p_i = _op.squeeze(p_is[i], axis=[0])
p_f = _op.squeeze(p_fs[i], axis=[0])
p_o = _op.squeeze(p_os[i], axis=[0])

f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3]
output, H, C = LSTM.generate_lstm(
X_steps=X_steps,
H_t=H_t,
C_t=C_t,
W=W,
R=R,
B=B,
p_i=p_i,
p_f=p_f,
p_o=p_o,
f_act=f_act,
g_act=g_act,
h_act=h_act,
backwards=i == 1,
)
if Bp is not None:
Bs = _op.split(Bp, num_directions)
if Pp is not None:
p_i, p_o, p_f = _op.split(Pp, 3, axis=1)

result_output.append(output)
result_H.append(H)
result_C.append(C)
p_is = _op.split(p_i, num_directions)
p_fs = _op.split(p_f, num_directions)
p_os = _op.split(p_o, num_directions)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
C = _op.concatenate(result_C, axis=0)
weights_dicts = []
for i in range(num_directions):
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["cell_state"] = _op.squeeze(C_ts[i], axis=[0])

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4)
weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0)
mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4)
weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4)
weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0)
mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4)
weights_dict["b_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0)
if Pp is not None:
weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0])
weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0])
weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0])
weights_dicts.append(weights_dict)

if num_directions == 2:
output, H, C = LSTM.bidir_lstm_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H, C = lstm_cell(
input_seqs=X_steps,
**weights_dicts[0],
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)
C = _op.expand_dims(C, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3)

Expand Down
Loading