-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay] Port LSTM to Relay for testing #2011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
python/tvm/relay/testing/lstm.py
Outdated
| h2h = layers.dense_add_bias(data=inputs, weight=h2h_weight, | ||
| bias=h2h_bias, units=num_hidden * 4) | ||
|
|
||
| gates = relay.add(i2h, h2h) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed in the example that the line read simply, i2h + h2h, which was the only place where a + operator was used. There were several elemwise_add calls otherwise. What was the reason for the distinction? Did the + mean to concatenate, rather than add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it is addition, not concatenation. My guess is the plus operator is used because there is no need to assign a name to gates, but there is for next_c.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks
python/tvm/relay/testing/lstm.py
Outdated
| from . import layers | ||
| from .init import create_workload | ||
|
|
||
| def lstm_cell(inputs, states, i2h_weight, h2h_weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duly noted, will change
python/tvm/relay/testing/lstm.py
Outdated
| The result. | ||
| """ | ||
|
|
||
| i2h = layers.dense_add_bias(data=inputs, weight=i2h_weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight and bias should be left out. They will be created automatically by dense_add_bias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting
python/tvm/relay/testing/lstm.py
Outdated
| states = relay.var("states", | ||
| relay.TupleType([ | ||
| relay.TensorType((batch_size, num_hidden)), | ||
| relay.TensorType((batch_size, num_hidden))])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other explicit type annotations I could have in this function? It would be nice to annotate the return type on the function too, but I am not sure how to state it (do you have a suggestion, @jroesch?)
|
Something gives one way or another: Type unification is inadequate (I had to annotate the types because it would not be able to conclude that you can unify tuple(tensor, tensor) and tuple(unknown, unknown) -- there should really be a visitor for unification!) and the error reporting is worse (I know there's a shape mismatch somewhere but not where) |
|
I will do a substantial refactor using the ScopeBuilder to see if that might help anything |
80ccb97 to
dc4b027
Compare
|
Annotated every type and am no longer getting type errors in Relay (this points to some serious shortcomings in type unification and inference), but now it hangs on create_workload, possibly looping infinitely. Would appreciate pointers as to what could be done differently. |
| slice_gates = builder.let(("slice_gates", slice_type), | ||
| relay.split(gates, | ||
| indices_or_sections=4, | ||
| axis=1).astuple()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen The default value for axis in split is zero, but the relation for split rejects an axis of 0. That doesn't seem right -- should the relation be corrected, or the default argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split should be able to support axis=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change it then (very easy)
3cbde11 to
2143e21
Compare
|
@slyubomirsky can you update this PR? it is ready for another look? |
|
What updates, exactly, are needed? I addressed previous feedback. I suppose I can potentially reduce the number of explicit type annotations, though, so I will investigate whether that's possible. |
|
No, I just want to make sure if you need further updates, we can merge this in as it is |
|
Ah, in that case I would say it's not waiting on further changes. We can perhaps simplify the Relay implementation in a follow-up PR |
We would like to be able to evaluate Relay's performance on an LSTM, particularly since Relay can directly incorporate control flow. However, I could not find a concise example of an LSTM in NNVM to port over. @merrymercy pointed me to his own prior implementation of an LSTM cell, which is what I ported over.
However, in order to get this to match up with the other examples, I would also need to set up the rest of the network and load in a benchmark; I would appreciate any pointers as to how I could proceed, since I am not familiar with LSTMs. Namely, a pointer to NNVM implementations I could point over would be most helpful for being able to set up comparisons. (@tqchen, @merrymercy)
I would also appreciate any advice on how to best present this example (e.g., references to include).
Edit: Full disclosure, this variant of an LSTM is still an unrolled loop because we haven't merged in planned changes for abstract data types in Relay. Relay can currently handle a loop via recursion but without ADTs, it can't take in an arbitrary-length input list