Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/gbatchkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,62 @@ class ComputeConfig(BaseModel):
)


# Parse strings in the format:
# machine_type:provisioning_model+accelerator_type:accelerator_count
# where the provisioning model and accelerator are optional.
#
# Defaults:
# - provisioning_model: SPOT
# - accelerator_type: none
# - accelerator_count: 0 (or 1, if accelerator_type is given)
def parse_compute_config(compute_str: str) -> ComputeConfig:
compute_parts = compute_str.split("+")

if len(compute_parts) == 1:
compute_parts = [compute_parts[0], ""]
elif len(compute_parts) > 2:
raise ValueError("Invalid compute config format")

if compute_parts[0] == "":
raise ValueError("Machine type is required")

machine_parts = compute_parts[0].split(":")
if len(machine_parts) == 1:
machine_type = machine_parts[0]
provisioning_model = "SPOT"
elif len(machine_parts) == 2:
machine_type, provisioning_model = machine_parts
else:
raise ValueError(f"Invalid machine type/provisioning model: {machine_parts}")

accelerator_parts = compute_parts[1].split(":")
if len(accelerator_parts) == 1:
accelerator_type = accelerator_parts[0]
if accelerator_type:
accelerator_count = 1
else:
accelerator_count = 0
elif len(accelerator_parts) == 2:
accelerator_type = accelerator_parts[0]
accelerator_count = int(accelerator_parts[1] or 0)

if accelerator_count > 0 and not accelerator_type:
raise ValueError(f"Accelerator count specified without accelerator type")
else:
raise ValueError(f"Invalid accelerator type/count: {accelerator_parts}")

# TODO: validate that machine_type is valid
# TODO: validate that accelerator_type is valid
# TODO: validate that machine_type + accelerator_type are valid together

return ComputeConfig(
machine_type=machine_type,
provisioning_model=provisioning_model or "SPOT",
accelerator_type=accelerator_type or "",
accelerator_count=int(accelerator_count or 0),
)


class NetworkInterfaceConfig(BaseModel):
network: str = Field(
default="",
Expand Down
42 changes: 42 additions & 0 deletions tests/gbatchkit/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from gbatchkit.types import parse_compute_config


def test_parse_compute_config():
# Test with all parts specified
config = parse_compute_config("n1-standard-8:SPOT+nvidia-tesla-t4:1")
assert config.machine_type == "n1-standard-8"
assert config.provisioning_model == "SPOT"
assert config.accelerator_type == "nvidia-tesla-t4"
assert config.accelerator_count == 1

config = parse_compute_config("n1-standard-8:SPOT+nvidia-tesla-t4")
assert config.accelerator_type == "nvidia-tesla-t4"
assert config.accelerator_count == 1

# Test with only machine type and provisioning model
config = parse_compute_config("n1-standard-8:SPOT")
assert config.machine_type == "n1-standard-8"
assert config.provisioning_model == "SPOT"
assert config.accelerator_type == ""
assert config.accelerator_count == 0

# Test with only machine type
config = parse_compute_config("n1-standard-8")
assert config.machine_type == "n1-standard-8"
assert config.provisioning_model == "SPOT"
assert config.accelerator_type == ""
assert config.accelerator_count == 0

# Test with empty string
with pytest.raises(ValueError):
parse_compute_config("")

# Test with invalid format
with pytest.raises(ValueError):
parse_compute_config("invalid+format+string")
with pytest.raises(ValueError):
parse_compute_config("invalid:format:string")
with pytest.raises(ValueError):
parse_compute_config("n1-standard-4:SPOT+gpu:1:1")