Skip to content

[RFC] Multi-Checkpoint Inference Support for Pipelined Training #513

@corbt

Description

@corbt

Summary

Propose adding support for simultaneous inference on multiple model checkpoints, enabling pipelined training where training continues on new checkpoints while validation runs on older ones.

This is also related to related to #512, which allows eval metrics to be logged out of order.

Motivation

Currently, ART's backend architecture holds only a single sampling client/LoRA at a time. After each training step, the old checkpoint is replaced:

TinkerService:

state.sampler_client = await self._save_checkpoint(...)  # Replaces old client

UnslothService:

await llm.remove_lora(1)  # Removes old LoRA
await llm.add_lora(LoRARequest(...))  # Adds new one

This prevents pipelined approaches like running eval on checkpoint N while training checkpoint N+1.

In my separate training codebase (openpipe-simple), I've implemented a 3-stage pipeline (inference → training → eval) that keeps multiple sampling clients alive simultaneously. This allows eval to run asynchronously on old checkpoints while training progresses. I'd like to bring similar capabilities to ART.

Proposed Design

1. Model Name Convention: name@step

Use a suffix to specify which checkpoint to use for inference:

  • my-model → latest checkpoint (default, backwards compatible)
  • my-model@5 → step 5 checkpoint

2. Update Model.get_inference_name() to accept optional step

def get_inference_name(self, step: int | None = None) -> str:
    """Return the name for the inference endpoint.
    
    Args:
        step: If provided, returns name for specific checkpoint.
              If None, returns name for latest checkpoint.
    """
    base_name = self.inference_model_name or self.name
    if step is not None:
        return f"{base_name}@{step}"
    return base_name

3. TinkerService: Dict of sampling clients

@dataclass
class TinkerState:
    # Change from single client to dict
    sampler_clients: dict[int, tinker.SamplingClient] = field(default_factory=dict)
    latest_step: int = 0
    
    def get_sampler_client(self, step: int | None = None) -> tinker.SamplingClient:
        if step is None:
            step = self.latest_step
        return self.sampler_clients[step]

The OpenAI server parses @step from the model name and routes to the appropriate client.

4. UnslothService: Keep multiple LoRAs loaded

vLLM already supports multiple concurrent LoRAs. Instead of removing the old LoRA after training, keep it loaded:

# After training step N
lora_name = f"{self.model_name}@{next_step}"
await llm.add_lora(LoRARequest(lora_name=lora_name, lora_int_id=next_id, ...))
# Don't remove old LoRAs - vLLM routes by model name automatically

5. Cleanup via delete_checkpoints()

The existing delete_checkpoints() method handles cleanup - just extend it to also remove sampling clients/LoRAs for deleted steps.

6. Memory limits

Add configurable limits on concurrent checkpoints to prevent unbounded memory growth.

Usage Example

async def pipelined_training():
    eval_queue: asyncio.Queue[int] = asyncio.Queue()
    
    async def eval_worker():
        while True:
            step = await eval_queue.get()
            # Eval runs on specific old checkpoint
            name = model.get_inference_name(step=step)
            await run_eval(model, name)
    
    asyncio.create_task(eval_worker())
    
    for batch in train_iterator:
        # Rollout uses latest checkpoint
        groups = await rollout(model)  # get_inference_name() -> latest
        
        await model.train(groups)
        
        # Queue async eval on checkpoint we just finished
        current_step = await model.get_step()
        await eval_queue.put(current_step)
        # Training continues immediately while eval runs in background

Backends Affected

Backend Changes Needed
TinkerService sampler_clients dict, request routing by step
UnslothService Keep multiple LoRAs, track step→lora_id mapping
ServerlessBackend May already work via wandb artifact versioning

Backwards Compatibility

  • Default behavior unchanged: get_inference_name() with no args returns latest
  • Existing code continues to work without modification
  • The @step suffix is opt-in

Questions for Discussion

  1. Is the @step naming convention intuitive? Alternatives: name:step, name/step, query param?
  2. Should there be a default limit on concurrent checkpoints, or leave it unbounded?
  3. For UnslothService, should we expose vLLM's max_loras config?
  4. Any concerns about memory management in Tinker's sampling clients?

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions