diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 89bac4fec212..c7fb7c2c7aeb 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -246,7 +246,7 @@ class ChatArguments: default="main", metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."}, ) - device: str = field(default="cpu", metadata={"help": "Device to use for inference."}) + device: str = field(default="auto", metadata={"help": "Device to use for inference."}) torch_dtype: Optional[str] = field( default="auto", metadata={