diff --git a/src/gbatchkit/types.py b/src/gbatchkit/types.py index e478722..d12e45f 100644 --- a/src/gbatchkit/types.py +++ b/src/gbatchkit/types.py @@ -24,6 +24,62 @@ class ComputeConfig(BaseModel): ) +# Parse strings in the format: +# machine_type:provisioning_model+accelerator_type:accelerator_count +# where the provisioning model and accelerator are optional. +# +# Defaults: +# - provisioning_model: SPOT +# - accelerator_type: none +# - accelerator_count: 0 (or 1, if accelerator_type is given) +def parse_compute_config(compute_str: str) -> ComputeConfig: + compute_parts = compute_str.split("+") + + if len(compute_parts) == 1: + compute_parts = [compute_parts[0], ""] + elif len(compute_parts) > 2: + raise ValueError("Invalid compute config format") + + if compute_parts[0] == "": + raise ValueError("Machine type is required") + + machine_parts = compute_parts[0].split(":") + if len(machine_parts) == 1: + machine_type = machine_parts[0] + provisioning_model = "SPOT" + elif len(machine_parts) == 2: + machine_type, provisioning_model = machine_parts + else: + raise ValueError(f"Invalid machine type/provisioning model: {machine_parts}") + + accelerator_parts = compute_parts[1].split(":") + if len(accelerator_parts) == 1: + accelerator_type = accelerator_parts[0] + if accelerator_type: + accelerator_count = 1 + else: + accelerator_count = 0 + elif len(accelerator_parts) == 2: + accelerator_type = accelerator_parts[0] + accelerator_count = int(accelerator_parts[1] or 0) + + if accelerator_count > 0 and not accelerator_type: + raise ValueError(f"Accelerator count specified without accelerator type") + else: + raise ValueError(f"Invalid accelerator type/count: {accelerator_parts}") + + # TODO: validate that machine_type is valid + # TODO: validate that accelerator_type is valid + # TODO: validate that machine_type + accelerator_type are valid together + + return ComputeConfig( + machine_type=machine_type, + provisioning_model=provisioning_model or "SPOT", + accelerator_type=accelerator_type or "", + accelerator_count=int(accelerator_count or 0), + ) + + class NetworkInterfaceConfig(BaseModel): network: str = Field( default="", diff --git a/tests/gbatchkit/types_test.py b/tests/gbatchkit/types_test.py new file mode 100644 index 0000000..308ae13 --- /dev/null +++ b/tests/gbatchkit/types_test.py @@ -0,0 +1,42 @@ +import pytest + +from gbatchkit.types import parse_compute_config + + +def test_parse_compute_config(): + # Test with all parts specified + config = parse_compute_config("n1-standard-8:SPOT+nvidia-tesla-t4:1") + assert config.machine_type == "n1-standard-8" + assert config.provisioning_model == "SPOT" + assert config.accelerator_type == "nvidia-tesla-t4" + assert config.accelerator_count == 1 + + config = parse_compute_config("n1-standard-8:SPOT+nvidia-tesla-t4") + assert config.accelerator_type == "nvidia-tesla-t4" + assert config.accelerator_count == 1 + + # Test with only machine type and provisioning model + config = parse_compute_config("n1-standard-8:SPOT") + assert config.machine_type == "n1-standard-8" + assert config.provisioning_model == "SPOT" + assert config.accelerator_type == "" + assert config.accelerator_count == 0 + + # Test with only machine type + config = parse_compute_config("n1-standard-8") + assert config.machine_type == "n1-standard-8" + assert config.provisioning_model == "SPOT" + assert config.accelerator_type == "" + assert config.accelerator_count == 0 + + # Test with empty string + with pytest.raises(ValueError): + parse_compute_config("") + + # Test with invalid format + with pytest.raises(ValueError): + parse_compute_config("invalid+format+string") + with pytest.raises(ValueError): + parse_compute_config("invalid:format:string") + with pytest.raises(ValueError): + parse_compute_config("n1-standard-4:SPOT+gpu:1:1")