Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.10
3.11
2 changes: 1 addition & 1 deletion dev/yes-no-maybe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.13"
}
},
"nbformat": 4,
Expand Down
22 changes: 14 additions & 8 deletions dev/yes-no-maybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,25 @@ def with_quotes(w: str) -> str:
async def main():
load_dotenv()

backend = LocalBackend()
backend = LocalBackend(in_process=True)
global model
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", "011"),
name=os.environ.get("MODEL_NAME", "012"),
project="yes-no-maybe",
base_model=base_model,
_internal_config=art.dev.InternalModelConfig(
engine_args=art.dev.EngineArgs(
max_lora_rank=1,
),
peft_args=art.dev.PeftArgs(
r=1,
# engine_args=art.dev.EngineArgs(
# max_lora_rank=1,
# ),
# peft_args=art.dev.PeftArgs(
# r=1,
# ),
tinker_args=art.dev.TinkerArgs(
renderer_name="qwen3_instruct",
training_client_args=art.dev.TinkerTrainingClientArgs(
rank=1,
),
),
),
)
Expand All @@ -68,7 +74,7 @@ async def main():
]

openai_client = model.openai_client()
max_steps = int(os.environ.get("NUM_STEPS", "4"))
max_steps = int(os.environ.get("NUM_STEPS", "20"))
start_step = await model.get_step()
for _ in range(start_step, start_step + max_steps):
train_groups = await art.gather_trajectory_groups(
Expand Down
11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ name = "openpipe-art"
version = "0.5.4"
description = "The OpenPipe Agent Reinforcement Training (ART) library"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
dependencies = [
"openai>=1.65.5",
"typer>=0.15.2",
"litellm==1.74.1",
"weave>=0.51.51",
"tinker>=0.7.0",
"tinker-cookbook>=0.1.0",
"polars>=1.26.0",
"tblib>=3.0.0",
]

[project.optional-dependencies]
Expand All @@ -27,10 +31,9 @@ backend = [
"accelerate==1.7.0",
"awscli>=1.38.1",
"setproctitle>=1.3.6",
"tblib>=3.0.0",

"setuptools>=78.1.0",
"wandb==0.22.1",
"polars>=1.26.0",
"transformers>=4.55.2,<=4.57.3",
"duckdb>=1.0.0",
"pyarrow>=15.0.0",
Expand Down Expand Up @@ -91,7 +94,7 @@ select = ["I"]
[tool.ruff.lint.isort]
case-sensitive = false
known-first-party = ["art"]
known-third-party = ["wandb"]
known-third-party = ["tinker", "wandb"]
force-sort-within-sections = true

[tool.pytest.ini_options]
Expand Down
4 changes: 4 additions & 0 deletions src/art/dev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
InitArgs,
InternalModelConfig,
PeftArgs,
TinkerArgs,
TinkerTrainingClientArgs,
TrainerArgs,
)
from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config
Expand All @@ -14,6 +16,8 @@
"InternalModelConfig",
"InitArgs",
"PeftArgs",
"TinkerArgs",
"TinkerTrainingClientArgs",
"TrainerArgs",
"get_openai_server_config",
"OpenAIServerConfig",
Expand Down
1 change: 1 addition & 0 deletions src/art/dev/get_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_model_config(
init_args=init_args,
engine_args=engine_args,
peft_args=peft_args,
tinker_args=config.get("tinker_args"),
trainer_args=trainer_args,
torchtune_args=torchtune_args,
)
22 changes: 20 additions & 2 deletions src/art/dev/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum

from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict

from .engine import EngineArgs
from .torchtune import TorchtuneArgs
Expand Down Expand Up @@ -112,17 +112,35 @@ class InternalModelConfig(TypedDict, total=False):

Args:
init: Arguments for initializing an Unsloth FastLanguageModel.
engine: Arguments for the vLLM engine.
peft: Arguments for creating an Unsloth PEFT model wrapper.
train: Arguments for the GRPO trainer.
tinker: Arguments for the Tinker training client.
trainer: Arguments for the GRPO trainer.
torchtune: Arguments for TorchTune.
"""

init_args: "InitArgs"
engine_args: "EngineArgs"
peft_args: "PeftArgs"
tinker_args: "TinkerArgs | None"
trainer_args: "TrainerArgs"
torchtune_args: TorchtuneArgs | None


class TinkerArgs(TypedDict, total=False):
renderer_name: Required[str]
training_client_args: "TinkerTrainingClientArgs"


class TinkerTrainingClientArgs(TypedDict, total=False):
rank: int
seed: int | None
train_mlp: bool
train_attn: bool
train_unembed: bool
user_metadata: dict[str, str] | None


class InitArgs(TypedDict, total=False):
model_name: str
max_seq_length: int
Expand Down
29 changes: 21 additions & 8 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,29 @@ async def register(

async def _get_service(self, model: TrainableModel) -> ModelService:
from ..dev.get_model_config import get_model_config
from ..torchtune.service import TorchtuneService
from ..unsloth.service import UnslothService

if model.name not in self._services:
config = get_model_config(
base_model=model.base_model,
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
)
if config.get("torchtune_args") is not None:
is_tinker = config.get("tinker_args") is not None
if is_tinker:
from ..tinker.service import TinkerService

service_class = TinkerService
elif config.get("torchtune_args") is not None:
from ..torchtune.service import TorchtuneService

service_class = TorchtuneService
else:
from ..unsloth.service import UnslothService

service_class = UnslothService
# When moving the service to a child process, import unsloth
# early to maximize optimizations
os.environ["IMPORT_UNSLOTH"] = "1"
self._services[model.name] = service_class(
model_name=model.name,
base_model=model.base_model,
Expand All @@ -151,12 +161,9 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
if not self._in_process:
# Kill all "model-service" processes to free up GPU memory
subprocess.run(["pkill", "-9", "model-service"])
# When moving the service to a child process, import unsloth
# early to maximize optimizations
os.environ["IMPORT_UNSLOTH"] = "1"
self._services[model.name] = move_to_child_process(
self._services[model.name],
process_name="model-service",
process_name="tinker-service" if is_tinker else "model-service",
)
return self._services[model.name]

Expand Down Expand Up @@ -242,6 +249,8 @@ async def _delete_checkpoints(
benchmark: str,
benchmark_smoothing: float,
) -> None:
from ..tinker.service import TinkerService

output_dir = get_model_dir(model=model, art_path=self._path)
# Keep the latest step
steps_to_keep = [get_model_step(model, self._path)]
Expand All @@ -261,7 +270,11 @@ async def _delete_checkpoints(
print(f'"{output_dir}/history.jsonl" not found')
except pl.exceptions.ColumnNotFoundError:
print(f'No "{benchmark}" metric found in history')
delete_checkpoints(output_dir, steps_to_keep)
service = await self._get_service(model)
if isinstance(service, TinkerService):
await service.delete_checkpoints(steps_to_keep)
else:
delete_checkpoints(output_dir, steps_to_keep)

async def _prepare_backend_for_training(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Loss(BaseModel):
mean_policy_loss: torch.Tensor
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor


Expand Down Expand Up @@ -135,6 +136,7 @@ def loss_fn(
mean_policy_loss=mean_policy_loss,
mean_kl=mean_kl,
mean_entropy=mean_entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
)

Expand Down
50 changes: 50 additions & 0 deletions src/art/preprocessing/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING

import torch

from .pack import PackedTensors

if TYPE_CHECKING:
from .. import dev, types


class TrainInputs(PackedTensors):
"""Training inputs with config attached."""

config: "types.TrainConfig"
_config: "dev.TrainConfig"
return_new_logprobs: bool


def create_train_inputs(
packed_tensors: PackedTensors,
offset: int,
config: "types.TrainConfig",
_config: "dev.TrainConfig",
warmup: bool,
) -> TrainInputs:
"""Create TrainInputs for a single batch offset."""
return TrainInputs(
**{
k: (
v[offset : offset + 1, :1024]
if warmup and v.dim() > 1
else v[offset : offset + 1]
)
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
pixel_values=(
[None] if warmup else packed_tensors["pixel_values"][offset : offset + 1]
),
image_grid_thw=(
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
),
config=(
config.model_copy(update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0})
if warmup
else config
),
_config=_config,
return_new_logprobs=False,
)
Empty file added src/art/tinker/__init__.py
Empty file.
Loading