diff --git a/src/art/backend.py b/src/art/backend.py index 8d07e7153..c9a8bed81 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, TypeAlias import warnings import httpx @@ -20,6 +20,10 @@ if TYPE_CHECKING: from .model import Model, TrainableModel +# Type aliases for models with any config/state type (for backend method signatures) +AnyModel: TypeAlias = "Model[Any, Any]" +AnyTrainableModel: TypeAlias = "TrainableModel[Any, Any]" + class Backend: def __init__( @@ -39,7 +43,7 @@ async def close(self) -> None: async def register( self, - model: "Model", + model: AnyModel, ) -> None: """ Registers a model with the Backend for logging and/or training. @@ -50,14 +54,14 @@ async def register( response = await self._client.post("/register", json=model.safe_model_dump()) response.raise_for_status() - async def _get_step(self, model: "TrainableModel") -> int: + async def _get_step(self, model: AnyTrainableModel) -> int: response = await self._client.post("/_get_step", json=model.safe_model_dump()) response.raise_for_status() return response.json() async def _delete_checkpoint_files( self, - model: "TrainableModel", + model: AnyTrainableModel, steps_to_keep: list[int], ) -> None: response = await self._client.post( @@ -68,7 +72,7 @@ async def _delete_checkpoint_files( async def _prepare_backend_for_training( self, - model: "TrainableModel", + model: AnyTrainableModel, config: dev.OpenAIServerConfig | None, ) -> tuple[str, str]: response = await self._client.post( @@ -80,7 +84,7 @@ async def _prepare_backend_for_training( base_url, api_key = tuple(response.json()) return base_url, api_key - def _model_inference_name(self, model: "Model", step: int | None = None) -> str: + def _model_inference_name(self, model: AnyModel, step: int | None = None) -> str: """Return the inference name for a model checkpoint. Override in subclasses to provide backend-specific naming. @@ -93,7 +97,7 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str: async def train( self, - model: "TrainableModel", + model: AnyTrainableModel, trajectory_groups: Iterable[TrajectoryGroup], **kwargs: Any, ) -> TrainResult: @@ -114,7 +118,7 @@ async def train( async def _train_model( self, - model: "TrainableModel", + model: AnyTrainableModel, trajectory_groups: list[TrajectoryGroup], config: TrainConfig, dev_config: dev.TrainConfig, @@ -152,7 +156,7 @@ async def _train_model( @log_http_errors async def _experimental_pull_from_s3( self, - model: "Model", + model: AnyModel, *, s3_bucket: str | None = None, prefix: str | None = None, @@ -191,7 +195,7 @@ async def _experimental_pull_from_s3( @log_http_errors async def _experimental_push_to_s3( self, - model: "Model", + model: AnyModel, *, s3_bucket: str | None = None, prefix: str | None = None, @@ -215,7 +219,7 @@ async def _experimental_push_to_s3( @log_http_errors async def _experimental_fork_checkpoint( self, - model: "Model", + model: AnyModel, from_model: str, from_project: str | None = None, from_s3_bucket: str | None = None, diff --git a/src/art/model.py b/src/art/model.py index 3e473cd83..1afd407e2 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,14 +1,14 @@ from datetime import datetime import json import os -from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Iterable, Optional, cast, overload import warnings import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient import polars as pl from pydantic import BaseModel -from typing_extensions import Never +from typing_extensions import Never, TypeVar from . import dev from .trajectories import Trajectory, TrajectoryGroup @@ -23,11 +23,12 @@ ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None) +StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any]) class Model( BaseModel, - Generic[ModelConfig], + Generic[ModelConfig, StateType], ): """ A model is an object that can be passed to your `rollout` function, and used @@ -129,7 +130,7 @@ def __new__( inference_model_name: str | None = None, base_path: str = ".art", report_metrics: list[str] | None = None, - ) -> "Model[None]": ... + ) -> "Model[None, dict[str, Any]]": ... @overload def __new__( @@ -145,14 +146,14 @@ def __new__( inference_model_name: str | None = None, base_path: str = ".art", report_metrics: list[str] | None = None, - ) -> "Model[ModelConfig]": ... + ) -> "Model[ModelConfig, dict[str, Any]]": ... - def __new__( + def __new__( # pyright: ignore[reportInconsistentOverload] cls, *args, **kwargs, - ) -> "Model[ModelConfig] | Model[None]": - return super().__new__(cls) + ) -> "Model[ModelConfig, StateType]": + return super().__new__(cls) # type: ignore[return-value] def safe_model_dump(self, *args, **kwargs) -> dict: """ @@ -260,6 +261,47 @@ def _get_output_dir(self) -> str: """Get the output directory for this model.""" return f"{self.base_path}/{self.project}/models/{self.name}" + def write_state(self, state: StateType) -> None: + """Write persistent state to the model directory as JSON. + + This state is stored in `state.json` within the model's output directory + and can be used to track training progress, dataset position, or any + other information that should persist across runs. + + Args: + state: A dictionary of JSON-serializable values to persist. + + Example: + model.write_state({ + "step": 5, + "dataset_offset": 100, + "last_checkpoint_time": "2024-01-15T10:30:00", + }) + """ + output_dir = self._get_output_dir() + os.makedirs(output_dir, exist_ok=True) + with open(f"{output_dir}/state.json", "w") as f: + json.dump(state, f, indent=2) + + def read_state(self) -> StateType | None: + """Read persistent state from the model directory. + + Returns: + The state dictionary if it exists, or None if no state has been saved. + + Example: + state = model.read_state() + if state: + start_step = state["step"] + dataset_offset = state["dataset_offset"] + """ + output_dir = self._get_output_dir() + state_path = f"{output_dir}/state.json" + if not os.path.exists(state_path): + return None + with open(state_path, "r") as f: + return json.load(f) + def _get_wandb_run(self) -> Optional["Run"]: """Get or create the wandb run for this model.""" import wandb @@ -429,7 +471,7 @@ async def get_step(self) -> int: # --------------------------------------------------------------------------- -class TrainableModel(Model[ModelConfig], Generic[ModelConfig]): +class TrainableModel(Model[ModelConfig, StateType], Generic[ModelConfig, StateType]): base_model: str # Override discriminator field for FastAPI serialization trainable: bool = True @@ -480,7 +522,7 @@ def __new__( base_path: str = ".art", report_metrics: list[str] | None = None, _internal_config: dev.InternalModelConfig | None = None, - ) -> "TrainableModel[None]": ... + ) -> "TrainableModel[None, dict[str, Any]]": ... @overload def __new__( @@ -495,13 +537,13 @@ def __new__( base_path: str = ".art", report_metrics: list[str] | None = None, _internal_config: dev.InternalModelConfig | None = None, - ) -> "TrainableModel[ModelConfig]": ... + ) -> "TrainableModel[ModelConfig, dict[str, Any]]": ... - def __new__( + def __new__( # pyright: ignore[reportInconsistentOverload] cls, *args, **kwargs, - ) -> "TrainableModel[ModelConfig] | TrainableModel[None]": + ) -> "TrainableModel[ModelConfig, StateType]": return super().__new__(cls) # type: ignore def model_dump(self, *args, **kwargs) -> dict: