Handle PyTorch to Flax conversion of 1D convolutions#15519
Handle PyTorch to Flax conversion of 1D convolutions#15519patil-suraj merged 1 commit intohuggingface:masterfrom sanchit-gandhi:flax-wav2vec2
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
LGTM! Thanks for handling this!
Note for the future: We should try to find a more robust way of detecting Conv layers.
Also cc @patrickvonplaten
|
I'm a bit surprised that we needed that. We already had 1D Conv layers in Flax in Wav2Vec2 and the conversion worked |
|
|
||
| # conv1d layer | ||
| renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) | ||
| if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key): |
There was a problem hiding this comment.
What we do here is equivalent to what is done below for # linear layer - I don't understand why we've added this. Was some code failing before?
pt_tensor.transpose(2, 1, 0) is the same as pt_tensor.T and if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key) is true then pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key) is also True.
We should avoid at all cost to add more complexity to those weight conversion statements and keep things as simple as possible. Remember that all PT<->Flax conversions depend on this code and we should not clutter be extra careful here. If we add a new statement like this, there has to be at least a test that ensures that this change is needed. At the moment I cannot think of a single use case where this is the case -> we already had 1D-conv layer conversions working for Flax Wav2Vec2 <-> PT Wav2Vec2
=> IMO we should revert this PR. The comments could indeed be improved however, so I'm happy to change the comment # conv layer to # 2D - conv layer and # linear layer to # linear and 1D - conv layer
There was a problem hiding this comment.
Completely agree, we actually don't need this. My bad, not my best review. Thanks a lot!
This reverts commit 854a0d5.
…face#15519)" (huggingface#15540) This reverts commit 854a0d5.
What does this PR do?
Currently, only 2-dimensional convolutional layers are renamed and reshaped in the PyTorch to Flax conversion script. This PR handles the case of 1-dimensional convolutions layers, in an entirely equivalent way to their 2-dimensional counterparts.