Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/reference/profiles.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ The profile configuration supports many properties. See below.
max_price:
type: 'Optional[float]'

### `retry_policy`
### `retry`

#SCHEMA# dstack._internal.core.models.profiles.ProfileRetryPolicy
#SCHEMA# dstack._internal.core.models.profiles.ProfileRetry
overrides:
show_root_heading: false
19 changes: 10 additions & 9 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List

from rich.markup import escape
from rich.table import Table

from dstack._internal.cli.utils.common import add_row_from_dict, console
from dstack._internal.core.models.instances import InstanceAvailability
from dstack._internal.core.models.profiles import TerminationPolicy
from dstack._internal.core.models.runs import RunPlan
from dstack._internal.utils.common import pretty_date
from dstack._internal.utils.common import format_pretty_duration, pretty_date
from dstack.api import Run


Expand All @@ -23,18 +24,18 @@ def print_run_plan(run_plan: RunPlan, offers_limit: int = 3):
max_duration = (
f"{job_plan.job_spec.max_duration / 3600:g}h" if job_plan.job_spec.max_duration else "-"
)
retry_policy = job_plan.job_spec.retry_policy
retry_policy = (
(f"{retry_policy.duration / 3600:g}h" if retry_policy.duration else "yes")
if retry_policy.retry
else "no"
)
if job_plan.job_spec.retry is None:
retry = "no"
else:
retry = escape(job_plan.job_spec.retry.pretty_format())

profile = run_plan.run_spec.merged_profile
creation_policy = profile.creation_policy
termination_policy = profile.termination_policy
termination_idle_time = f"{profile.termination_idle_time}s"
if termination_policy == TerminationPolicy.DONT_DESTROY:
termination_idle_time = "-"
else:
termination_idle_time = format_pretty_duration(profile.termination_idle_time)

if req.spot is None:
spot_policy = "auto"
Expand All @@ -54,7 +55,7 @@ def th(s: str) -> str:
props.add_row(th("Max price"), max_price)
props.add_row(th("Max duration"), max_duration)
props.add_row(th("Spot policy"), spot_policy)
props.add_row(th("Retry policy"), retry_policy)
props.add_row(th("Retry policy"), retry)
props.add_row(th("Creation policy"), creation_policy)
props.add_row(th("Termination policy"), termination_policy)
props.add_row(th("Termination idle time"), termination_idle_time)
Expand Down
6 changes: 4 additions & 2 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.gcp.config import GCPConfig
from dstack._internal.core.errors import ComputeResourceNotFoundError, NoCapacityError
from dstack._internal.core.errors import (
ComputeResourceNotFoundError,
NoCapacityError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.instances import (
Expand Down Expand Up @@ -96,7 +99,6 @@ def create_instance(
instance_config: InstanceConfiguration,
) -> JobProvisioningData:
instance_name = instance_config.instance_name

if not gcp_resources.is_valid_resource_name(instance_name):
# In a rare case the instance name is invalid in GCP,
# we better use a random instance name than fail provisioning.
Expand Down
52 changes: 43 additions & 9 deletions src/dstack/_internal/core/models/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,43 @@ class ProfileRetryPolicy(CoreModel):

_validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration)

@root_validator()
@classmethod
def _validate_fields(cls, field_values):
if field_values["retry"] and "duration" not in field_values:
field_values["duration"] = DEFAULT_RETRY_DURATION
if field_values.get("duration") is not None:
field_values["retry"] = True
return field_values
@root_validator
def _validate_fields(cls, values):
if values["retry"] and "duration" not in values:
values["duration"] = DEFAULT_RETRY_DURATION
if values.get("duration") is not None:
values["retry"] = True
return values


class RetryEvent(str, Enum):
NO_CAPACITY = "no-capacity"
INTERRUPTION = "interruption"
ERROR = "error"


class ProfileRetry(CoreModel):
on_events: Annotated[
List[RetryEvent],
Field(
description=(
"The list of events that should be handled with retry."
" Supported events are `no-capacity`, `interruption`, and `error`"
)
),
]
duration: Annotated[
Optional[Union[int, str]],
Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"),
] = None

_validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration)

@root_validator
def _validate_fields(cls, values):
if len(values["on_events"]) == 0:
raise ValueError("`on_events` cannot be empty")
return values


class ProfileParams(CoreModel):
Expand All @@ -86,8 +115,13 @@ class ProfileParams(CoreModel):
description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
),
]
retry: Annotated[
Optional[Union[ProfileRetry, bool]],
Field(description="The policy for resubmitting the run. Defaults to `false`"),
]
retry_policy: Annotated[
Optional[ProfileRetryPolicy], Field(description="The policy for re-submitting the run")
Optional[ProfileRetryPolicy],
Field(description="The policy for resubmitting the run. Deprecated in favor of `retry`"),
]
max_duration: Annotated[
Optional[Union[Literal["off"], str, int]],
Expand Down
18 changes: 13 additions & 5 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
CreationPolicy,
Profile,
ProfileParams,
RetryEvent,
SpotPolicy,
TerminationPolicy,
)
from dstack._internal.core.models.repos import AnyRunRepoData
from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.utils import common as common_utils
from dstack._internal.utils.common import pretty_resources
from dstack._internal.utils.common import format_pretty_duration, pretty_resources


class AppSpec(CoreModel):
Expand Down Expand Up @@ -58,9 +59,14 @@ def is_finished(self):
return self in self.finished_statuses()


class RetryPolicy(CoreModel):
retry: bool
duration: Optional[int]
class Retry(CoreModel):
on_events: List[RetryEvent]
duration: int

