diff --git a/bittensor/_cli/__init__.py b/bittensor/_cli/__init__.py index d8c64be787..870b442e6a 100644 --- a/bittensor/_cli/__init__.py +++ b/bittensor/_cli/__init__.py @@ -834,14 +834,17 @@ def _check_for_cuda_reg_config( config: 'bittensor.Config' ) -> None: for i, device in enumerate(devices): choices_str += (" {}: {}\n".format(device, device_names[i])) console.print(choices_str) - dev_id = IntListPrompt.ask("Which GPU(s) would you like to use?", choices=devices, default=str(bittensor.defaults.subtensor.register.cuda.dev_id)) - try: - # replace the commas with spaces then split over whitespace., - # then strip the whitespace and convert to ints. - dev_id = [int(dev_id.strip()) for dev_id in dev_id.replace(',', ' ').split()] - except ValueError: - console.error(":cross_mark:[red]Invalid GPU device[/red] [bold white]{}[/bold white]\nAvailable CUDA devices:{}".format(dev_id, choices_str)) - sys.exit(1) + dev_id = IntListPrompt.ask("Which GPU(s) would you like to use? Please list one, or comma-separated", choices=devices, default='All') + if dev_id == 'All': + dev_id = list(range(torch.cuda.device_count())) + else: + try: + # replace the commas with spaces then split over whitespace., + # then strip the whitespace and convert to ints. + dev_id = [int(dev_id.strip()) for dev_id in dev_id.replace(',', ' ').split()] + except ValueError: + console.error(":cross_mark:[red]Invalid GPU device[/red] [bold white]{}[/bold white]\nAvailable CUDA devices:{}".format(dev_id, choices_str)) + sys.exit(1) config.subtensor.register.cuda.dev_id = dev_id def check_register_config( config: 'bittensor.Config' ): @@ -954,5 +957,6 @@ class IntListPrompt(PromptBase): def check_choice( self, value: str ) -> bool: assert self.choices is not None # check if value is a valid choice or all the values in a list of ints are valid choices - return value in self.choices or \ + return value == "All" or \ + value in self.choices or \ all( val.strip() in self.choices for val in value.replace(',', ' ').split( ))