align xpu's autocast behavior w/ cuda by using device agnostic torch APIs#38284
align xpu's autocast behavior w/ cuda by using device agnostic torch APIs#38284ydshieh merged 25 commits intohuggingface:mainfrom
Conversation
cuda Signed-off-by: Matrix Yao <matrix.yao@intel.com>
|
ci failure seems not brought by my PR |
…magegpt Signed-off-by: Matrix Yao <matrix.yao@intel.com>
|
|
||
| # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) | ||
| with torch.amp.autocast(query.device.type, enabled=False): | ||
| with torch.autocast(query.device.type, enabled=False): |
There was a problem hiding this comment.
align to other modeling code in models directory
|
@ArthurZucker @IlyasMoutawwakil could you help review and comment? Thx very much |
Signed-off-by: Matrix YAO <matrix.yao@intel.com>
|
ci failure maybe because of the instable ci env |
|
CI seems clear now! cc @IlyasMoutawwakil |
|
@IlyasMoutawwakil , could you help review? Thx very much. |
| input_dtype = query_states.dtype | ||
| device_type = ( | ||
| query_states.device.type | ||
| if isinstance(query_states.device.type, str) and query_states.device.type != "mps" |
There was a problem hiding this comment.
why would query_states.device.type be anything other than str ?
There was a problem hiding this comment.
and what's the problem with mps exactly ?
There was a problem hiding this comment.
i don't know, it's a existing practice in original code
, and i reuse it because i don't have mps so just follow the existing behavior. I can see it also be in chameleon and recurrent_gemma modeling since the first PR, so i cannot retrieve the history on why using this. Maybe @ArthurZucker and @zucchini-nlp know the reason.There was a problem hiding this comment.
yeah, I think it comes from this comment #29285 (comment). The changes were first added in this PR
There was a problem hiding this comment.
thx @zucchini-nlp, this PR explains why disable casting is needed.
And, @IlyasMoutawwakil , this PR #29439 , explains why mps is excluded, it's because mps doesn't support amp.
For when query_states.device.type is not a str, I guess it's a backward-compatible behavior, because before PT 2.0, there are only 2 values for torch.device.type which are "cpu" and "cuda", so for devices like "mps" (supported since PT 1.12), so for PT 1.13, tensor device type in "mps" may return None. So need this guard. But I don't have device to confirm, just my guess.
There was a problem hiding this comment.
thanks, so I guess we can remove the str check, since support for torch 1.x was dropped a while ago.
There was a problem hiding this comment.
@IlyasMoutawwakil , done, pls help review again, thx.
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
|
@Rocketknight1 , do you know who else need review this PR after Ilyas approved? Thx. |
|
run-slow: qwen2_5_omni, gemma, phimoe, qwen2_moe, gpt2, distilbert |
|
This comment contains run-slow, running the specified jobs: models: ['models/distilbert', 'models/gemma', 'models/gpt2', 'models/phimoe', 'models/qwen2_5_omni', 'models/qwen2_moe'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker, pls help review, thx very much.