-
Notifications
You must be signed in to change notification settings - Fork 708
Description
Summary
Propose adding support for simultaneous inference on multiple model checkpoints, enabling pipelined training where training continues on new checkpoints while validation runs on older ones.
This is also related to related to #512, which allows eval metrics to be logged out of order.
Motivation
Currently, ART's backend architecture holds only a single sampling client/LoRA at a time. After each training step, the old checkpoint is replaced:
TinkerService:
state.sampler_client = await self._save_checkpoint(...) # Replaces old clientUnslothService:
await llm.remove_lora(1) # Removes old LoRA
await llm.add_lora(LoRARequest(...)) # Adds new oneThis prevents pipelined approaches like running eval on checkpoint N while training checkpoint N+1.
In my separate training codebase (openpipe-simple), I've implemented a 3-stage pipeline (inference → training → eval) that keeps multiple sampling clients alive simultaneously. This allows eval to run asynchronously on old checkpoints while training progresses. I'd like to bring similar capabilities to ART.
Proposed Design
1. Model Name Convention: name@step
Use a suffix to specify which checkpoint to use for inference:
my-model→ latest checkpoint (default, backwards compatible)my-model@5→ step 5 checkpoint
2. Update Model.get_inference_name() to accept optional step
def get_inference_name(self, step: int | None = None) -> str:
"""Return the name for the inference endpoint.
Args:
step: If provided, returns name for specific checkpoint.
If None, returns name for latest checkpoint.
"""
base_name = self.inference_model_name or self.name
if step is not None:
return f"{base_name}@{step}"
return base_name3. TinkerService: Dict of sampling clients
@dataclass
class TinkerState:
# Change from single client to dict
sampler_clients: dict[int, tinker.SamplingClient] = field(default_factory=dict)
latest_step: int = 0
def get_sampler_client(self, step: int | None = None) -> tinker.SamplingClient:
if step is None:
step = self.latest_step
return self.sampler_clients[step]The OpenAI server parses @step from the model name and routes to the appropriate client.
4. UnslothService: Keep multiple LoRAs loaded
vLLM already supports multiple concurrent LoRAs. Instead of removing the old LoRA after training, keep it loaded:
# After training step N
lora_name = f"{self.model_name}@{next_step}"
await llm.add_lora(LoRARequest(lora_name=lora_name, lora_int_id=next_id, ...))
# Don't remove old LoRAs - vLLM routes by model name automatically5. Cleanup via delete_checkpoints()
The existing delete_checkpoints() method handles cleanup - just extend it to also remove sampling clients/LoRAs for deleted steps.
6. Memory limits
Add configurable limits on concurrent checkpoints to prevent unbounded memory growth.
Usage Example
async def pipelined_training():
eval_queue: asyncio.Queue[int] = asyncio.Queue()
async def eval_worker():
while True:
step = await eval_queue.get()
# Eval runs on specific old checkpoint
name = model.get_inference_name(step=step)
await run_eval(model, name)
asyncio.create_task(eval_worker())
for batch in train_iterator:
# Rollout uses latest checkpoint
groups = await rollout(model) # get_inference_name() -> latest
await model.train(groups)
# Queue async eval on checkpoint we just finished
current_step = await model.get_step()
await eval_queue.put(current_step)
# Training continues immediately while eval runs in backgroundBackends Affected
| Backend | Changes Needed |
|---|---|
| TinkerService | sampler_clients dict, request routing by step |
| UnslothService | Keep multiple LoRAs, track step→lora_id mapping |
| ServerlessBackend | May already work via wandb artifact versioning |
Backwards Compatibility
- Default behavior unchanged:
get_inference_name()with no args returns latest - Existing code continues to work without modification
- The
@stepsuffix is opt-in
Questions for Discussion
- Is the
@stepnaming convention intuitive? Alternatives:name:step,name/step, query param? - Should there be a default limit on concurrent checkpoints, or leave it unbounded?
- For UnslothService, should we expose vLLM's
max_lorasconfig? - Any concerns about memory management in Tinker's sampling clients?
Related
- Use training_step for W&B x-axis to allow out-of-order logging #512 - Use training_step for W&B x-axis to allow out-of-order logging (enables async eval logging)