Fix ONNX exports for Optimum compatible models#31311
Fix ONNX exports for Optimum compatible models#31311amyeroberts merged 13 commits intohuggingface:mainfrom
Conversation
| def safe_int(x): | ||
| return x.to(torch.int64) if torch.jit.is_tracing() else int(x) | ||
| old_grid_size = safe_int(posemb_grid.size(0) ** 0.5) |
| new_height = (torch.ceil(orig_height / patch_height) * patch_height).to(torch.int64) | ||
| new_width = (torch.ceil(orig_width / patch_width) * patch_width).to(torch.int64) |
There was a problem hiding this comment.
Same comment as above - doesn't interpolate require (int, int) when not tracing?
There was a problem hiding this comment.
I'll check tracing, thanks for the heads up
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Very nice! Thanks for fixing this for all these models ❤️
Just a few small comments
| raise TypeError(f"Could not infer framework from class {model_class}.") | ||
|
|
||
|
|
||
| def safe_int(x): |
There was a problem hiding this comment.
Docstrings would be helpful here e.g. for inspecting in IDEs: what does it mean for an int to be safe?
There was a problem hiding this comment.
Indeed a better name is probably a good idea 😅 I called it safe_int in a way to "safely cast some value (which could be a python number or tensor) to an integer in a way that respects tracing"
There was a problem hiding this comment.
I'll swap with torch_int and torch_float
| new_width = int(math.ceil(orig_width / patch_width) * patch_width) | ||
| new_height = ( | ||
| safe_float(torch.ceil(orig_height / patch_height) * patch_height) | ||
| if torch.jit.is_tracing() |
There was a problem hiding this comment.
Do we need the conditional here? This is already handled in the safe_float and safe_int functions
There was a problem hiding this comment.
I think it's required for torch.ceil no?
There was a problem hiding this comment.
tbh, I don't know, is there a reason we couldn't usetorch.ceil directly?
There was a problem hiding this comment.
if I'm passing an int or float, torch.ceil will be called first and it will fail because torch.ceil can only be called with tensors AFAIK
There was a problem hiding this comment.
Only other Q here then is why do we use a float when tracing and int otherwise?
There was a problem hiding this comment.
sorry I think I was mistaken with that one, you're right, I fixed it :)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
|
@amyeroberts the failing tests seem irrelevant to this PR, I can't re-run them, can you re-run? |
|
@merveenoyan si si - done! |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for fixing all of these!
Just for my own understanding - it there any reason to not use the torch compatible float/int when were not tracing?
|
@amyeroberts to my understanding, torch ONNX export internally calls |
|
@amyeroberts can you merge if you think it's ok? |
Right, I see why we need to do it for the onnx export, but for day-to-day use could we just use torch primitives instead of a python |
|
@amyeroberts I guess if it's just torch modelling code then yes. Would you like me to swap everything? |
|
also asking the same question to @xenova |
@merveenoyan Yes please! This will be cleaner and easier to follow in the code :) |
I agree with @amyeroberts - if there is a way to "do everything in torch land", that's the best solution! However, there are cases where I'm not entirely sure how to do this. For example, with
See here for example code (DinoV2 backbone): if torch.jit.is_tracing():
sqrt_N = N ** 0.5
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="bicubic",
antialias=self.interpolate_antialias,
)
else:
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
antialias=self.interpolate_antialias,
)Very ugly... I know :/ |
|
@xenova sounds good, very glad to work with you tbh I didn't know that it would be required in inference. |
|
@merveenoyan My understanding from above was that the PR would be updated to remove all the if/else structures wherever possible (but as @xenova points out isn't everywhere unfortunately) |
|
@amyeroberts from what I understood we should still keep them in if/else not to break the inference (I'm also scared of edge cases if there is etc) so I'd rather keep them. what I can do is to test all of them to see if they break or not when all are tensors and remove where it doesn't have to be a python type |
|
@merveenoyan OK. Let's just merge then and we can follow up in future PRs 👍 |

@amyeroberts as discussed and also pinging @xenova for review :') (who also fixed DPT)
I prioritized Optimum compatible ones because I'm launching a project where there's Optimum examples for vision models. I will have a separate PR for the models that aren't compatible with Optimum. Rest of the Optimum compatible models export well without a problem.