def pretty_format(self) -> str:
pretty_duration = format_pretty_duration(self.duration)
events = ", ".join(event.value for event in self.on_events)
return f"{pretty_duration}[{events}]"


class RunTerminationReason(str, Enum):
Expand Down Expand Up @@ -187,7 +193,7 @@ class JobSpec(CoreModel):
max_duration: Optional[int]
registry_auth: Optional[RegistryAuth]
requirements: Requirements
retry_policy: RetryPolicy
retry: Optional[Retry]
working_dir: Optional[str]


Expand Down Expand Up @@ -225,6 +231,7 @@ class JobSubmission(CoreModel):
id: UUID4
submission_num: int
submitted_at: datetime
last_processed_at: datetime
finished_at: Optional[datetime]
status: JobStatus
termination_reason: Optional[JobTerminationReason]
Expand Down Expand Up @@ -323,6 +330,7 @@ class Run(CoreModel):
project_name: str
user: str
submitted_at: datetime
last_processed_at: datetime
status: RunStatus
termination_reason: Optional[RunTerminationReason]
run_spec: RunSpec
Expand Down
32 changes: 32 additions & 0 deletions src/dstack/_internal/core/services/profiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional

from dstack._internal.core.models.profiles import DEFAULT_RETRY_DURATION, Profile, RetryEvent
from dstack._internal.core.models.runs import Retry


def get_retry(profile: Profile) -> Optional[Retry]:
profile_retry = profile.retry
if profile_retry is None:
# Handle retry_policy before retry was introduced
# TODO: Remove once retry_policy no longer supported
profile_retry_policy = profile.retry_policy
if profile_retry_policy is None:
return None
if not profile_retry_policy.retry:
return None
duration = profile_retry_policy.duration or DEFAULT_RETRY_DURATION
return Retry(
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
duration=duration,
)
if isinstance(profile_retry, bool):
if profile_retry:
return Retry(
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
duration=DEFAULT_RETRY_DURATION,
)
return None
profile_retry = profile_retry.copy()
if profile_retry.duration is None:
profile_retry.duration = DEFAULT_RETRY_DURATION
return Retry.parse_obj(profile_retry)
69 changes: 41 additions & 28 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,18 @@
InstanceRuntime,
RemoteConnectionInfo,
)
from dstack._internal.core.models.profiles import Profile, TerminationPolicy
from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements
from dstack._internal.core.models.profiles import (
Profile,
RetryEvent,
TerminationPolicy,
)
from dstack._internal.core.models.runs import (
InstanceStatus,
JobProvisioningData,
Requirements,
Retry,
)
from dstack._internal.core.services.profiles import get_retry
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import InstanceModel, ProjectModel
from dstack._internal.server.schemas.runner import HealthcheckResponse
Expand Down Expand Up @@ -341,24 +351,6 @@ async def create_instance(instance_id: UUID) -> None:
)
).one()

if instance.retry_policy and instance.retry_policy_duration is not None:
retry_duration_deadline = _get_retry_duration_deadline(instance)
if get_current_datetime() > retry_duration_deadline:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
instance.termination_reason = "Retry duration expired"
await session.commit()
logger.warning(
"Retry duration expired. Terminate instance %s",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

if instance.last_retry_at is not None:
last_retry = instance.last_retry_at.replace(tzinfo=datetime.timezone.utc)
if get_current_datetime() < last_retry + timedelta(minutes=1):
Expand Down Expand Up @@ -386,10 +378,10 @@ async def create_instance(instance_id: UUID) -> None:
return

try:
profile = Profile.__response__.parse_raw(instance.profile)
requirements = Requirements.__response__.parse_raw(instance.requirements)
instance_configuration = InstanceConfiguration.__response__.parse_raw(
instance.instance_configuration
profile: Profile = Profile.__response__.parse_raw(instance.profile)
requirements: Requirements = Requirements.__response__.parse_raw(instance.requirements)
instance_configuration: InstanceConfiguration = (
InstanceConfiguration.__response__.parse_raw(instance.instance_configuration)
)
except ValidationError as e:
instance.status = InstanceStatus.TERMINATED
Expand All @@ -410,14 +402,35 @@ async def create_instance(instance_id: UUID) -> None:
await session.commit()
return

retry = get_retry(profile)
should_retry = retry is not None and RetryEvent.NO_CAPACITY in retry.on_events

if retry is not None:
retry_duration_deadline = _get_retry_duration_deadline(instance, retry)
if get_current_datetime() > retry_duration_deadline:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
instance.termination_reason = "Retry duration expired"
await session.commit()
logger.warning(
"Retry duration expired. Terminate instance %s",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

offers = await get_create_instance_offers(
project=instance.project,
profile=profile,
requirements=requirements,
exclude_not_available=True,
)

if not offers and instance.retry_policy:
if not offers and should_retry:
instance.last_retry_at = get_current_datetime()
await session.commit()
logger.debug(
Expand Down Expand Up @@ -479,7 +492,7 @@ async def create_instance(instance_id: UUID) -> None:

instance.last_retry_at = get_current_datetime()

if not instance.retry_policy:
if not should_retry:
instance.status = InstanceStatus.TERMINATED
instance.deleted = True
instance.deleted_at = get_current_datetime()
Expand Down Expand Up @@ -749,9 +762,9 @@ def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta:
return get_current_datetime() - last_time


def _get_retry_duration_deadline(instance: InstanceModel) -> datetime.datetime:
def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime:
return instance.created_at.replace(tzinfo=datetime.timezone.utc) + timedelta(
seconds=instance.retry_policy_duration
seconds=retry.duration
)


Expand Down
Loading