Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/art/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, TypeAlias
import warnings

import httpx
Expand All @@ -20,6 +20,10 @@
if TYPE_CHECKING:
from .model import Model, TrainableModel

# Type aliases for models with any config/state type (for backend method signatures)
AnyModel: TypeAlias = "Model[Any, Any]"
AnyTrainableModel: TypeAlias = "TrainableModel[Any, Any]"


class Backend:
def __init__(
Expand All @@ -39,7 +43,7 @@ async def close(self) -> None:

async def register(
self,
model: "Model",
model: AnyModel,
) -> None:
"""
Registers a model with the Backend for logging and/or training.
Expand All @@ -50,14 +54,14 @@ async def register(
response = await self._client.post("/register", json=model.safe_model_dump())
response.raise_for_status()

async def _get_step(self, model: "TrainableModel") -> int:
async def _get_step(self, model: AnyTrainableModel) -> int:
response = await self._client.post("/_get_step", json=model.safe_model_dump())
response.raise_for_status()
return response.json()

async def _delete_checkpoint_files(
self,
model: "TrainableModel",
model: AnyTrainableModel,
steps_to_keep: list[int],
) -> None:
response = await self._client.post(
Expand All @@ -68,7 +72,7 @@ async def _delete_checkpoint_files(

async def _prepare_backend_for_training(
self,
model: "TrainableModel",
model: AnyTrainableModel,
config: dev.OpenAIServerConfig | None,
) -> tuple[str, str]:
response = await self._client.post(
Expand All @@ -80,7 +84,7 @@ async def _prepare_backend_for_training(
base_url, api_key = tuple(response.json())
return base_url, api_key

def _model_inference_name(self, model: "Model", step: int | None = None) -> str:
def _model_inference_name(self, model: AnyModel, step: int | None = None) -> str:
"""Return the inference name for a model checkpoint.

Override in subclasses to provide backend-specific naming.
Expand All @@ -93,7 +97,7 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str:

async def train(
self,
model: "TrainableModel",
model: AnyTrainableModel,
trajectory_groups: Iterable[TrajectoryGroup],
**kwargs: Any,
) -> TrainResult:
Expand All @@ -114,7 +118,7 @@ async def train(

async def _train_model(
self,
model: "TrainableModel",
model: AnyTrainableModel,
trajectory_groups: list[TrajectoryGroup],
config: TrainConfig,
dev_config: dev.TrainConfig,
Expand Down Expand Up @@ -152,7 +156,7 @@ async def _train_model(
@log_http_errors
async def _experimental_pull_from_s3(
self,
model: "Model",
model: AnyModel,
*,
s3_bucket: str | None = None,
prefix: str | None = None,
Expand Down Expand Up @@ -191,7 +195,7 @@ async def _experimental_pull_from_s3(
@log_http_errors
async def _experimental_push_to_s3(
self,
model: "Model",
model: AnyModel,
*,
s3_bucket: str | None = None,
prefix: str | None = None,
Expand All @@ -215,7 +219,7 @@ async def _experimental_push_to_s3(
@log_http_errors
async def _experimental_fork_checkpoint(
self,
model: "Model",
model: AnyModel,
from_model: str,
from_project: str | None = None,
from_s3_bucket: str | None = None,
Expand Down
68 changes: 55 additions & 13 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from datetime import datetime
import json
import os
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Generic, Iterable, Optional, cast, overload
import warnings

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
import polars as pl
from pydantic import BaseModel
from typing_extensions import Never
from typing_extensions import Never, TypeVar

from . import dev
from .trajectories import Trajectory, TrajectoryGroup
Expand All @@ -23,11 +23,12 @@


ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None)
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])


class Model(
BaseModel,
Generic[ModelConfig],
Generic[ModelConfig, StateType],
):
"""
A model is an object that can be passed to your `rollout` function, and used
Expand Down Expand Up @@ -129,7 +130,7 @@ def __new__(
inference_model_name: str | None = None,
base_path: str = ".art",
report_metrics: list[str] | None = None,
) -> "Model[None]": ...
) -> "Model[None, dict[str, Any]]": ...

@overload
def __new__(
Expand All @@ -145,14 +146,14 @@ def __new__(
inference_model_name: str | None = None,
base_path: str = ".art",
report_metrics: list[str] | None = None,
) -> "Model[ModelConfig]": ...
) -> "Model[ModelConfig, dict[str, Any]]": ...

def __new__(
def __new__( # pyright: ignore[reportInconsistentOverload]
cls,
*args,
**kwargs,
) -> "Model[ModelConfig] | Model[None]":
return super().__new__(cls)
) -> "Model[ModelConfig, StateType]":
return super().__new__(cls) # type: ignore[return-value]

def safe_model_dump(self, *args, **kwargs) -> dict:
"""
Expand Down Expand Up @@ -260,6 +261,47 @@ def _get_output_dir(self) -> str:
"""Get the output directory for this model."""
return f"{self.base_path}/{self.project}/models/{self.name}"

def write_state(self, state: StateType) -> None:
"""Write persistent state to the model directory as JSON.

This state is stored in `state.json` within the model's output directory
and can be used to track training progress, dataset position, or any
other information that should persist across runs.

Args:
state: A dictionary of JSON-serializable values to persist.

Example:
model.write_state({
"step": 5,
"dataset_offset": 100,
"last_checkpoint_time": "2024-01-15T10:30:00",
})
"""
output_dir = self._get_output_dir()
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/state.json", "w") as f:
json.dump(state, f, indent=2)

def read_state(self) -> StateType | None:
"""Read persistent state from the model directory.

Returns:
The state dictionary if it exists, or None if no state has been saved.

Example:
state = model.read_state()
if state:
start_step = state["step"]
dataset_offset = state["dataset_offset"]
"""
output_dir = self._get_output_dir()
state_path = f"{output_dir}/state.json"
if not os.path.exists(state_path):
return None
with open(state_path, "r") as f:
return json.load(f)

def _get_wandb_run(self) -> Optional["Run"]:
"""Get or create the wandb run for this model."""
import wandb
Expand Down Expand Up @@ -429,7 +471,7 @@ async def get_step(self) -> int:
# ---------------------------------------------------------------------------


class TrainableModel(Model[ModelConfig], Generic[ModelConfig]):
class TrainableModel(Model[ModelConfig, StateType], Generic[ModelConfig, StateType]):
base_model: str
# Override discriminator field for FastAPI serialization
trainable: bool = True
Expand Down Expand Up @@ -480,7 +522,7 @@ def __new__(
base_path: str = ".art",
report_metrics: list[str] | None = None,
_internal_config: dev.InternalModelConfig | None = None,
) -> "TrainableModel[None]": ...
) -> "TrainableModel[None, dict[str, Any]]": ...

@overload
def __new__(
Expand All @@ -495,13 +537,13 @@ def __new__(
base_path: str = ".art",
report_metrics: list[str] | None = None,
_internal_config: dev.InternalModelConfig | None = None,
) -> "TrainableModel[ModelConfig]": ...
) -> "TrainableModel[ModelConfig, dict[str, Any]]": ...

def __new__(
def __new__( # pyright: ignore[reportInconsistentOverload]
cls,
*args,
**kwargs,
) -> "TrainableModel[ModelConfig] | TrainableModel[None]":
) -> "TrainableModel[ModelConfig, StateType]":
return super().__new__(cls) # type: ignore

def model_dump(self, *args, **kwargs) -> dict:
Expand Down