Skip to content

Conversation

@corbt
Copy link
Contributor

@corbt corbt commented Jan 20, 2026

⚠️ BREAKING CHANGE

Training is now initiated through backend.train() instead of model.train().


Migration Guide

Before (deprecated):

await model.train(trajectory_groups, config=art.TrainConfig(learning_rate=5e-6))

After (recommended):

result = await backend.train(model, trajectory_groups, learning_rate=5e-6)
await model.log(trajectory_groups, metrics=result.metrics, step=result.step, split="train")

Summary

This PR implements Phase 1 of the Backend-First Training API (RFC #519), providing:

New API: backend.train(model, trajectory_groups, **kwargs)

  • Type-safe parameters: Each backend has explicit, documented kwargs instead of generic config objects
  • Structured returns: LocalTrainResult and ServerlessTrainResult with step, metrics, and backend-specific fields (e.g., checkpoint_path)

Explicit Logging

  • backend.train() does NOT automatically log trajectories or metrics
  • Users call model.log() once after training to log both trajectories and metrics together

Extended model.log()

  • Now accepts metrics and step kwargs for logging training metrics alongside trajectories
  • Single call logs everything: await model.log(groups, metrics=result.metrics, step=result.step, split="train")

Fixed get_inference_name()

  • Now correctly returns model.name@step for LocalBackend after registration
  • Properly delegates to backend-specific implementation

Deprecation Warning

  • model.train() now emits a DeprecationWarning with migration instructions

Phase 2 (Future)

In a future release, we will:

  • Remove model.train() method entirely
  • Remove art.TrainConfig and art.dev.TrainConfig classes

Files Changed

Core:

  • src/art/model.py - Extended log(), deprecation warning, fixed get_inference_name()
  • src/art/backend.py - Added train() stub and _model_inference_name()
  • src/art/local/backend.py - Added train() with full kwargs, _model_inference_name()
  • src/art/serverless/backend.py - Added train() with backend-specific kwargs
  • src/art/types.py - Added TrainResult, LocalTrainResult, ServerlessTrainResult
  • src/art/__init__.py - Exported new types

Examples (all updated to new API):

  • examples/2048/train.py
  • examples/benchmarking_comparison_models.py
  • examples/hn_title_generator/train.py
  • examples/just-the-facts/just_the_facts/train.py
  • examples/mcp-rl/mcp_rl/train.py
  • examples/openenv_echo.py
  • examples/prisoners-dilemma.ipynb
  • examples/rock-paper-tool-use.ipynb
  • examples/temporal_clue/*.py and *.ipynb
  • examples/tic_tac_toe/*.py
  • examples/tic_tac_toe_self_play/*.py

Tests:

  • tests/test_backend_train_api.py (new)
  • tests/integration/test_multi_checkpoint_training.py

Closes #519

BREAKING CHANGE: Training is now initiated through backend.train() instead of model.train()

## Migration Guide

### Before (deprecated):
```python
await model.train(trajectory_groups, config=art.TrainConfig(learning_rate=5e-6))
```

### After (recommended):
```python
await model.log(trajectory_groups, split='train')  # Log trajectories
result = await backend.train(model, trajectory_groups, learning_rate=5e-6)
await model.log(metrics=result.metrics, step=result.step, split='train')  # Log training metrics
```

## Key Changes

- **New API**: `backend.train(model, trajectory_groups, **kwargs)` with explicit, type-safe parameters
- **Explicit logging**: `backend.train()` does NOT automatically log trajectories or metrics
- **Extended model.log()**: Now accepts `metrics` and `step` kwargs for logging training metrics directly
- **Structured returns**: `LocalTrainResult` and `ServerlessTrainResult` with step, metrics, and backend-specific fields
- **Fixed get_inference_name()**: Now correctly returns `model.name@step` for LocalBackend
- **Deprecation warning**: `model.train()` emits a warning with migration instructions

## Phase 2 (Future)

In a future release, we will:
- Remove `model.train()` method entirely
- Remove `art.TrainConfig` and `art.dev.TrainConfig` classes

Closes #519
Combined trajectory and metrics logging into a single call:
  result = await backend.train(model, groups, ...)
  await model.log(groups, metrics=result.metrics, step=result.step, split='train')

Removed redundant comments and pre-training log calls.
@corbt corbt merged commit 5af1d38 into main Jan 20, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: Backend-First Training API

3 participants