fix bug when getting the real accelerator's device number#2874
fix bug when getting the real accelerator's device number#2874faaany wants to merge 3 commits intohuggingface:mainfrom
Conversation
|
OK for me. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
SunMarc
left a comment
There was a problem hiding this comment.
Sounds good ! I left a comment about mlu since i'm not sure this is safe to change the expected_device_type.
| elif is_mlu_available(): | ||
| expected_device_type = "mlu" |
There was a problem hiding this comment.
The issue you shared @faaany is about the to() methods for mlu. I'm not sure if torch.device(d).type will really return cuda with mlu. Do you have any insights @huismiling since you added the support to mlu ? To be safe, I would suggest reverting this change.
There was a problem hiding this comment.
good idea, I will update it. But I am also very curious about the behavior on npu:
Hi @statelesshz , could you help us verify what torch.device(0) returns on NPU? It is cuda or npu? Thanks a lot!
There was a problem hiding this comment.
@faaany Hi, MLU devices type is mlu .
>>> torch.device(0).type
'mlu'
|
That's great @faaany ! |
|
Nice :) |
What does this PR do?
This PR is a follow-up fix for PR #2826 and I want to correct my statement in that PR that torch.device(d).type == "xpu" is enough to check the xpu device just like the case in npu and mlu. This was my mistake. In fact,
torch.device(0).typewill always return "cuda" on XPU as can be seen from the pytorch code and from the pytorch offical doc at least for now. But we are working on a PR to support it in the future pytorch version. Also for NPU path, I think torch.device(0).typewill returncuda` as can be seen here.In addition, users might pass device id that exceeds the available device count. For this case, we will not count that incorrect id to
num_deviceswhen calculating the balanced memory. So this PR actually fixes 2 issues:num_devicesfor non-cuda devices will always be 0num_deviceswill include device index that is larger than the available device numberWho can review?
@SunMarc and @muellerzr