[_unmask_unattended] Refactor#29356
Conversation
| is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | ||
| is_tracing |= isinstance(input_tensor, torch.fx.Proxy) |
There was a problem hiding this comment.
You can leave this as a one-liner. I believe in python a or b never evaluates b if a is already True.
| is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | ||
| is_tracing |= isinstance(input_tensor, torch.fx.Proxy) |
There was a problem hiding this comment.
Can you pass is_tracing as an Optional[bool]? This variable is already computed in the _prepare_xxx methods. And check it here only if it is None.
| def _unmask_unattended( | ||
| expanded_mask: torch.FloatTensor, | ||
| attention_mask: torch.FloatTensor, | ||
| input_tensor: torch.FloatTensor, |
There was a problem hiding this comment.
I don't think we need input_tensor, do we? The FX check should be doable on attention_mask.
There was a problem hiding this comment.
but sometimes the attention mask is None even when tracing no?
There was a problem hiding this comment.
I meant expanded_mask. See the type hint.
This method is never intended to be called with None input, if you want to change that you could just return and edit the type hint.
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Follow up on #29318, if we keep the
AttentionMaskConverterthen let's put the checks in the_unmask_unattendedsince this function is bound to disappear once the bug is fixed in torch@fxmarty let's make this easier to refactor later on, and more readable!