diff --git a/docs/ADRs/finetuning/013-axolotl-training-implementation.md b/docs/ADRs/finetuning/013-axolotl-training-implementation.md new file mode 100644 index 0000000..cef583d --- /dev/null +++ b/docs/ADRs/finetuning/013-axolotl-training-implementation.md @@ -0,0 +1,60 @@ +# ADR-013. Axolotl Training Implementation + +Date: 2026-04-27 +Status: Proposed + +## Context + +With the service skeleton established in ADR-012, the next decisions concerned the actual training pipeline: how Axolotl is invoked, how GPU utilisation is maximised on H100 hardware, how evaluation is handled and how optional integrations (Weights & Biases) are exposed to users. + +## Decision + +### Axolotl invocation + +Axolotl is invoked as a subprocess (`axolotl train `) rather than imported as a library. This keeps Axolotl's CUDA environment self-contained. + +A per job Axolotl config is generated at runtime by merging a service level base config (`config.yaml`, mounted read-only) with the user's job parameters. The result is written to a temp directory (`/tmp/finetuning/{job_id}/`) and cleaned up unconditionally in a `finally` block after the subprocess exits. + +### Worker as a separate container + +The worker runs in its own container using the Axolotl base image (`axolotlai/axolotl:main-py3.12-cu130-2.10.0`). The API uses a separate lightweight Python image. Merging them would require the API to carry the full Axolotl image (several GB) for no benefit. + +### Flash Attention 4 + +The Axolotl image ships Flash Attention 2 (FA2). On CUDA 13 / H100 hardware, FA2 produced a `CUBLAS_STATUS_INVALID_VALUE` error in the RoPE computation during evaluation, crashing jobs before training began. Installing Flash Attention 4 (`flash-attn-4[cu13]==4.0.0b10`) resolved this. FA4 is the architecturally correct choice for Hopper GPUs (H100) and is explicitly recommended by Axolotl in its startup logs for this hardware. The `[cu13]` extra selects the CUDA 13 wheel. + +FA4 is pinned to a specific beta version. Upgrading is a deliberate decision identical to how we pin the Axolotl base image. + +### Sample packing + +Sample packing is enabled by default (`sample_packing: true`) and disabled for evaluation (`eval_sample_packing: false`). Without it, sequences are padded individually to `sequence_len`, resulting in approximately 55% padding waste on typical instruction-tuning datasets. With sample packing, multiple conversations are packed end-to-end into each sequence slot using Flash Attention masking to prevent cross conversation attention. This improved trainable token density to ~65% and GPU throughput ~9x on our test dataset. + +Sample packing is not applied during evaluation: the eval set is usually small and packing adds complexity without meaningfully improving evaluation speed. + +### Default sequence length + +The default `sequence_len` is 2048. The original default of 512 dropped approximately 37% of training samples from our representative dataset (max sequence length ~1716 tokens). 2048 retains all sequences and, combined with sample packing, allows more conversations per packed slot. Users may override this per job. + +### Evaluation control + +Evaluation is opt-in. Jobs default to `do_eval: false`, which injects `eval_strategy: "no"` into the generated Axolotl config, overriding the base config's `eval_strategy: epoch`. When `do_eval: true`, the `validation` split of the user's dataset is used and evaluation runs at the end of each epoch. + +This is opt-in rather than opt-out because many Hugging Face datasets do not include a `validation` split; silently failing a job because the split is absent is a worse experience than requiring users to explicitly request evaluation. SDK documentation will specify that users must include a `validation` split if they set `do_eval: true`. + +### Weights & Biases integration + +Optional wandb logging is supported via three fields: `wandb_token`, `wandb_project`, and `wandb_entity`. The token is handled identically to the HF token: stored in the database, passed to the Axolotl subprocess as `WANDB_API_KEY` and cleared to `NULL` on job completion. + +Validation rules: +- `wandb_project` and `wandb_token` must be provided together. +- `wandb_entity` (a team/organisation name) may be omitted; wandb defaults to the user's personal account. +- Providing `wandb_entity` without `wandb_project` is rejected. + +`wandb_entity` alone being omitted is the only partially specified combination that is permitted, reflecting wandb's own behaviour. + +## Consequences + +- Training jobs on H100 / CUDA 13 hardware work with FA4. +- Sample packing significantly improves GPU utilisation but compresses the effective number of training steps per epoch. For small datasets, users should be aware that a large `micro_batch_size` relative to the dataset size can result in very few optimiser steps per epoch. +- Evaluation requires users to know their dataset structure. No automatic detection of available splits is performed. +- Wandb tokens are treated as sensitive credentials and cleared after use, consistent with the HF token policy established in ADR-011. diff --git a/stacks/finetuning-service/.env.example b/stacks/finetuning-service/.env.example index e1b6c0f..bb8146a 100644 --- a/stacks/finetuning-service/.env.example +++ b/stacks/finetuning-service/.env.example @@ -12,6 +12,8 @@ # ----------------------------------------------------------------------------- # Fine-Tuning Service # ----------------------------------------------------------------------------- +DEVICE=3 +# Multi-gpu not currently supported FINETUNING_PORT=8005 MAX_JOB_DURATION_HOURS=4 diff --git a/stacks/finetuning-service/Dockerfile b/stacks/finetuning-service/Dockerfile deleted file mode 100644 index 5b154c5..0000000 --- a/stacks/finetuning-service/Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM axolotlai/axolotl:main-py3.12-cu130-2.10.0 - -WORKDIR /app - -# Install service dependencies on top of Axolotl's environment -COPY requirements.txt . -RUN pip install -r requirements.txt - -COPY app/ ./app/ - -CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${FINETUNING_PORT}"] diff --git a/stacks/finetuning-service/Dockerfile.api b/stacks/finetuning-service/Dockerfile.api new file mode 100644 index 0000000..dca1ecb --- /dev/null +++ b/stacks/finetuning-service/Dockerfile.api @@ -0,0 +1,19 @@ +# ============================================================================= +# Dockerfile.api - Fine-Tuning Service API +# ============================================================================= +# +# Lightweight image for the FastAPI service only. +# The Axolotl worker uses a separate image (Dockerfile.worker). +# +# ============================================================================= + +FROM python:3.12-slim + +WORKDIR /app + +COPY requirements.api.txt . +RUN pip install -r requirements.api.txt + +COPY app/ ./app/ + +CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${FINETUNING_PORT}"] diff --git a/stacks/finetuning-service/Dockerfile.worker b/stacks/finetuning-service/Dockerfile.worker new file mode 100644 index 0000000..0940b38 --- /dev/null +++ b/stacks/finetuning-service/Dockerfile.worker @@ -0,0 +1,22 @@ +# ============================================================================= +# Dockerfile.worker - Fine-Tuning Worker +# ============================================================================= +# +# Axolotl-based image for the queue worker that runs training jobs. +# The FastAPI service uses a separate lightweight image (Dockerfile.api). +# +# Update the tag deliberately when upgrading Axolotl. +# +# ============================================================================= + +FROM axolotlai/axolotl:main-py3.12-cu130-2.10.0 + +WORKDIR /app + +COPY requirements.worker.txt . +RUN pip install -r requirements.worker.txt + +COPY app/ ./app/ +COPY worker.py . + +CMD ["python", "worker.py"] diff --git a/stacks/finetuning-service/app/config.py b/stacks/finetuning-service/app/config.py new file mode 100644 index 0000000..98b7fea --- /dev/null +++ b/stacks/finetuning-service/app/config.py @@ -0,0 +1,26 @@ +"""Service configuration loader.""" + +from pathlib import Path + +import yaml + +CONFIG_PATH = Path("/app/config.yaml") + + +def load_config() -> dict: + """Load the service configuration from disk. + + Returns: + The parsed configuration dictionary. + """ + with CONFIG_PATH.open() as f: + return yaml.safe_load(f) + + +def get_allowed_models() -> list[str]: + """Return the list of models permitted for fine-tuning. + + Returns: + A list of allowed Hugging Face model repo paths. + """ + return load_config().get("allowed_models", []) diff --git a/stacks/finetuning-service/app/database.py b/stacks/finetuning-service/app/database.py index dd940d2..7f56103 100644 --- a/stacks/finetuning-service/app/database.py +++ b/stacks/finetuning-service/app/database.py @@ -4,6 +4,7 @@ import sqlite3 import uuid from contextlib import contextmanager +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Generator, Optional @@ -13,6 +14,24 @@ DB_PATH = Path("/data/jobs.db") +@dataclass +class JobDetail: + """Internal job representation including sensitive fields. + + Not exposed via the API. Used by the worker to access the + HF token and full training config. + """ + + id: str + model: str + hf_dataset: str + hf_token: str + hub_model_id: str + wandb_token: Optional[str] + config: str + status: JobStatus + + @contextmanager def get_connection() -> Generator[sqlite3.Connection, None, None]: """Yield a database connection with row factory configured. @@ -43,7 +62,7 @@ def _row_to_job(row: sqlite3.Row) -> JobResponse: status=JobStatus(row["status"]), model=row["model"], hf_dataset=row["hf_dataset"], - suffix=row["suffix"], + hub_model_id=row["hub_model_id"], created_at=datetime.fromisoformat(row["created_at"]), started_at=( datetime.fromisoformat(row["started_at"]) @@ -68,7 +87,9 @@ def init_db() -> None: status TEXT NOT NULL, model TEXT NOT NULL, hf_dataset TEXT NOT NULL, - suffix TEXT, + hf_token TEXT, + hub_model_id TEXT NOT NULL, + wandb_token TEXT, config TEXT NOT NULL, created_at TEXT NOT NULL, started_at TEXT, @@ -105,23 +126,39 @@ def create_job(request: JobSubmitRequest) -> JobResponse: now = datetime.now(timezone.utc).isoformat() config = json.dumps( { - # TODO: Add hyperparameters/config keys + "num_epochs": request.num_epochs, + "learning_rate": request.learning_rate, + "micro_batch_size": request.micro_batch_size, + "gradient_accumulation_steps": ( + request.gradient_accumulation_steps + ), + "sequence_len": request.sequence_len, + "lora_r": request.lora_r, + "lora_dropout": request.lora_dropout, + "lora_target_modules": request.lora_target_modules, + "load_in_4bit": request.load_in_4bit, + "load_in_8bit": request.load_in_8bit, + "do_eval": request.do_eval, + "wandb_project": request.wandb_project, + "wandb_entity": request.wandb_entity, } ) with get_connection() as conn: conn.execute( """ INSERT INTO jobs ( - id, status, model, hf_dataset, - suffix, config, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?) + id, status, model, hf_dataset, hf_token, + hub_model_id, wandb_token, config, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( job_id, JobStatus.QUEUED, request.model, request.hf_dataset, - request.suffix, + request.hf_token, + request.hub_model_id, + request.wandb_token, config, now, ), @@ -131,7 +168,7 @@ def create_job(request: JobSubmitRequest) -> JobResponse: status=JobStatus.QUEUED, model=request.model, hf_dataset=request.hf_dataset, - suffix=request.suffix, + hub_model_id=request.hub_model_id, created_at=datetime.fromisoformat(now), ) @@ -186,3 +223,83 @@ def cancel_job(job_id: str) -> Optional[JobResponse]: (JobStatus.CANCELLED, job_id, JobStatus.QUEUED), ) return get_job(job_id) + + +def get_next_queued_job() -> Optional[JobDetail]: + """Fetch the oldest queued job for the worker to process. + + Returns: + A JobDetail if a queued job exists, otherwise None. + """ + with get_connection() as conn: + row = conn.execute( + """ + SELECT * FROM jobs + WHERE status = ? + ORDER BY created_at ASC + LIMIT 1 + """, + (JobStatus.QUEUED,), + ).fetchone() + if not row: + return None + return JobDetail( + id=row["id"], + model=row["model"], + hf_dataset=row["hf_dataset"], + hf_token=row["hf_token"], + hub_model_id=row["hub_model_id"], + wandb_token=row["wandb_token"], + config=row["config"], + status=JobStatus(row["status"]), + ) + + +def claim_job(job_id: str) -> bool: + """Atomically mark a queued job as running. + + Args: + job_id: The UUID of the job to claim. + + Returns: + True if the job was successfully claimed, False if it + was already claimed by another process. + """ + now = datetime.now(timezone.utc).isoformat() + with get_connection() as conn: + result = conn.execute( + """ + UPDATE jobs SET status = ?, started_at = ? + WHERE id = ? AND status = ? + """, + (JobStatus.RUNNING, now, job_id, JobStatus.QUEUED), + ) + return result.rowcount > 0 + + +def complete_job( + job_id: str, + status: JobStatus, + error_message: Optional[str] = None, +) -> None: + """Mark a job as completed and clear its tokens. + + Args: + job_id: The UUID of the job to complete. + status: The final status (succeeded or failed). + error_message: Optional error description for failed jobs. + """ + now = datetime.now(timezone.utc).isoformat() + with get_connection() as conn: + conn.execute( + """ + UPDATE jobs + SET status = ?, + completed_at = ?, + error_message = ?, + hf_token = NULL, + wandb_token = NULL + WHERE id = ? + """, + (status, now, error_message, job_id), + ) diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index 2009014..822689f 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -4,7 +4,10 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self + +from .config import get_allowed_models class JobStatus(str, Enum): @@ -30,12 +33,107 @@ class JobSubmitRequest(BaseModel): "and write access to the adapter destination." ) ) - suffix: Optional[str] = Field( - default=None, - description="Label appended to the adapter repository name.", + hub_model_id: str = Field( + description=( + "Hugging Face repo path to push the trained adapter to " + "(e.g. 'username/my-adapter')." + ) ) - # TODO: add hyperparameters/config keys - # (minimal subset of keys from axolotl config reference) + num_epochs: int = Field(ge=1) + learning_rate: float = Field(gt=0) + micro_batch_size: int = Field(default=4, ge=1) + gradient_accumulation_steps: int = Field(default=1, ge=1) + sequence_len: int = Field(default=2048, ge=64) + lora_r: int = Field(ge=1) + lora_dropout: float = Field(default=0.0, ge=0.0, le=1.0) + lora_target_modules: list[str] = Field( + default=["q_proj", "v_proj", "k_proj", "o_proj"] + ) + load_in_4bit: bool = Field(default=False) + load_in_8bit: bool = Field(default=False) + do_eval: bool = Field(default=False) + wandb_token: Optional[str] = Field(default=None) + wandb_project: Optional[str] = Field(default=None) + wandb_entity: Optional[str] = Field(default=None) + + @field_validator("model") + @classmethod + def model_must_be_whitelisted(cls, v: str) -> str: + """Validate the model is on the allowed list. + + Args: + v: The model repo path to validate. + + Returns: + The validated model path. + + Raises: + ValueError: If the model is not on the whitelist. + """ + allowed = get_allowed_models() + if v not in allowed: + raise ValueError( + f"Model '{v}' is not permitted. Allowed models: {allowed}" + ) + return v + + @field_validator("lora_target_modules") + @classmethod + def lora_target_modules_must_not_be_empty(cls, v: list[str]) -> list[str]: + """Validate that at least one LoRA target module is specified. + + Args: + v: The list of target modules. + + Returns: + The validated list. + + Raises: + ValueError: If the list is empty. + """ + if not v: + raise ValueError( + "lora_target_modules must contain at least one module." + ) + return v + + @model_validator(mode="after") + def wandb_fields_are_consistent(self) -> Self: + """Validate wandb field combinations are coherent. + + Returns: + The validated model instance. + + Raises: + ValueError: If wandb fields are provided in an invalid + combination. + """ + has_token = bool(self.wandb_token) + has_project = bool(self.wandb_project) + has_entity = bool(self.wandb_entity) + if has_entity and not has_project: + raise ValueError("wandb_entity requires wandb_project to be set.") + if has_project and not has_token: + raise ValueError("wandb_project requires wandb_token.") + if has_token and not has_project: + raise ValueError("wandb_token requires wandb_project.") + return self + + @model_validator(mode="after") + def quantisation_modes_are_mutually_exclusive(self) -> Self: + """Validate that 4-bit and 8-bit quantisation are not both set. + + Returns: + The validated model instance. + + Raises: + ValueError: If both load_in_4bit and load_in_8bit are True. + """ + if self.load_in_4bit and self.load_in_8bit: + raise ValueError( + "load_in_4bit and load_in_8bit are mutually exclusive." + ) + return self class JobResponse(BaseModel): @@ -45,7 +143,7 @@ class JobResponse(BaseModel): status: JobStatus model: str hf_dataset: str - suffix: Optional[str] = None + hub_model_id: str created_at: datetime started_at: Optional[datetime] = None completed_at: Optional[datetime] = None diff --git a/stacks/finetuning-service/config.yaml b/stacks/finetuning-service/config.yaml new file mode 100644 index 0000000..0ba4f58 --- /dev/null +++ b/stacks/finetuning-service/config.yaml @@ -0,0 +1,37 @@ +# ============================================================================= +# Fine-Tuning Service Configuration +# ============================================================================= +# +# Edit this file and restart the service to update settings +# without rebuilding the image. +# +# ============================================================================= + +allowed_models: + - meta-llama/Llama-3.1-8B-Instruct + +axolotl_base_config: + # Adapter + adapter: lora + + # Precision + bf16: true + tf32: true + + # Memory optimisation + flash_attention: true + gradient_checkpointing: true + sample_packing: true + eval_sample_packing: false + + # Evaluation and checkpointing + eval_strategy: epoch + save_strategy: epoch + save_total_limit: 2 + save_safetensors: true + + # Hub push + hub_strategy: end + + # Security + trust_remote_code: false diff --git a/stacks/finetuning-service/docker-compose.yml b/stacks/finetuning-service/docker-compose.yml index 9a28218..01e23e6 100644 --- a/stacks/finetuning-service/docker-compose.yml +++ b/stacks/finetuning-service/docker-compose.yml @@ -11,24 +11,21 @@ # ============================================================================= services: - finetuning-service: - build: . - container_name: finetuning-service + # --------------------------------------------------------------------------- + # API + # --------------------------------------------------------------------------- + api: + build: + context: . + dockerfile: Dockerfile.api + container_name: finetuning-api restart: unless-stopped environment: - NVIDIA_VISIBLE_DEVICES: "3" # TODO: Confirm GPU 3 FINETUNING_PORT: ${FINETUNING_PORT} LITELLM_URL: "http://host.docker.internal:${LITELLM_PORT}" LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} - MAX_JOB_DURATION_HOURS: ${MAX_JOB_DURATION_HOURS} - deploy: - resources: - reservations: - devices: - - driver: nvidia - device_ids: ['3'] # TODO: Confirm GPU 3 - capabilities: [gpu] volumes: + - ./config.yaml:/app/config.yaml:ro - finetuning_jobs:/data ports: - "127.0.0.1:${FINETUNING_PORT}:${FINETUNING_PORT}" @@ -46,5 +43,33 @@ services: max-size: "50m" max-file: "3" + # --------------------------------------------------------------------------- + # Worker + # --------------------------------------------------------------------------- + worker: + build: + context: . + dockerfile: Dockerfile.worker + container_name: finetuning-worker + restart: unless-stopped + environment: + NVIDIA_VISIBLE_DEVICES: "${DEVICE}" + MAX_JOB_DURATION_HOURS: ${MAX_JOB_DURATION_HOURS} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["${DEVICE}"] + capabilities: [gpu] + volumes: + - ./config.yaml:/app/config.yaml:ro + - finetuning_jobs:/data + logging: + driver: json-file + options: + max-size: "50m" + max-file: "3" + volumes: finetuning_jobs: diff --git a/stacks/finetuning-service/requirements.api.txt b/stacks/finetuning-service/requirements.api.txt new file mode 100644 index 0000000..36b7287 --- /dev/null +++ b/stacks/finetuning-service/requirements.api.txt @@ -0,0 +1,4 @@ +uvicorn +fastapi +pyyaml +typing_extensions diff --git a/stacks/finetuning-service/requirements.txt b/stacks/finetuning-service/requirements.txt deleted file mode 100644 index d1f80a3..0000000 --- a/stacks/finetuning-service/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -uvicorn -fastapi \ No newline at end of file diff --git a/stacks/finetuning-service/requirements.worker.txt b/stacks/finetuning-service/requirements.worker.txt new file mode 100644 index 0000000..fa9b2de --- /dev/null +++ b/stacks/finetuning-service/requirements.worker.txt @@ -0,0 +1,2 @@ +pyyaml +flash-attn-4[cu13]==4.0.0b10 diff --git a/stacks/finetuning-service/worker.py b/stacks/finetuning-service/worker.py new file mode 100644 index 0000000..58cd6b2 --- /dev/null +++ b/stacks/finetuning-service/worker.py @@ -0,0 +1,214 @@ +"""Queue worker for processing fine-tuning jobs.""" + +import json +import logging +import os +import shutil +import subprocess +import time +from pathlib import Path +from typing import Optional + +import yaml +from app.config import load_config +from app.database import ( + claim_job, + complete_job, + get_next_queued_job, + init_db, + recover_running_jobs, +) +from app.models import JobStatus + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) +log = logging.getLogger(__name__) + +POLL_INTERVAL = int(os.getenv("POLL_INTERVAL_SECONDS", "10")) +MAX_DURATION = int(os.getenv("MAX_JOB_DURATION_HOURS", "4")) * 3600 +WORK_DIR = Path("/tmp/finetuning") + + +def build_axolotl_config( + job_id: str, + model: str, + hf_dataset: str, + hub_model_id: str, + user_config: dict, +) -> Path: + """Build an Axolotl config file for a job. + + Merges the service base config with user-provided settings + and writes the result to a per-job temp directory. + + Args: + job_id: The UUID of the job. + model: The HuggingFace model repo path. + hf_dataset: The HuggingFace dataset repo path. + hub_model_id: The destination HuggingFace repo for the adapter. + user_config: The user-provided training configuration. + + Returns: + Path to the generated config file. + """ + job_dir = WORK_DIR / job_id + output_dir = job_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + base = load_config().get("axolotl_base_config", {}) + lora_r = user_config["lora_r"] + + config = { + **base, + "base_model": model, + "output_dir": str(output_dir), + "hub_model_id": hub_model_id, + "datasets": [ + { + "path": hf_dataset, + "type": "chat_template", + "split": "train", + } + ], + **( + { + "test_datasets": [ + { + "path": hf_dataset, + "type": "chat_template", + "split": "validation", + } + ] + } + if user_config.get("do_eval") + else {"eval_strategy": "no"} + ), + **( + { + "use_wandb": True, + "wandb_project": user_config["wandb_project"], + **( + {"wandb_entity": user_config["wandb_entity"]} + if user_config.get("wandb_entity") + else {} + ), + } + if user_config.get("wandb_project") + else {} + ), + "num_epochs": user_config["num_epochs"], + "learning_rate": user_config["learning_rate"], + "micro_batch_size": user_config["micro_batch_size"], + "gradient_accumulation_steps": ( + user_config["gradient_accumulation_steps"] + ), + "sequence_len": user_config["sequence_len"], + "lora_r": lora_r, + "lora_alpha": lora_r * 2, + "lora_dropout": user_config["lora_dropout"], + "lora_target_modules": user_config["lora_target_modules"], + "load_in_4bit": user_config["load_in_4bit"], + "load_in_8bit": user_config["load_in_8bit"], + } + + config_path = job_dir / "config.yaml" + with config_path.open("w") as f: + yaml.dump(config, f) + + log.info("Generated Axolotl config for job %s.", job_id) + return config_path + + +def run_job( + job_id: str, + model: str, + hf_dataset: str, + hf_token: str, + hub_model_id: str, + wandb_token: Optional[str], + user_config: dict, +) -> None: + """Execute a fine-tuning job. + + Generates the Axolotl config, runs the training subprocess, + and cleans up the temp directory on completion or failure. + + Args: + job_id: The UUID of the job. + model: The HuggingFace model repo path. + hf_dataset: The HuggingFace dataset repo path. + hf_token: The HuggingFace token for dataset and Hub access. + hub_model_id: The destination HuggingFace repo for the adapter. + wandb_token: Optional Weights & Biases API key. + user_config: The user-provided training configuration. + """ + job_dir = WORK_DIR / job_id + error_message: Optional[str] = None + final_status = JobStatus.SUCCEEDED + + try: + config_path = build_axolotl_config( + job_id, model, hf_dataset, hub_model_id, user_config + ) + env = {**os.environ, "HF_TOKEN": hf_token} + if wandb_token: + env["WANDB_API_KEY"] = wandb_token + subprocess.run( + ["axolotl", "train", str(config_path)], + env=env, + timeout=MAX_DURATION, + check=True, + ) + log.info("Job %s completed successfully.", job_id) + except subprocess.TimeoutExpired: + error_message = "Job exceeded maximum wall-clock duration." + final_status = JobStatus.FAILED + log.error("Job %s timed out.", job_id) + except subprocess.CalledProcessError as e: + error_message = f"Training failed with exit code {e.returncode}." + final_status = JobStatus.FAILED + log.error("Job %s failed: %s", job_id, e) + finally: + complete_job(job_id, final_status, error_message) + if job_dir.exists(): + shutil.rmtree(job_dir) + log.info("Cleaned up temp dir for job %s.", job_id) + + +def main() -> None: + """Main worker loop. + + Polls the job queue at a fixed interval and processes + one job at a time. + """ + init_db() + recover_running_jobs() + log.info( + "Worker started. Polling every %ds, max job duration %dh.", + POLL_INTERVAL, + MAX_DURATION // 3600, + ) + while True: + job = get_next_queued_job() + if job: + log.info("Claiming job %s.", job.id) + if claim_job(job.id): + log.info("Running job %s.", job.id) + user_config = json.loads(job.config) + run_job( + job.id, + job.model, + job.hf_dataset, + job.hf_token, + job.hub_model_id, + job.wandb_token, + user_config, + ) + else: + time.sleep(POLL_INTERVAL) + + +if __name__ == "__main__": + main()