From f9a403ceb0be2d7a30b7d2cf044d89c05318882c Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Fri, 16 Jan 2026 20:34:50 +0000 Subject: [PATCH 1/3] Use training_step for W&B x-axis to allow out-of-order logging --- src/art/local/backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index eefde0d8f..bfa9b143c 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -658,12 +658,7 @@ def _log_metrics( # If we have a W&B run, log the data there if run := self._get_wandb_run(model): - # Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it - wandb.define_metric("training_step", hidden=True) - - # Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics - # wandb.define_metric(f"{split}/*", step_metric="training_step") - run.log({"training_step": step, **metrics}, step=step) + run.log({"training_step": step, **metrics}) def _get_wandb_run(self, model: Model) -> Run | None: if "WANDB_API_KEY" not in os.environ: @@ -688,6 +683,12 @@ def _get_wandb_run(self, model: Model) -> Run | None: ), ) self._wandb_runs[model.name] = run + + # Define training_step as the x-axis for all metrics. + # This allows out-of-order logging (e.g., async validation for previous steps). + wandb.define_metric("training_step") + wandb.define_metric("train/*", step_metric="training_step") + wandb.define_metric("val/*", step_metric="training_step") os.environ["WEAVE_PRINT_CALL_LINK"] = os.getenv( "WEAVE_PRINT_CALL_LINK", "False" ) From 092b461ed63b08bc68ca1d800999f91520ac5062 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Sat, 17 Jan 2026 01:51:23 +0000 Subject: [PATCH 2/3] Implement multi-checkpoint inference for pipelined training --- src/art/dev/openai_server.py | 12 +- src/art/local/backend.py | 7 + src/art/local/service.py | 8 + src/art/model.py | 24 +- src/art/serverless/backend.py | 17 +- src/art/tinker/service.py | 54 ++- src/art/unsloth/service.py | 69 ++- src/art/vllm/server.py | 19 +- tests/integration/__init__.py | 1 + .../test_multi_checkpoint_training.py | 187 ++++++++ tests/unit/test_multi_checkpoint_inference.py | 436 ++++++++++++++++++ 11 files changed, 804 insertions(+), 30 deletions(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_multi_checkpoint_training.py create mode 100644 tests/unit/test_multi_checkpoint_inference.py diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index 9e050e0a1..90b3335d1 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -12,13 +12,23 @@ def get_openai_server_config( lora_path: str | None = None, config: "OpenAIServerConfig | None" = None, ) -> "OpenAIServerConfig": + import os + if config is None: config = OpenAIServerConfig() log_file = config.get("log_file", log_file) + + # Extract step from lora_path for multi-checkpoint support + # lora_path format is: {output_dir}/checkpoints/{step:04d} + lora_name = model_name + if lora_path: + step = int(os.path.basename(lora_path)) + lora_name = f"{model_name}@{step}" + server_args = ServerArgs( api_key="default", lora_modules=( - [f'{{"name": "{model_name}", "path": "{lora_path}"}}'] + [f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] if lora_path else None ), diff --git a/src/art/local/backend.py b/src/art/local/backend.py index bfa9b143c..acef9b51b 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -484,6 +484,13 @@ async def _train_model( f"Advanced step from {current_step} to {next_step} (no training occurred)" ) + # Register the renamed checkpoint as a new LoRA adapter + # so it's available for inference at the new step + try: + await service.register_lora_for_step(next_step, next_checkpoint_dir) + except Exception: + pass # Method may not exist on all service types + # Log metrics showing no groups were trainable self._log_metrics( model, diff --git a/src/art/local/service.py b/src/art/local/service.py index b54ed5703..241bedcbc 100644 --- a/src/art/local/service.py +++ b/src/art/local/service.py @@ -28,3 +28,11 @@ def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: ... + + async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: + """Register a LoRA adapter for a specific checkpoint step. + + This is called when training is skipped (e.g., all rewards are the same) + but the checkpoint directory is renamed to advance the step. + """ + ... diff --git a/src/art/model.py b/src/art/model.py index 43c519b23..b652a2813 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -187,9 +187,15 @@ def openai_client( ) return self._openai_client - def litellm_completion_params(self) -> dict: - """Return the parameters that should be sent to litellm.completion.""" - model_name = self.inference_model_name + def litellm_completion_params(self, step: int | None = None) -> dict: + """Return the parameters that should be sent to litellm.completion. + + Args: + step: If provided, returns params for specific checkpoint using + the `name@step` convention. If None, returns params for + latest checkpoint (default, backwards compatible). + """ + model_name = self.get_inference_name(step) if self.trainable: model_name = f"hosted_vllm/{model_name}" return { @@ -203,13 +209,21 @@ def litellm_completion_params(self) -> dict: # Inference name helpers # ------------------------------------------------------------------ - def get_inference_name(self) -> str: + def get_inference_name(self, step: int | None = None) -> str: """Return the name that should be sent to the inference endpoint. If `inference_model_name` is provided we use that, otherwise we fall back to the model's own `name`. + + Args: + step: If provided, returns name for specific checkpoint using + the `name@step` convention. If None, returns name for + latest checkpoint (default, backwards compatible). """ - return self.inference_model_name or self.name + base_name = self.inference_model_name or self.name + if step is not None: + return f"{base_name}@{step}" + return base_name async def log( self, diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 2d5de8490..b33dda1e6 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -74,9 +74,22 @@ async def delete( assert model.id is not None, "Model ID is required" await self._client.models.delete(model_id=model.id) - def _model_inference_name(self, model: "TrainableModel") -> str: + def _model_inference_name( + self, model: "TrainableModel", step: int | None = None + ) -> str: + """Return the inference name for a model checkpoint. + + Args: + model: The trainable model. + step: If provided, returns name for specific checkpoint using + W&B artifact versioning (e.g., :step5). If None, returns + name for latest checkpoint (default, backwards compatible). + """ assert model.entity is not None, "Model entity is required" - return f"wandb-artifact:///{model.entity}/{model.project}/{model.name}" + base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}" + if step is not None: + return f"{base_name}:step{step}" + return base_name async def _get_step(self, model: "Model") -> int: if model.trainable: diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index a99e61fd0..8824965e3 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -175,20 +175,35 @@ def custom_loss_fn( } last_checkpoint_dir = self._get_last_checkpoint_dir() assert last_checkpoint_dir is not None, "No checkpoint found" - state.sampler_client = await self._save_checkpoint( - last_checkpoint_dir.with_name(f"{int(last_checkpoint_dir.name) + 1:04d}"), + next_step = int(last_checkpoint_dir.name) + 1 + new_sampler_client = await self._save_checkpoint( + last_checkpoint_dir.with_name(f"{next_step:04d}"), state.training_client, ) + # Add new sampler client to the dict and update latest step + state.sampler_clients[next_step] = new_sampler_client + state.latest_step = next_step async def delete_checkpoints(self, steps_to_keep: list[int]) -> None: state = await self._state_task + # Find steps to delete + steps_to_delete = [ + int(checkpoint_dir.name) + for checkpoint_dir in self._checkpoints_path.iterdir() + if int(checkpoint_dir.name) not in steps_to_keep + ] + # Delete checkpoints from disk and Tinker await asyncio.gather( *[ - delete_checkpoint(checkpoint_dir, state.rest_client) - for checkpoint_dir in self._checkpoints_path.iterdir() - if int(checkpoint_dir.name) not in steps_to_keep + delete_checkpoint(self._checkpoints_path / f"{step:04d}", state.rest_client) + for step in steps_to_delete ] ) + # Also remove corresponding sampler clients from state + for step in steps_to_delete: + if step in state.sampler_clients: + del state.sampler_clients[step] + print(f"Removed sampler client for step {step}") @cached_property def _state_task(self) -> asyncio.Task["TinkerState"]: @@ -201,6 +216,7 @@ async def _get_state(self) -> "TinkerState": rest_client = service_client.create_rest_client() checkpoint_dir = self._get_last_checkpoint_dir() if checkpoint_dir: + current_step = int(checkpoint_dir.name) info = yaml.safe_load(open(checkpoint_dir / "info.yaml", "r")) with log_timing("Creating Tinker training client from checkpoint"): training_client = await service_client.create_training_client_from_state_with_optimizer_async( @@ -212,6 +228,7 @@ async def _get_state(self) -> "TinkerState": model_path=info["sampler_weights_path"], ) else: + current_step = 0 with log_timing("Creating Tinker training client"): training_client_args = config.get("training_client_args", {}) if "rank" not in training_client_args: @@ -231,7 +248,8 @@ async def _get_state(self) -> "TinkerState": service_client=service_client, rest_client=rest_client, training_client=training_client, - sampler_client=sampler_client, + sampler_clients={current_step: sampler_client}, + latest_step=current_step, renderer=renderers.get_renderer( name=config["renderer_name"], tokenizer=tokenizer_utils.get_tokenizer(self.base_model), @@ -296,6 +314,15 @@ async def completions() -> dict: async def chat_completions( request: Request, body: CompletionCreateParams ) -> ChatCompletion: + # Parse model name to extract optional @step suffix + model_name = body.get("model", self.model_name) + step: int | None = None + if "@" in str(model_name): + base_name, step_str = str(model_name).rsplit("@", 1) + step = int(step_str) + + sampler_client = state.get_sampler_client(step) + prompt = tinker.ModelInput.from_ints( tokens=state.renderer.tokenizer.apply_chat_template( list(body["messages"]), # type: ignore @@ -303,7 +330,7 @@ async def chat_completions( add_generation_prompt=True, ) ) - sample_response = await state.sampler_client.sample_async( + sample_response = await sampler_client.sample_async( prompt=prompt, num_samples=body.get("n") or 1, sampling_params=tinker.SamplingParams( @@ -417,5 +444,16 @@ class TinkerState: service_client: tinker.ServiceClient rest_client: TinkerRestClient training_client: tinker.TrainingClient - sampler_client: tinker.SamplingClient + sampler_clients: dict[int, tinker.SamplingClient] + latest_step: int renderer: renderers.Renderer + + def get_sampler_client(self, step: int | None = None) -> tinker.SamplingClient: + if step is None: + step = self.latest_step + if step not in self.sampler_clients: + available = sorted(self.sampler_clients.keys()) + raise ValueError( + f"No sampler client for step {step}. Available steps: {available}" + ) + return self.sampler_clients[step] diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index d493bb3e2..182831ba8 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -261,6 +261,13 @@ class UnslothService: config: dev.InternalModelConfig output_dir: str _is_sleeping: bool = False + _latest_step: int = 0 + _lora_id_counter: int = 1 # Start from 1 since 0 is reserved + + def _next_lora_id(self) -> int: + """Return a new unique LoRA ID to avoid collisions in vLLM.""" + self._lora_id_counter += 1 + return self._lora_id_counter async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> None: lora_path = get_last_checkpoint_dir(self.output_dir) @@ -269,24 +276,50 @@ async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> No lora_path = get_step_checkpoint_dir(self.output_dir, 0) os.makedirs(os.path.dirname(lora_path), exist_ok=True) self._state.trainer.save_model(lora_path) + self._latest_step = 0 + else: + # Extract step from checkpoint path + self._latest_step = get_step_from_dir(self.output_dir) # Offload training model to CPU before vLLM starts to free GPU memory self._state.offload_to_cpu() + server_config = dev.get_openai_server_config( + model_name=self.model_name, + base_model=self.base_model, + log_file=f"{self.output_dir}/logs/vllm.log", + lora_path=lora_path, + config=config, + ) await openai_server_task( engine=await self.llm, - config=dev.get_openai_server_config( - model_name=self.model_name, - base_model=self.base_model, - log_file=f"{self.output_dir}/logs/vllm.log", - lora_path=lora_path, - config=config, - ), + config=server_config, ) async def vllm_engine_is_sleeping(self) -> bool: return self._is_sleeping + async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: + """Register a LoRA adapter for a specific checkpoint step. + + This is called when training is skipped but the checkpoint is renamed. + """ + llm = await self.llm + await llm.pause_generation() + added = await llm.add_lora( + LoRARequest( + lora_name=f"{self.model_name}@{step}", + lora_int_id=self._next_lora_id(), + lora_path=checkpoint_dir, + ) + ) + if not added: + raise RuntimeError( + f"Failed to add LoRA adapter for step {step} at {checkpoint_dir}" + ) + self._latest_step = step + await llm.resume_generation() + async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -371,17 +404,26 @@ async def train( await run_on_workers(llm, do_wake_up) self._is_sleeping = False - # Swap out the LoRA adapter with the newly trained checkpoint - await llm.remove_lora(1) - await llm.add_lora( + # Determine the new step from the checkpoint directory + # checkpoint_dir format is: {output_dir}/checkpoints/{step:04d} + new_step = int(os.path.basename(checkpoint_dir)) + + # Add the new LoRA adapter + # We keep old LoRAs loaded - vLLM will page them out as needed + added = await llm.add_lora( LoRARequest( - lora_name=self.model_name, - lora_int_id=1, + lora_name=f"{self.model_name}@{new_step}", + lora_int_id=self._next_lora_id(), lora_path=checkpoint_dir, ) ) + if not added: + raise RuntimeError( + f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}" + ) + self._latest_step = new_step - # Resume generation after LoRA swap is complete + # Resume generation after LoRA add is complete await llm.resume_generation() if verbose: @@ -461,6 +503,7 @@ def llm(self) -> asyncio.Task[AsyncLLM]: engine_args = { **self.config.get("engine_args", {}), "enable_lora": True, + "max_loras": self.config.get("engine_args", {}).get("max_loras", 2), } # Remove boolean flags that vLLM's argparse doesn't accept as =False for key in ["enable_log_requests", "disable_log_requests"]: diff --git a/src/art/vllm/server.py b/src/art/vllm/server.py index 90e2f7e33..8b8c0a9fd 100644 --- a/src/art/vllm/server.py +++ b/src/art/vllm/server.py @@ -44,6 +44,18 @@ async def openai_server_task( subclass_chat_completion_request() from vllm.entrypoints.openai import api_server + # Capture the OpenAIServingModels instance so dynamically added LoRAs + # are reflected in the model list. + if not hasattr(api_server, "_art_openai_serving_models"): + api_server._art_openai_serving_models = None + original_init = api_server.OpenAIServingModels.__init__ + + def _init(self, *args: Any, **kwargs: Any) -> None: + original_init(self, *args, **kwargs) + api_server._art_openai_serving_models = self + + api_server.OpenAIServingModels.__init__ = _init + patch_listen_for_disconnect() patch_tool_parser_manager() set_vllm_log_file(config.get("log_file", "vllm.log")) @@ -65,7 +77,12 @@ async def _add_lora(lora_request) -> bool: long_lora_max_len=getattr(lora_request, "long_lora_max_len", None), base_model_name=getattr(lora_request, "base_model_name", None), ) - return await add_lora(lora_request) + added = await add_lora(lora_request) + if added: + models = getattr(api_server, "_art_openai_serving_models", None) + if models is not None: + models.lora_requests[lora_request.lora_name] = lora_request + return added engine.add_lora = _add_lora diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..0765cfe34 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for ART multi-checkpoint inference.""" diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py new file mode 100644 index 000000000..981655456 --- /dev/null +++ b/tests/integration/test_multi_checkpoint_training.py @@ -0,0 +1,187 @@ +"""Integration tests for multi-checkpoint inference with real training loops. + +These tests run actual training loops with different backends to verify that +multi-checkpoint inference works end-to-end without crashing. + +Usage: + # Run all integration tests (requires appropriate backend setup) + uv run pytest tests/integration/test_multi_checkpoint_training.py -v -s + +Environment variables: + BASE_MODEL: The base model to use (default: Qwen/Qwen3-0.6B) + WANDB_API_KEY: Required for ServerlessBackend test + TINKER_API_KEY: Required for TinkerBackend test +""" + +import os +import tempfile +import uuid + +import openai +import pytest + +import art +from art.local import LocalBackend + + +# Use a small model for fast testing +DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" + + +def get_base_model() -> str: + """Get the base model to use for testing.""" + return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL) + + +async def simple_rollout( + client: openai.AsyncOpenAI, model_name: str, prompt: str +) -> art.Trajectory: + """A simple rollout function for testing.""" + messages: art.Messages = [{"role": "user", "content": prompt}] + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + max_tokens=10, + timeout=60, + temperature=1, + ) + choice = chat_completion.choices[0] + content = (choice.message.content or "").lower() + if "yes" in content: + reward = 1.0 + elif "no" in content: + reward = 0.5 + elif "maybe" in content: + reward = 0.25 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +async def run_training_loop( + model: art.TrainableModel, + num_steps: int = 1, + rollouts_per_step: int = 4, +) -> list[int]: + """Run a simple training loop and return the step numbers after each train call.""" + openai_client = model.openai_client() + prompts = ["Say yes", "Say no", "Say maybe", "Say hello"] + steps_completed = [] + + async def resolve_model_name(preferred: str, fallback: str) -> str: + try: + available = [m.id async for m in openai_client.models.list()] + except Exception: + return preferred + return preferred if preferred in available else fallback + + for _ in range(num_steps): + current_step = await model.get_step() + preferred_name = model.get_inference_name(step=current_step) + model_name = await resolve_model_name( + preferred_name, model.get_inference_name(step=0) + ) + train_groups = await art.gather_trajectory_groups( + [ + art.TrajectoryGroup( + [ + simple_rollout(openai_client, model_name, prompt) + for _ in range(rollouts_per_step) + ] + ) + for prompt in prompts + ] + ) + await model.train( + train_groups, + config=art.TrainConfig(learning_rate=1e-5), + ) + steps_completed.append(await model.get_step()) + + return steps_completed + + +async def _run_inference_on_step( + model: art.TrainableModel, + step: int, +) -> None: + openai_client = model.openai_client() + model_name = model.get_inference_name(step=step) + await openai_client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello"}], + model=model_name, + max_tokens=10, + timeout=30, + ) + + +@pytest.mark.skipif( + "TINKER_API_KEY" not in os.environ, + reason="TINKER_API_KEY not set - skipping TinkerBackend test", +) +async def test_tinker_backend(): + """Test multi-checkpoint inference with TinkerBackend.""" + model_name = f"test-multi-ckpt-tinker-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = art.TinkerBackend(path=tmpdir) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) + await _run_inference_on_step(model, step=steps[-1]) + await _run_inference_on_step(model, step=0) + finally: + await backend.close() + + +@pytest.mark.skipif( + not os.path.exists("/dev/nvidia0"), + reason="No GPU available - skipping LocalBackend test", +) +async def test_local_backend(): + """Test multi-checkpoint inference with LocalBackend (UnslothService).""" + model_name = f"test-multi-ckpt-local-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = LocalBackend(path=tmpdir) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) + await _run_inference_on_step(model, step=steps[-1]) + await _run_inference_on_step(model, step=0) + finally: + await backend.close() + + +@pytest.mark.skipif( + "WANDB_API_KEY" not in os.environ, + reason="WANDB_API_KEY not set - skipping ServerlessBackend test", +) +async def test_serverless_backend(): + """Test multi-checkpoint inference with ServerlessBackend.""" + model_name = f"test-multi-ckpt-serverless-{uuid.uuid4().hex[:8]}" + backend = art.ServerlessBackend() + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model="meta-llama/Llama-3.1-8B-Instruct", + ) + try: + await model.register(backend) + steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) + await _run_inference_on_step(model, step=steps[-1]) + await _run_inference_on_step(model, step=0) + finally: + try: + await backend.delete(model) + except Exception: + pass + await backend.close() diff --git a/tests/unit/test_multi_checkpoint_inference.py b/tests/unit/test_multi_checkpoint_inference.py new file mode 100644 index 000000000..d8e12a024 --- /dev/null +++ b/tests/unit/test_multi_checkpoint_inference.py @@ -0,0 +1,436 @@ +"""Tests for multi-checkpoint inference support (RFC #513). + +This module tests the ability to run inference on multiple model checkpoints +simultaneously, enabling pipelined training where training continues on new +checkpoints while validation runs on older ones. + +The key features tested are: +1. Model.get_inference_name() with optional step parameter +2. TinkerState.get_sampler_client() for step-based routing +3. ServerlessBackend._model_inference_name() with step suffix +4. UnslothService max_loras configuration +""" + +import asyncio +from dataclasses import dataclass +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import art +from art.model import Model, TrainableModel + + +# ============================================================================= +# Model.get_inference_name() Tests +# ============================================================================= + + +class TestModelGetInferenceName: + """Test Model.get_inference_name() with optional step parameter.""" + + def test_get_inference_name_without_step_uses_name(self): + """Without step, should return the model name.""" + model = Model(name="test-model", project="test-project") + assert model.get_inference_name() == "test-model" + + def test_get_inference_name_without_step_uses_inference_model_name(self): + """Without step, should prefer inference_model_name if set.""" + model = Model( + name="test-model", + project="test-project", + inference_model_name="custom-inference-name", + ) + assert model.get_inference_name() == "custom-inference-name" + + def test_get_inference_name_with_step_appends_suffix(self): + """With step, should append @step suffix.""" + model = Model(name="test-model", project="test-project") + assert model.get_inference_name(step=5) == "test-model@5" + assert model.get_inference_name(step=0) == "test-model@0" + assert model.get_inference_name(step=100) == "test-model@100" + + def test_get_inference_name_with_step_uses_inference_model_name(self): + """With step, should use inference_model_name as base if set.""" + model = Model( + name="test-model", + project="test-project", + inference_model_name="custom-inference-name", + ) + assert model.get_inference_name(step=5) == "custom-inference-name@5" + + def test_get_inference_name_none_step_is_same_as_no_step(self): + """Explicitly passing step=None should behave same as no step.""" + model = Model(name="test-model", project="test-project") + assert model.get_inference_name(step=None) == model.get_inference_name() + + +class TestTrainableModelGetInferenceName: + """Test TrainableModel.get_inference_name() with optional step parameter.""" + + def test_trainable_model_get_inference_name_with_step(self): + """TrainableModel should also support step parameter.""" + model = TrainableModel( + name="trainable-model", + project="test-project", + base_model="meta-llama/Llama-3.1-8B", + ) + assert model.get_inference_name() == "trainable-model" + assert model.get_inference_name(step=3) == "trainable-model@3" + + +class TestLitellmCompletionParams: + """Test Model.litellm_completion_params() with optional step parameter.""" + + def test_litellm_completion_params_without_step(self): + """Without step, should use latest checkpoint name.""" + model = Model( + name="test-model", + project="test-project", + inference_model_name="inference-name", + inference_base_url="http://localhost:8000/v1", + inference_api_key="test-key", + ) + params = model.litellm_completion_params() + assert params["model"] == "inference-name" + assert params["base_url"] == "http://localhost:8000/v1" + assert params["api_key"] == "test-key" + + def test_litellm_completion_params_with_step(self): + """With step, should append @step suffix to model name.""" + model = Model( + name="test-model", + project="test-project", + inference_model_name="inference-name", + inference_base_url="http://localhost:8000/v1", + inference_api_key="test-key", + ) + params = model.litellm_completion_params(step=5) + assert params["model"] == "inference-name@5" + + def test_litellm_completion_params_trainable_model_with_step(self): + """Trainable model with step should have hosted_vllm/ prefix and @step suffix.""" + model = TrainableModel( + name="trainable-model", + project="test-project", + base_model="meta-llama/Llama-3.1-8B", + ) + # Set inference_model_name as it would be after register() + model.inference_model_name = "trainable-model" + model.inference_base_url = "http://localhost:8000/v1" + model.inference_api_key = "test-key" + + params = model.litellm_completion_params() + assert params["model"] == "hosted_vllm/trainable-model" + + params_with_step = model.litellm_completion_params(step=3) + assert params_with_step["model"] == "hosted_vllm/trainable-model@3" + + +# ============================================================================= +# TinkerState Tests +# ============================================================================= + + +class TestTinkerStateGetSamplerClient: + """Test TinkerState.get_sampler_client() for step-based routing.""" + + @pytest.fixture + def tinker_state_class(self): + """Import TinkerState, skipping if dependencies unavailable.""" + try: + from art.tinker.service import TinkerState + return TinkerState + except ImportError as e: + pytest.skip(f"Tinker dependencies not available: {e}") + + def test_get_sampler_client_without_step_returns_latest(self, tinker_state_class): + """Without step, should return client for latest_step.""" + TinkerState = tinker_state_class + + # Create mock sampler clients + mock_client_0 = MagicMock() + mock_client_5 = MagicMock() + + state = TinkerState( + service_client=MagicMock(), + rest_client=MagicMock(), + training_client=MagicMock(), + sampler_clients={0: mock_client_0, 5: mock_client_5}, + latest_step=5, + renderer=MagicMock(), + ) + + assert state.get_sampler_client() is mock_client_5 + assert state.get_sampler_client(step=None) is mock_client_5 + + def test_get_sampler_client_with_step_returns_specific_client(self, tinker_state_class): + """With step, should return client for that specific step.""" + TinkerState = tinker_state_class + + mock_client_0 = MagicMock() + mock_client_3 = MagicMock() + mock_client_5 = MagicMock() + + state = TinkerState( + service_client=MagicMock(), + rest_client=MagicMock(), + training_client=MagicMock(), + sampler_clients={0: mock_client_0, 3: mock_client_3, 5: mock_client_5}, + latest_step=5, + renderer=MagicMock(), + ) + + assert state.get_sampler_client(step=0) is mock_client_0 + assert state.get_sampler_client(step=3) is mock_client_3 + assert state.get_sampler_client(step=5) is mock_client_5 + + def test_get_sampler_client_invalid_step_raises_error(self, tinker_state_class): + """Invalid step should raise ValueError with available steps.""" + TinkerState = tinker_state_class + + state = TinkerState( + service_client=MagicMock(), + rest_client=MagicMock(), + training_client=MagicMock(), + sampler_clients={0: MagicMock(), 5: MagicMock()}, + latest_step=5, + renderer=MagicMock(), + ) + + with pytest.raises(ValueError) as exc_info: + state.get_sampler_client(step=3) + + assert "No sampler client for step 3" in str(exc_info.value) + assert "Available steps: [0, 5]" in str(exc_info.value) + + +# ============================================================================= +# ServerlessBackend Tests +# ============================================================================= + + +class TestServerlessBackendModelInferenceName: + """Test ServerlessBackend._model_inference_name() with step suffix.""" + + def test_model_inference_name_without_step(self): + """Without step, should return base W&B artifact name.""" + from art.serverless.backend import ServerlessBackend + + # Create backend with mock client + with patch("art.serverless.backend.Client"): + backend = ServerlessBackend(api_key="test-key") + + model = TrainableModel( + name="test-model", + project="test-project", + base_model="meta-llama/Llama-3.1-8B", + ) + model.entity = "test-entity" + + result = backend._model_inference_name(model) + assert result == "wandb-artifact:///test-entity/test-project/test-model" + + def test_model_inference_name_with_step(self): + """With step, should append :step{N} suffix.""" + from art.serverless.backend import ServerlessBackend + + with patch("art.serverless.backend.Client"): + backend = ServerlessBackend(api_key="test-key") + + model = TrainableModel( + name="test-model", + project="test-project", + base_model="meta-llama/Llama-3.1-8B", + ) + model.entity = "test-entity" + + result = backend._model_inference_name(model, step=5) + assert result == "wandb-artifact:///test-entity/test-project/test-model:step5" + + result = backend._model_inference_name(model, step=0) + assert result == "wandb-artifact:///test-entity/test-project/test-model:step0" + + def test_model_inference_name_none_step_is_same_as_no_step(self): + """Explicitly passing step=None should behave same as no step.""" + from art.serverless.backend import ServerlessBackend + + with patch("art.serverless.backend.Client"): + backend = ServerlessBackend(api_key="test-key") + + model = TrainableModel( + name="test-model", + project="test-project", + base_model="meta-llama/Llama-3.1-8B", + ) + model.entity = "test-entity" + + assert backend._model_inference_name(model, step=None) == backend._model_inference_name(model) + + +# ============================================================================= +# OpenAI Server Config Tests +# ============================================================================= + + +class TestOpenAIServerConfigLoraName: + """Test that get_openai_server_config uses step-based LoRA naming.""" + + def test_lora_name_includes_step(self): + """LoRA module name should include @step suffix.""" + from art.dev.openai_server import get_openai_server_config + + config = get_openai_server_config( + model_name="my-model", + base_model="meta-llama/Llama-3.1-8B", + log_file="/tmp/test.log", + lora_path="/path/to/checkpoints/0005", + ) + + lora_modules = config.get("server_args", {}).get("lora_modules", []) + assert len(lora_modules) == 1 + assert "my-model@5" in lora_modules[0] + assert "/path/to/checkpoints/0005" in lora_modules[0] + + def test_lora_name_step_zero(self): + """LoRA module name should work with step 0.""" + from art.dev.openai_server import get_openai_server_config + + config = get_openai_server_config( + model_name="my-model", + base_model="meta-llama/Llama-3.1-8B", + log_file="/tmp/test.log", + lora_path="/path/to/checkpoints/0000", + ) + + lora_modules = config.get("server_args", {}).get("lora_modules", []) + assert len(lora_modules) == 1 + assert "my-model@0" in lora_modules[0] + + +# ============================================================================= +# Step Parsing Tests +# ============================================================================= + + +class TestStepParsing: + """Test parsing of @step suffix from model names.""" + + def test_parse_step_from_model_name(self): + """Test the step parsing logic used in TinkerService.""" + test_cases = [ + ("model-name", None), # No @ suffix + ("model-name@5", 5), # Valid step + ("model-name@0", 0), # Step 0 + ("model-name@100", 100), # Large step + ("model@name@5", 5), # Multiple @ (use last) + ("model-name@invalid", None), # Invalid step (not a number) + ("model-name@", None), # Empty step + ] + + for model_name, expected_step in test_cases: + step = None + if "@" in str(model_name): + _, step_str = str(model_name).rsplit("@", 1) + try: + step = int(step_str) + except ValueError: + pass + + assert step == expected_step, f"Failed for {model_name}: got {step}, expected {expected_step}" + + +# ============================================================================= +# UnslothService Configuration Tests +# ============================================================================= + + +class TestUnslothServiceMaxLoras: + """Test UnslothService max_loras configuration.""" + + @pytest.fixture + def unsloth_service_class(self): + """Import UnslothService, skipping if dependencies unavailable.""" + try: + from art.unsloth.service import UnslothService + return UnslothService + except ImportError as e: + pytest.skip(f"Unsloth dependencies not available: {e}") + + def test_max_loras_default_is_2(self, unsloth_service_class): + """UnslothService should default to max_loras=2 (one for training, one for validation).""" + UnslothService = unsloth_service_class + + service = UnslothService( + model_name="test-model", + base_model="meta-llama/Llama-3.1-8B", + config={}, + output_dir="/tmp/test", + ) + + # Access the llm cached property to check engine args + # We can't actually create the LLM, but we can check the config logic + engine_args = { + **service.config.get("engine_args", {}), + "enable_lora": True, + "max_loras": service.config.get("engine_args", {}).get("max_loras", 2), + } + + assert engine_args["max_loras"] == 2 + assert engine_args["enable_lora"] is True + + def test_max_loras_can_be_overridden(self, unsloth_service_class): + """max_loras should be configurable via engine_args.""" + UnslothService = unsloth_service_class + + service = UnslothService( + model_name="test-model", + base_model="meta-llama/Llama-3.1-8B", + config={"engine_args": {"max_loras": 8}}, + output_dir="/tmp/test", + ) + + engine_args = { + **service.config.get("engine_args", {}), + "enable_lora": True, + "max_loras": service.config.get("engine_args", {}).get("max_loras", 2), + } + + assert engine_args["max_loras"] == 8 + + +# ============================================================================= +# Pipelined Training Usage Example +# ============================================================================= + + +class TestPipelinedTrainingUsage: + """Test the usage pattern for pipelined training as described in RFC #513.""" + + def test_pipelined_training_pattern(self): + """ + Verify the API supports the pipelined training pattern from RFC #513. + + The pattern is: + 1. Rollout uses latest checkpoint: model.get_inference_name() + 2. After training, queue eval on specific checkpoint: model.get_inference_name(step=N) + """ + model = Model(name="my-model", project="test", inference_model_name="my-model") + + # Rollout uses latest checkpoint (no step) + rollout_name = model.get_inference_name() + assert rollout_name == "my-model" + assert "@" not in rollout_name + + # After training step 5, queue eval on that specific checkpoint + eval_name = model.get_inference_name(step=5) + assert eval_name == "my-model@5" + + # Training continues, new checkpoint at step 6 + # Rollout still uses latest (would be step 6 after training) + new_rollout_name = model.get_inference_name() + assert new_rollout_name == "my-model" + + # Previous eval can still reference step 5 + prev_eval_name = model.get_inference_name(step=5) + assert prev_eval_name == "my-model@5" From 67ea0d9b96ba0f896c82f2b433fbf2048b2792e3 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Sat, 17 Jan 2026 02:06:57 +0000 Subject: [PATCH 3/3] Fix formatting and typing issues --- src/art/dev/openai_server.py | 4 +--- src/art/local/backend.py | 6 ++--- src/art/local/service.py | 8 ------- src/art/tinker/service.py | 4 +++- src/art/vllm/server.py | 24 +++++++++++-------- .../test_multi_checkpoint_training.py | 1 - tests/unit/test_multi_checkpoint_inference.py | 19 ++++++++++----- 7 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index 90b3335d1..2f28560f4 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -28,9 +28,7 @@ def get_openai_server_config( server_args = ServerArgs( api_key="default", lora_modules=( - [f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] - if lora_path - else None + [f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] if lora_path else None ), return_tokens_as_token_ids=True, enable_auto_tool_choice=True, diff --git a/src/art/local/backend.py b/src/art/local/backend.py index acef9b51b..2251929c8 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -486,10 +486,10 @@ async def _train_model( # Register the renamed checkpoint as a new LoRA adapter # so it's available for inference at the new step - try: + from ..unsloth.service import UnslothService + + if isinstance(service, UnslothService): await service.register_lora_for_step(next_step, next_checkpoint_dir) - except Exception: - pass # Method may not exist on all service types # Log metrics showing no groups were trainable self._log_metrics( diff --git a/src/art/local/service.py b/src/art/local/service.py index 241bedcbc..b54ed5703 100644 --- a/src/art/local/service.py +++ b/src/art/local/service.py @@ -28,11 +28,3 @@ def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: ... - - async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: - """Register a LoRA adapter for a specific checkpoint step. - - This is called when training is skipped (e.g., all rewards are the same) - but the checkpoint directory is renamed to advance the step. - """ - ... diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 8824965e3..d20a25319 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -195,7 +195,9 @@ async def delete_checkpoints(self, steps_to_keep: list[int]) -> None: # Delete checkpoints from disk and Tinker await asyncio.gather( *[ - delete_checkpoint(self._checkpoints_path / f"{step:04d}", state.rest_client) + delete_checkpoint( + self._checkpoints_path / f"{step:04d}", state.rest_client + ) for step in steps_to_delete ] ) diff --git a/src/art/vllm/server.py b/src/art/vllm/server.py index 8b8c0a9fd..988f637fa 100644 --- a/src/art/vllm/server.py +++ b/src/art/vllm/server.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager import logging import os -from typing import Any, AsyncIterator, Coroutine +from typing import Any, AsyncIterator, Coroutine, cast from openai import AsyncOpenAI from uvicorn.config import LOGGING_CONFIG @@ -16,6 +16,8 @@ from ..dev.openai_server import OpenAIServerConfig +_openai_serving_models: Any | None = None + async def openai_server_task( engine: EngineClient, @@ -46,15 +48,19 @@ async def openai_server_task( # Capture the OpenAIServingModels instance so dynamically added LoRAs # are reflected in the model list. - if not hasattr(api_server, "_art_openai_serving_models"): - api_server._art_openai_serving_models = None - original_init = api_server.OpenAIServingModels.__init__ + from vllm.entrypoints.openai import serving_models + + serving_models_any = cast(Any, serving_models) + if not getattr(serving_models_any, "_art_openai_serving_models_patched", False): + serving_models_any._art_openai_serving_models_patched = True + original_init = serving_models.OpenAIServingModels.__init__ def _init(self, *args: Any, **kwargs: Any) -> None: original_init(self, *args, **kwargs) - api_server._art_openai_serving_models = self + global _openai_serving_models + _openai_serving_models = self - api_server.OpenAIServingModels.__init__ = _init + serving_models.OpenAIServingModels.__init__ = _init patch_listen_for_disconnect() patch_tool_parser_manager() @@ -78,10 +84,8 @@ async def _add_lora(lora_request) -> bool: base_model_name=getattr(lora_request, "base_model_name", None), ) added = await add_lora(lora_request) - if added: - models = getattr(api_server, "_art_openai_serving_models", None) - if models is not None: - models.lora_requests[lora_request.lora_name] = lora_request + if added and _openai_serving_models is not None: + _openai_serving_models.lora_requests[lora_request.lora_name] = lora_request return added engine.add_lora = _add_lora diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py index 981655456..fd43f8b21 100644 --- a/tests/integration/test_multi_checkpoint_training.py +++ b/tests/integration/test_multi_checkpoint_training.py @@ -23,7 +23,6 @@ import art from art.local import LocalBackend - # Use a small model for fast testing DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" diff --git a/tests/unit/test_multi_checkpoint_inference.py b/tests/unit/test_multi_checkpoint_inference.py index d8e12a024..108a7e1c4 100644 --- a/tests/unit/test_multi_checkpoint_inference.py +++ b/tests/unit/test_multi_checkpoint_inference.py @@ -20,7 +20,6 @@ import art from art.model import Model, TrainableModel - # ============================================================================= # Model.get_inference_name() Tests # ============================================================================= @@ -140,6 +139,7 @@ def tinker_state_class(self): """Import TinkerState, skipping if dependencies unavailable.""" try: from art.tinker.service import TinkerState + return TinkerState except ImportError as e: pytest.skip(f"Tinker dependencies not available: {e}") @@ -164,7 +164,9 @@ def test_get_sampler_client_without_step_returns_latest(self, tinker_state_class assert state.get_sampler_client() is mock_client_5 assert state.get_sampler_client(step=None) is mock_client_5 - def test_get_sampler_client_with_step_returns_specific_client(self, tinker_state_class): + def test_get_sampler_client_with_step_returns_specific_client( + self, tinker_state_class + ): """With step, should return client for that specific step.""" TinkerState = tinker_state_class @@ -265,7 +267,9 @@ def test_model_inference_name_none_step_is_same_as_no_step(self): ) model.entity = "test-entity" - assert backend._model_inference_name(model, step=None) == backend._model_inference_name(model) + assert backend._model_inference_name( + model, step=None + ) == backend._model_inference_name(model) # ============================================================================= @@ -287,7 +291,7 @@ def test_lora_name_includes_step(self): lora_path="/path/to/checkpoints/0005", ) - lora_modules = config.get("server_args", {}).get("lora_modules", []) + lora_modules = config.get("server_args", {}).get("lora_modules") or [] assert len(lora_modules) == 1 assert "my-model@5" in lora_modules[0] assert "/path/to/checkpoints/0005" in lora_modules[0] @@ -303,7 +307,7 @@ def test_lora_name_step_zero(self): lora_path="/path/to/checkpoints/0000", ) - lora_modules = config.get("server_args", {}).get("lora_modules", []) + lora_modules = config.get("server_args", {}).get("lora_modules") or [] assert len(lora_modules) == 1 assert "my-model@0" in lora_modules[0] @@ -337,7 +341,9 @@ def test_parse_step_from_model_name(self): except ValueError: pass - assert step == expected_step, f"Failed for {model_name}: got {step}, expected {expected_step}" + assert step == expected_step, ( + f"Failed for {model_name}: got {step}, expected {expected_step}" + ) # ============================================================================= @@ -353,6 +359,7 @@ def unsloth_service_class(self): """Import UnslothService, skipping if dependencies unavailable.""" try: from art.unsloth.service import UnslothService + return UnslothService except ImportError as e: pytest.skip(f"Unsloth dependencies not available: {e}")