Skip to content
Merged
22 changes: 13 additions & 9 deletions bittensor/_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,14 +834,17 @@ def _check_for_cuda_reg_config( config: 'bittensor.Config' ) -> None:
for i, device in enumerate(devices):
choices_str += (" {}: {}\n".format(device, device_names[i]))
console.print(choices_str)
dev_id = IntListPrompt.ask("Which GPU(s) would you like to use?", choices=devices, default=str(bittensor.defaults.subtensor.register.cuda.dev_id))
try:
# replace the commas with spaces then split over whitespace.,
# then strip the whitespace and convert to ints.
dev_id = [int(dev_id.strip()) for dev_id in dev_id.replace(',', ' ').split()]
except ValueError:
console.error(":cross_mark:[red]Invalid GPU device[/red] [bold white]{}[/bold white]\nAvailable CUDA devices:{}".format(dev_id, choices_str))
sys.exit(1)
dev_id = IntListPrompt.ask("Which GPU(s) would you like to use? Please list one, or comma-separated", choices=devices, default='All')
if dev_id == 'All':
dev_id = list(range(torch.cuda.device_count()))
else:
try:
# replace the commas with spaces then split over whitespace.,
# then strip the whitespace and convert to ints.
dev_id = [int(dev_id.strip()) for dev_id in dev_id.replace(',', ' ').split()]
except ValueError:
console.error(":cross_mark:[red]Invalid GPU device[/red] [bold white]{}[/bold white]\nAvailable CUDA devices:{}".format(dev_id, choices_str))
sys.exit(1)
config.subtensor.register.cuda.dev_id = dev_id

def check_register_config( config: 'bittensor.Config' ):
Expand Down Expand Up @@ -954,5 +957,6 @@ class IntListPrompt(PromptBase):
def check_choice( self, value: str ) -> bool:
assert self.choices is not None
# check if value is a valid choice or all the values in a list of ints are valid choices
return value in self.choices or \
return value == "All" or \
value in self.choices or \
all( val.strip() in self.choices for val in value.replace(',', ' ').split( ))