Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ dependencies = [
"typer>=0.15.2",
"litellm>=1.71.1",
"weave>=0.52.23",
"tinker>=0.8.1",
"tinker-cookbook>=0.1.0",
"polars>=1.26.0",
"tblib>=3.0.0",
"nest-asyncio>=1.6.0",
Expand Down Expand Up @@ -115,6 +113,9 @@ unused-ignore-comment = "ignore"
# Allow unresolved imports for optional dependencies that may not be installed locally.
# In CI, we install all optional deps so these will be resolved and type-checked.
allowed-unresolved-imports = [
# tinker deps
"tinker.**",
"tinker_cookbook.**",
# backend deps
"accelerate.**",
"awscli.**",
Expand Down Expand Up @@ -165,6 +166,12 @@ dev = [
"pyarrow>=15.0.0",
"prek>=0.2.29",
]
tinker = [
"fastapi>=0.128.0",
"tinker>=0.8.1",
"tinker-cookbook>=0.1.0",
"uvicorn>=0.35.0",
]

[tool.uv.sources]
panza = { git = "https://github.com/corbt/panza.git" }
11 changes: 9 additions & 2 deletions src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def __init__(self, **kwargs):
from .local import LocalBackend
from .model import Model, TrainableModel
from .serverless import ServerlessBackend
from .tinker import TinkerBackend

try:
from .tinker import TinkerBackend
from .tinker_native import TinkerNativeBackend
except ModuleNotFoundError:
TinkerBackend = None # type: ignore[assignment]
TinkerNativeBackend = None # type: ignore[assignment]
from .trajectories import Trajectory, TrajectoryGroup
from .types import (
LocalTrainResult,
Expand Down Expand Up @@ -91,9 +97,10 @@ def __init__(self, **kwargs):
"retry",
"TrainConfig",
"TrainResult",
"TinkerBackend",
"Trajectory",
"TrajectoryGroup",
"capture_yielded_trajectory",
"yield_trajectory",
]
if TinkerBackend is not None:
__all__.extend(["TinkerBackend", "TinkerNativeBackend"])
2 changes: 2 additions & 0 deletions src/art/dev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
InternalModelConfig,
PeftArgs,
TinkerArgs,
TinkerNativeArgs,
TinkerTrainingClientArgs,
TrainerArgs,
)
Expand All @@ -16,6 +17,7 @@
"InitArgs",
"PeftArgs",
"TinkerArgs",
"TinkerNativeArgs",
"TinkerTrainingClientArgs",
"TrainerArgs",
"get_openai_server_config",
Expand Down
6 changes: 6 additions & 0 deletions src/art/dev/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class InternalModelConfig(TypedDict, total=False):
engine_args: "EngineArgs"
peft_args: "PeftArgs"
tinker_args: "TinkerArgs | None"
tinker_native_args: "TinkerNativeArgs | None"
trainer_args: "TrainerArgs"


Expand All @@ -129,6 +130,11 @@ class TinkerArgs(TypedDict, total=False):
training_client_args: "TinkerTrainingClientArgs"


class TinkerNativeArgs(TypedDict, total=False):
renderer_name: Required[str]
training_client_args: "TinkerTrainingClientArgs"


class TinkerTrainingClientArgs(TypedDict, total=False):
rank: int
seed: int | None
Expand Down
2 changes: 2 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ async def register(
Args:
model: An art.Model instance.
"""
# Ensure model state/logging uses the backend path
model.base_path = self._path
output_dir = get_model_dir(model=model, art_path=self._path)
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/model.json", "w") as f:
Expand Down
49 changes: 46 additions & 3 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,22 @@ def _get_output_dir(self) -> str:
"""Get the output directory for this model."""
return f"{self.base_path}/{self.project}/models/{self.name}"

def write_state(self, state: StateType) -> None:
"""Write persistent state to the model directory as JSON.
def overwrite_state(self, state: StateType) -> None:
"""Overwrite persistent state in the model directory as JSON.

This state is stored in `state.json` within the model's output directory
and can be used to track training progress, dataset position, or any
other information that should persist across runs.

Warning:
This overwrites the entire state file. Prefer `merge_state()` unless
you intentionally want to replace all existing keys.

Args:
state: A dictionary of JSON-serializable values to persist.

Example:
model.write_state({
model.overwrite_state({
"step": 5,
"dataset_offset": 100,
"last_checkpoint_time": "2024-01-15T10:30:00",
Expand All @@ -283,6 +287,45 @@ def write_state(self, state: StateType) -> None:
with open(f"{output_dir}/state.json", "w") as f:
json.dump(state, f, indent=2)

def write_state(self, state: StateType) -> None:
"""Deprecated: use `overwrite_state()` or `merge_state()` instead."""
warnings.warn(
"write_state() is deprecated. Use overwrite_state() or merge_state() instead.",
DeprecationWarning,
stacklevel=2,
)
self.overwrite_state(state)

def merge_state(self, state: StateType) -> StateType:
"""Deep-merge state into the existing state and persist it.

Args:
state: A dictionary of JSON-serializable values to merge.

Returns:
The merged state dictionary that was persisted.
"""
existing = self.read_state() or {}
merged = self._deep_merge_dicts(existing, state)
self.overwrite_state(merged)
return cast(StateType, merged)

@staticmethod
def _deep_merge_dicts(
base: dict[str, Any], updates: dict[str, Any]
) -> dict[str, Any]:
merged = dict(base)
for key, value in updates.items():
if (
key in merged
and isinstance(merged[key], dict)
and isinstance(value, dict)
):
merged[key] = Model._deep_merge_dicts(merged[key], value)
else:
merged[key] = value
return merged

def read_state(self) -> StateType | None:
"""Read persistent state from the model directory.

Expand Down
13 changes: 13 additions & 0 deletions src/art/pipeline_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .status import StatusReporter
from .trainer import PipelineTrainer, make_group_rollout_fn
from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn

__all__ = [
"PipelineTrainer",
"make_group_rollout_fn",
"StatusReporter",
"RolloutFn",
"SingleRolloutFn",
"EvalFn",
"ScenarioT",
]
Loading