Skip to content

Requires for torch.tensor before casting#31755

Merged
LysandreJik merged 1 commit intohuggingface:mainfrom
echarlaix:fix-onnx
Jul 3, 2024
Merged

Requires for torch.tensor before casting#31755
LysandreJik merged 1 commit intohuggingface:mainfrom
echarlaix:fix-onnx

Conversation

@echarlaix
Copy link
Copy Markdown
Collaborator

@echarlaix echarlaix commented Jul 2, 2024

Fixes ONNX export for swin, swin-donut and clap models

self.shift_size = torch_int(0)

coming from :

self.shift_size = torch_int(0)

introduced in #31311

as torch_int is expecting a torch.Tensor :

return x.to(torch.int64) if torch.jit.is_tracing() else int(x)

also I think we should be able to have here

self.shift_size = 0

cc @merveenoyan @xenova

@echarlaix echarlaix requested a review from amyeroberts July 2, 2024 17:16
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

import torch

return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can use the torch_int utility instead

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean ensuring shift_size is a torch.Tensor when jit tracing in the modeling directly ?

self.shift_size = torch_int(0)

import torch

return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here torch_float

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since we're explicitly calling torch_float where float was used to be it should be fine, no?

@echarlaix
Copy link
Copy Markdown
Collaborator Author

echarlaix commented Jul 3, 2024

@amyeroberts do you think this fix could be included in a patch release ?

Copy link
Copy Markdown
Contributor

@merveenoyan merveenoyan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@LysandreJik LysandreJik merged commit dc72fd7 into huggingface:main Jul 3, 2024
@LysandreJik
Copy link
Copy Markdown
Member

@echarlaix I'm happy to include it in a patch release towards the end of the week

@echarlaix echarlaix deleted the fix-onnx branch July 3, 2024 09:18
@echarlaix
Copy link
Copy Markdown
Collaborator Author

@echarlaix I'm happy to include it in a patch release towards the end of the week

thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants