Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@ def get_openai_server_config(
lora_path: str | None = None,
config: "OpenAIServerConfig | None" = None,
) -> "OpenAIServerConfig":
import os

if config is None:
config = OpenAIServerConfig()
log_file = config.get("log_file", log_file)

# Extract step from lora_path for multi-checkpoint support
# lora_path format is: {output_dir}/checkpoints/{step:04d}
lora_name = model_name
if lora_path:
step = int(os.path.basename(lora_path))
lora_name = f"{model_name}@{step}"

server_args = ServerArgs(
api_key="default",
lora_modules=(
[f'{{"name": "{model_name}", "path": "{lora_path}"}}']
if lora_path
else None
[f'{{"name": "{lora_name}", "path": "{lora_path}"}}'] if lora_path else None
),
return_tokens_as_token_ids=True,
enable_auto_tool_choice=True,
Expand Down
20 changes: 14 additions & 6 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,13 @@ async def _train_model(
f"Advanced step from {current_step} to {next_step} (no training occurred)"
)

# Register the renamed checkpoint as a new LoRA adapter
# so it's available for inference at the new step
from ..unsloth.service import UnslothService

if isinstance(service, UnslothService):
await service.register_lora_for_step(next_step, next_checkpoint_dir)

# Log metrics showing no groups were trainable
self._log_metrics(
model,
Expand Down Expand Up @@ -658,12 +665,7 @@ def _log_metrics(

# If we have a W&B run, log the data there
if run := self._get_wandb_run(model):
# Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it
wandb.define_metric("training_step", hidden=True)

# Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics
# wandb.define_metric(f"{split}/*", step_metric="training_step")
run.log({"training_step": step, **metrics}, step=step)
run.log({"training_step": step, **metrics})

def _get_wandb_run(self, model: Model) -> Run | None:
if "WANDB_API_KEY" not in os.environ:
Expand All @@ -688,6 +690,12 @@ def _get_wandb_run(self, model: Model) -> Run | None:
),
)
self._wandb_runs[model.name] = run

# Define training_step as the x-axis for all metrics.
# This allows out-of-order logging (e.g., async validation for previous steps).
wandb.define_metric("training_step")
wandb.define_metric("train/*", step_metric="training_step")
wandb.define_metric("val/*", step_metric="training_step")
os.environ["WEAVE_PRINT_CALL_LINK"] = os.getenv(
"WEAVE_PRINT_CALL_LINK", "False"
)
Expand Down
24 changes: 19 additions & 5 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,15 @@ def openai_client(
)
return self._openai_client

def litellm_completion_params(self) -> dict:
"""Return the parameters that should be sent to litellm.completion."""
model_name = self.inference_model_name
def litellm_completion_params(self, step: int | None = None) -> dict:
"""Return the parameters that should be sent to litellm.completion.

Args:
step: If provided, returns params for specific checkpoint using
the `name@step` convention. If None, returns params for
latest checkpoint (default, backwards compatible).
"""
model_name = self.get_inference_name(step)
if self.trainable:
model_name = f"hosted_vllm/{model_name}"
return {
Expand All @@ -203,13 +209,21 @@ def litellm_completion_params(self) -> dict:
# Inference name helpers
# ------------------------------------------------------------------

def get_inference_name(self) -> str:
def get_inference_name(self, step: int | None = None) -> str:
"""Return the name that should be sent to the inference endpoint.

If `inference_model_name` is provided we use that, otherwise we fall
back to the model's own `name`.

Args:
step: If provided, returns name for specific checkpoint using
the `name@step` convention. If None, returns name for
latest checkpoint (default, backwards compatible).
"""
return self.inference_model_name or self.name
base_name = self.inference_model_name or self.name
if step is not None:
return f"{base_name}@{step}"
return base_name

async def log(
self,
Expand Down
17 changes: 15 additions & 2 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,22 @@ async def delete(
assert model.id is not None, "Model ID is required"
await self._client.models.delete(model_id=model.id)

def _model_inference_name(self, model: "TrainableModel") -> str:
def _model_inference_name(
self, model: "TrainableModel", step: int | None = None
) -> str:
"""Return the inference name for a model checkpoint.

Args:
model: The trainable model.
step: If provided, returns name for specific checkpoint using
W&B artifact versioning (e.g., :step5). If None, returns
name for latest checkpoint (default, backwards compatible).
"""
assert model.entity is not None, "Model entity is required"
return f"wandb-artifact:///{model.entity}/{model.project}/{model.name}"
base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}"
if step is not None:
return f"{base_name}:step{step}"
return base_name

async def _get_step(self, model: "Model") -> int:
if model.trainable:
Expand Down
56 changes: 48 additions & 8 deletions src/art/tinker/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,37 @@ def custom_loss_fn(
}
last_checkpoint_dir = self._get_last_checkpoint_dir()
assert last_checkpoint_dir is not None, "No checkpoint found"
state.sampler_client = await self._save_checkpoint(
last_checkpoint_dir.with_name(f"{int(last_checkpoint_dir.name) + 1:04d}"),
next_step = int(last_checkpoint_dir.name) + 1
new_sampler_client = await self._save_checkpoint(
last_checkpoint_dir.with_name(f"{next_step:04d}"),
state.training_client,
)
# Add new sampler client to the dict and update latest step
state.sampler_clients[next_step] = new_sampler_client
state.latest_step = next_step

async def delete_checkpoints(self, steps_to_keep: list[int]) -> None:
state = await self._state_task
# Find steps to delete
steps_to_delete = [
int(checkpoint_dir.name)
for checkpoint_dir in self._checkpoints_path.iterdir()
if int(checkpoint_dir.name) not in steps_to_keep
]
# Delete checkpoints from disk and Tinker
await asyncio.gather(
*[
delete_checkpoint(checkpoint_dir, state.rest_client)
for checkpoint_dir in self._checkpoints_path.iterdir()
if int(checkpoint_dir.name) not in steps_to_keep
delete_checkpoint(
self._checkpoints_path / f"{step:04d}", state.rest_client
)
for step in steps_to_delete
]
)
# Also remove corresponding sampler clients from state
for step in steps_to_delete:
if step in state.sampler_clients:
del state.sampler_clients[step]
print(f"Removed sampler client for step {step}")

@cached_property
def _state_task(self) -> asyncio.Task["TinkerState"]:
Expand All @@ -201,6 +218,7 @@ async def _get_state(self) -> "TinkerState":
rest_client = service_client.create_rest_client()
checkpoint_dir = self._get_last_checkpoint_dir()
if checkpoint_dir:
current_step = int(checkpoint_dir.name)
info = yaml.safe_load(open(checkpoint_dir / "info.yaml", "r"))
with log_timing("Creating Tinker training client from checkpoint"):
training_client = await service_client.create_training_client_from_state_with_optimizer_async(
Expand All @@ -212,6 +230,7 @@ async def _get_state(self) -> "TinkerState":
model_path=info["sampler_weights_path"],
)
else:
current_step = 0
with log_timing("Creating Tinker training client"):
training_client_args = config.get("training_client_args", {})
if "rank" not in training_client_args:
Expand All @@ -231,7 +250,8 @@ async def _get_state(self) -> "TinkerState":
service_client=service_client,
rest_client=rest_client,
training_client=training_client,
sampler_client=sampler_client,
sampler_clients={current_step: sampler_client},
latest_step=current_step,
renderer=renderers.get_renderer(
name=config["renderer_name"],
tokenizer=tokenizer_utils.get_tokenizer(self.base_model),
Expand Down Expand Up @@ -296,14 +316,23 @@ async def completions() -> dict:
async def chat_completions(
request: Request, body: CompletionCreateParams
) -> ChatCompletion:
# Parse model name to extract optional @step suffix
model_name = body.get("model", self.model_name)
step: int | None = None
if "@" in str(model_name):
base_name, step_str = str(model_name).rsplit("@", 1)
step = int(step_str)

sampler_client = state.get_sampler_client(step)

prompt = tinker.ModelInput.from_ints(
tokens=state.renderer.tokenizer.apply_chat_template(
list(body["messages"]), # type: ignore
tools=body.get("tools"), # type: ignore
add_generation_prompt=True,
)
)
sample_response = await state.sampler_client.sample_async(
sample_response = await sampler_client.sample_async(
prompt=prompt,
num_samples=body.get("n") or 1,
sampling_params=tinker.SamplingParams(
Expand Down Expand Up @@ -417,5 +446,16 @@ class TinkerState:
service_client: tinker.ServiceClient
rest_client: TinkerRestClient
training_client: tinker.TrainingClient
sampler_client: tinker.SamplingClient
sampler_clients: dict[int, tinker.SamplingClient]
latest_step: int
renderer: renderers.Renderer

def get_sampler_client(self, step: int | None = None) -> tinker.SamplingClient:
if step is None:
step = self.latest_step
if step not in self.sampler_clients:
available = sorted(self.sampler_clients.keys())
raise ValueError(
f"No sampler client for step {step}. Available steps: {available}"
)
return self.sampler_clients[step]
69 changes: 56 additions & 13 deletions src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ class UnslothService:
config: dev.InternalModelConfig
output_dir: str
_is_sleeping: bool = False
_latest_step: int = 0
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved

def _next_lora_id(self) -> int:
"""Return a new unique LoRA ID to avoid collisions in vLLM."""
self._lora_id_counter += 1
return self._lora_id_counter

async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> None:
lora_path = get_last_checkpoint_dir(self.output_dir)
Expand All @@ -269,24 +276,50 @@ async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> No
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
os.makedirs(os.path.dirname(lora_path), exist_ok=True)
self._state.trainer.save_model(lora_path)
self._latest_step = 0
else:
# Extract step from checkpoint path
self._latest_step = get_step_from_dir(self.output_dir)

# Offload training model to CPU before vLLM starts to free GPU memory
self._state.offload_to_cpu()

server_config = dev.get_openai_server_config(
model_name=self.model_name,
base_model=self.base_model,
log_file=f"{self.output_dir}/logs/vllm.log",
lora_path=lora_path,
config=config,
)
await openai_server_task(
engine=await self.llm,
config=dev.get_openai_server_config(
model_name=self.model_name,
base_model=self.base_model,
log_file=f"{self.output_dir}/logs/vllm.log",
lora_path=lora_path,
config=config,
),
config=server_config,
)

async def vllm_engine_is_sleeping(self) -> bool:
return self._is_sleeping

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
"""Register a LoRA adapter for a specific checkpoint step.

This is called when training is skipped but the checkpoint is renamed.
"""
llm = await self.llm
await llm.pause_generation()
added = await llm.add_lora(
LoRARequest(
lora_name=f"{self.model_name}@{step}",
lora_int_id=self._next_lora_id(),
lora_path=checkpoint_dir,
)
)
if not added:
raise RuntimeError(
f"Failed to add LoRA adapter for step {step} at {checkpoint_dir}"
)
self._latest_step = step
await llm.resume_generation()

async def train(
self,
disk_packed_tensors: DiskPackedTensors,
Expand Down Expand Up @@ -371,17 +404,26 @@ async def train(
await run_on_workers(llm, do_wake_up)
self._is_sleeping = False

# Swap out the LoRA adapter with the newly trained checkpoint
await llm.remove_lora(1)
await llm.add_lora(
# Determine the new step from the checkpoint directory
# checkpoint_dir format is: {output_dir}/checkpoints/{step:04d}
new_step = int(os.path.basename(checkpoint_dir))

# Add the new LoRA adapter
# We keep old LoRAs loaded - vLLM will page them out as needed
added = await llm.add_lora(
LoRARequest(
lora_name=self.model_name,
lora_int_id=1,
lora_name=f"{self.model_name}@{new_step}",
lora_int_id=self._next_lora_id(),
lora_path=checkpoint_dir,
)
)
if not added:
raise RuntimeError(
f"Failed to add LoRA adapter for step {new_step} at {checkpoint_dir}"
)
self._latest_step = new_step

# Resume generation after LoRA swap is complete
# Resume generation after LoRA add is complete
await llm.resume_generation()

if verbose:
Expand Down Expand Up @@ -461,6 +503,7 @@ def llm(self) -> asyncio.Task[AsyncLLM]:
engine_args = {
**self.config.get("engine_args", {}),
"enable_lora": True,
"max_loras": self.config.get("engine_args", {}).get("max_loras", 2),
}
# Remove boolean flags that vLLM's argparse doesn't accept as =False
for key in ["enable_log_requests", "disable_log_requests"]:
Expand Down
Loading