From b3e21abac8dd86ed11636ccaaef6603cba638fa5 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Thu, 23 Apr 2026 14:24:07 +0100 Subject: [PATCH 1/9] feat(stacks): add lightweight user config for fine-tuning service and pre-axolotl validation --- stacks/finetuning-service/app/config.py | 26 +++++++ stacks/finetuning-service/app/database.py | 13 +++- stacks/finetuning-service/app/models.py | 76 +++++++++++++++++++- stacks/finetuning-service/config.yaml | 10 +++ stacks/finetuning-service/docker-compose.yml | 1 + stacks/finetuning-service/requirements.txt | 3 +- 6 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 stacks/finetuning-service/app/config.py create mode 100644 stacks/finetuning-service/config.yaml diff --git a/stacks/finetuning-service/app/config.py b/stacks/finetuning-service/app/config.py new file mode 100644 index 0000000..98b7fea --- /dev/null +++ b/stacks/finetuning-service/app/config.py @@ -0,0 +1,26 @@ +"""Service configuration loader.""" + +from pathlib import Path + +import yaml + +CONFIG_PATH = Path("/app/config.yaml") + + +def load_config() -> dict: + """Load the service configuration from disk. + + Returns: + The parsed configuration dictionary. + """ + with CONFIG_PATH.open() as f: + return yaml.safe_load(f) + + +def get_allowed_models() -> list[str]: + """Return the list of models permitted for fine-tuning. + + Returns: + A list of allowed Hugging Face model repo paths. + """ + return load_config().get("allowed_models", []) diff --git a/stacks/finetuning-service/app/database.py b/stacks/finetuning-service/app/database.py index dd940d2..8506277 100644 --- a/stacks/finetuning-service/app/database.py +++ b/stacks/finetuning-service/app/database.py @@ -105,7 +105,18 @@ def create_job(request: JobSubmitRequest) -> JobResponse: now = datetime.now(timezone.utc).isoformat() config = json.dumps( { - # TODO: Add hyperparameters/config keys + "num_epochs": request.num_epochs, + "learning_rate": request.learning_rate, + "micro_batch_size": request.micro_batch_size, + "gradient_accumulation_steps": ( + request.gradient_accumulation_steps + ), + "sequence_len": request.sequence_len, + "lora_r": request.lora_r, + "lora_dropout": request.lora_dropout, + "lora_target_modules": request.lora_target_modules, + "load_in_4bit": request.load_in_4bit, + "load_in_8bit": request.load_in_8bit, } ) with get_connection() as conn: diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index 2009014..074264b 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -4,7 +4,10 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self + +from .config import get_allowed_models class JobStatus(str, Enum): @@ -34,8 +37,75 @@ class JobSubmitRequest(BaseModel): default=None, description="Label appended to the adapter repository name.", ) - # TODO: add hyperparameters/config keys - # (minimal subset of keys from axolotl config reference) + num_epochs: int = Field(ge=1) + learning_rate: float = Field(gt=0) + micro_batch_size: int = Field(default=4, ge=1) + gradient_accumulation_steps: int = Field(default=1, ge=1) + sequence_len: int = Field(default=512, ge=64) + lora_r: int = Field(ge=1) + lora_dropout: float = Field(default=0.0, ge=0.0, le=1.0) + lora_target_modules: list[str] = Field( + default=["q_proj", "v_proj", "k_proj", "o_proj"] + ) + load_in_4bit: bool = Field(default=False) + load_in_8bit: bool = Field(default=False) + + @field_validator("model") + @classmethod + def model_must_be_whitelisted(cls, v: str) -> str: + """Validate the model is on the allowed list. + + Args: + v: The model repo path to validate. + + Returns: + The validated model path. + + Raises: + ValueError: If the model is not on the whitelist. + """ + allowed = get_allowed_models() + if v not in allowed: + raise ValueError( + f"Model '{v}' is not permitted. Allowed models: {allowed}" + ) + return v + + @field_validator("lora_target_modules") + @classmethod + def lora_target_modules_must_not_be_empty(cls, v: list[str]) -> list[str]: + """Validate that at least one LoRA target module is specified. + + Args: + v: The list of target modules. + + Returns: + The validated list. + + Raises: + ValueError: If the list is empty. + """ + if not v: + raise ValueError( + "lora_target_modules must contain at least one module." + ) + return v + + @model_validator(mode="after") + def quantisation_modes_are_mutually_exclusive(self) -> Self: + """Validate that 4-bit and 8-bit quantisation are not both set. + + Returns: + The validated model instance. + + Raises: + ValueError: If both load_in_4bit and load_in_8bit are True. + """ + if self.load_in_4bit and self.load_in_8bit: + raise ValueError( + "load_in_4bit and load_in_8bit are mutually exclusive." + ) + return self class JobResponse(BaseModel): diff --git a/stacks/finetuning-service/config.yaml b/stacks/finetuning-service/config.yaml new file mode 100644 index 0000000..e128183 --- /dev/null +++ b/stacks/finetuning-service/config.yaml @@ -0,0 +1,10 @@ +# ============================================================================= +# Fine-Tuning Service Configuration +# ============================================================================= +# +# Edit this file to update service settings without rebuilding the image. +# +# ============================================================================= + +allowed_models: + - meta-llama/Llama-3.1-8B-Instruct diff --git a/stacks/finetuning-service/docker-compose.yml b/stacks/finetuning-service/docker-compose.yml index 9a28218..c24c511 100644 --- a/stacks/finetuning-service/docker-compose.yml +++ b/stacks/finetuning-service/docker-compose.yml @@ -29,6 +29,7 @@ services: device_ids: ['3'] # TODO: Confirm GPU 3 capabilities: [gpu] volumes: + - ./config.yaml:/app/config.yaml:ro - finetuning_jobs:/data ports: - "127.0.0.1:${FINETUNING_PORT}:${FINETUNING_PORT}" diff --git a/stacks/finetuning-service/requirements.txt b/stacks/finetuning-service/requirements.txt index d1f80a3..b407bac 100644 --- a/stacks/finetuning-service/requirements.txt +++ b/stacks/finetuning-service/requirements.txt @@ -1,2 +1,3 @@ uvicorn -fastapi \ No newline at end of file +fastapi +pyyaml From c2c46835187867e30e253fabe5773fea531be7a8 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Thu, 23 Apr 2026 15:53:15 +0100 Subject: [PATCH 2/9] feat(stacks): add axolotl subprocess with separate worker --- stacks/finetuning-service/Dockerfile | 11 -- stacks/finetuning-service/Dockerfile.api | 19 ++ stacks/finetuning-service/Dockerfile.worker | 22 +++ stacks/finetuning-service/app/database.py | 112 ++++++++++- stacks/finetuning-service/app/models.py | 10 +- stacks/finetuning-service/config.yaml | 27 ++- stacks/finetuning-service/docker-compose.yml | 48 +++-- ...{requirements.txt => requirements.api.txt} | 0 .../requirements.worker.txt | 1 + stacks/finetuning-service/worker.py | 185 ++++++++++++++++++ 10 files changed, 400 insertions(+), 35 deletions(-) delete mode 100644 stacks/finetuning-service/Dockerfile create mode 100644 stacks/finetuning-service/Dockerfile.api create mode 100644 stacks/finetuning-service/Dockerfile.worker rename stacks/finetuning-service/{requirements.txt => requirements.api.txt} (100%) create mode 100644 stacks/finetuning-service/requirements.worker.txt create mode 100644 stacks/finetuning-service/worker.py diff --git a/stacks/finetuning-service/Dockerfile b/stacks/finetuning-service/Dockerfile deleted file mode 100644 index 5b154c5..0000000 --- a/stacks/finetuning-service/Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM axolotlai/axolotl:main-py3.12-cu130-2.10.0 - -WORKDIR /app - -# Install service dependencies on top of Axolotl's environment -COPY requirements.txt . -RUN pip install -r requirements.txt - -COPY app/ ./app/ - -CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${FINETUNING_PORT}"] diff --git a/stacks/finetuning-service/Dockerfile.api b/stacks/finetuning-service/Dockerfile.api new file mode 100644 index 0000000..dca1ecb --- /dev/null +++ b/stacks/finetuning-service/Dockerfile.api @@ -0,0 +1,19 @@ +# ============================================================================= +# Dockerfile.api - Fine-Tuning Service API +# ============================================================================= +# +# Lightweight image for the FastAPI service only. +# The Axolotl worker uses a separate image (Dockerfile.worker). +# +# ============================================================================= + +FROM python:3.12-slim + +WORKDIR /app + +COPY requirements.api.txt . +RUN pip install -r requirements.api.txt + +COPY app/ ./app/ + +CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${FINETUNING_PORT}"] diff --git a/stacks/finetuning-service/Dockerfile.worker b/stacks/finetuning-service/Dockerfile.worker new file mode 100644 index 0000000..0940b38 --- /dev/null +++ b/stacks/finetuning-service/Dockerfile.worker @@ -0,0 +1,22 @@ +# ============================================================================= +# Dockerfile.worker - Fine-Tuning Worker +# ============================================================================= +# +# Axolotl-based image for the queue worker that runs training jobs. +# The FastAPI service uses a separate lightweight image (Dockerfile.api). +# +# Update the tag deliberately when upgrading Axolotl. +# +# ============================================================================= + +FROM axolotlai/axolotl:main-py3.12-cu130-2.10.0 + +WORKDIR /app + +COPY requirements.worker.txt . +RUN pip install -r requirements.worker.txt + +COPY app/ ./app/ +COPY worker.py . + +CMD ["python", "worker.py"] diff --git a/stacks/finetuning-service/app/database.py b/stacks/finetuning-service/app/database.py index 8506277..cc6d8dd 100644 --- a/stacks/finetuning-service/app/database.py +++ b/stacks/finetuning-service/app/database.py @@ -4,6 +4,7 @@ import sqlite3 import uuid from contextlib import contextmanager +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Generator, Optional @@ -13,6 +14,23 @@ DB_PATH = Path("/data/jobs.db") +@dataclass +class JobDetail: + """Internal job representation including sensitive fields. + + Not exposed via the API. Used by the worker to access the + HF token and full training config. + """ + + id: str + model: str + hf_dataset: str + hf_token: str + hub_model_id: str + config: str + status: JobStatus + + @contextmanager def get_connection() -> Generator[sqlite3.Connection, None, None]: """Yield a database connection with row factory configured. @@ -43,7 +61,7 @@ def _row_to_job(row: sqlite3.Row) -> JobResponse: status=JobStatus(row["status"]), model=row["model"], hf_dataset=row["hf_dataset"], - suffix=row["suffix"], + hub_model_id=row["hub_model_id"], created_at=datetime.fromisoformat(row["created_at"]), started_at=( datetime.fromisoformat(row["started_at"]) @@ -68,7 +86,8 @@ def init_db() -> None: status TEXT NOT NULL, model TEXT NOT NULL, hf_dataset TEXT NOT NULL, - suffix TEXT, + hf_token TEXT, + hub_model_id TEXT NOT NULL, config TEXT NOT NULL, created_at TEXT NOT NULL, started_at TEXT, @@ -123,16 +142,17 @@ def create_job(request: JobSubmitRequest) -> JobResponse: conn.execute( """ INSERT INTO jobs ( - id, status, model, hf_dataset, - suffix, config, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?) + id, status, model, hf_dataset, hf_token, + hub_model_id, config, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( job_id, JobStatus.QUEUED, request.model, request.hf_dataset, - request.suffix, + request.hf_token, + request.hub_model_id, config, now, ), @@ -142,7 +162,7 @@ def create_job(request: JobSubmitRequest) -> JobResponse: status=JobStatus.QUEUED, model=request.model, hf_dataset=request.hf_dataset, - suffix=request.suffix, + hub_model_id=request.hub_model_id, created_at=datetime.fromisoformat(now), ) @@ -197,3 +217,81 @@ def cancel_job(job_id: str) -> Optional[JobResponse]: (JobStatus.CANCELLED, job_id, JobStatus.QUEUED), ) return get_job(job_id) + + +def get_next_queued_job() -> Optional[JobDetail]: + """Fetch the oldest queued job for the worker to process. + + Returns: + A JobDetail if a queued job exists, otherwise None. + """ + with get_connection() as conn: + row = conn.execute( + """ + SELECT * FROM jobs + WHERE status = ? + ORDER BY created_at ASC + LIMIT 1 + """, + (JobStatus.QUEUED,), + ).fetchone() + if not row: + return None + return JobDetail( + id=row["id"], + model=row["model"], + hf_dataset=row["hf_dataset"], + hf_token=row["hf_token"], + hub_model_id=row["hub_model_id"], + config=row["config"], + status=JobStatus(row["status"]), + ) + + +def claim_job(job_id: str) -> bool: + """Atomically mark a queued job as running. + + Args: + job_id: The UUID of the job to claim. + + Returns: + True if the job was successfully claimed, False if it + was already claimed by another process. + """ + now = datetime.now(timezone.utc).isoformat() + with get_connection() as conn: + result = conn.execute( + """ + UPDATE jobs SET status = ?, started_at = ? + WHERE id = ? AND status = ? + """, + (JobStatus.RUNNING, now, job_id, JobStatus.QUEUED), + ) + return result.rowcount > 0 + + +def complete_job( + job_id: str, + status: JobStatus, + error_message: Optional[str] = None, +) -> None: + """Mark a job as completed and clear its HF token. + + Args: + job_id: The UUID of the job to complete. + status: The final status (succeeded or failed). + error_message: Optional error description for failed jobs. + """ + now = datetime.now(timezone.utc).isoformat() + with get_connection() as conn: + conn.execute( + """ + UPDATE jobs + SET status = ?, + completed_at = ?, + error_message = ?, + hf_token = NULL + WHERE id = ? + """, + (status, now, error_message, job_id), + ) diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index 074264b..a2c5f44 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -33,9 +33,11 @@ class JobSubmitRequest(BaseModel): "and write access to the adapter destination." ) ) - suffix: Optional[str] = Field( - default=None, - description="Label appended to the adapter repository name.", + hub_model_id: str = Field( + description=( + "Hugging Face repo path to push the trained adapter to " + "(e.g. 'username/my-adapter')." + ) ) num_epochs: int = Field(ge=1) learning_rate: float = Field(gt=0) @@ -115,7 +117,7 @@ class JobResponse(BaseModel): status: JobStatus model: str hf_dataset: str - suffix: Optional[str] = None + hub_model_id: str created_at: datetime started_at: Optional[datetime] = None completed_at: Optional[datetime] = None diff --git a/stacks/finetuning-service/config.yaml b/stacks/finetuning-service/config.yaml index e128183..475789c 100644 --- a/stacks/finetuning-service/config.yaml +++ b/stacks/finetuning-service/config.yaml @@ -2,9 +2,34 @@ # Fine-Tuning Service Configuration # ============================================================================= # -# Edit this file to update service settings without rebuilding the image. +# Edit this file and restart the service to update settings +# without rebuilding the image. # # ============================================================================= allowed_models: - meta-llama/Llama-3.1-8B-Instruct + +axolotl_base_config: + # Adapter + adapter: lora + + # Precision + bf16: true + tf32: true + + # Memory optimisation + flash_attention: true + gradient_checkpointing: true + + # Evaluation and checkpointing + eval_strategy: epoch + save_strategy: epoch + save_total_limit: 2 + save_safetensors: true + + # Hub push + hub_strategy: end + + # Security + trust_remote_code: false diff --git a/stacks/finetuning-service/docker-compose.yml b/stacks/finetuning-service/docker-compose.yml index c24c511..68dd9a1 100644 --- a/stacks/finetuning-service/docker-compose.yml +++ b/stacks/finetuning-service/docker-compose.yml @@ -11,23 +11,19 @@ # ============================================================================= services: - finetuning-service: - build: . - container_name: finetuning-service + # --------------------------------------------------------------------------- + # API + # --------------------------------------------------------------------------- + api: + build: + context: . + dockerfile: Dockerfile.api + container_name: finetuning-api restart: unless-stopped environment: - NVIDIA_VISIBLE_DEVICES: "3" # TODO: Confirm GPU 3 FINETUNING_PORT: ${FINETUNING_PORT} LITELLM_URL: "http://host.docker.internal:${LITELLM_PORT}" LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} - MAX_JOB_DURATION_HOURS: ${MAX_JOB_DURATION_HOURS} - deploy: - resources: - reservations: - devices: - - driver: nvidia - device_ids: ['3'] # TODO: Confirm GPU 3 - capabilities: [gpu] volumes: - ./config.yaml:/app/config.yaml:ro - finetuning_jobs:/data @@ -47,5 +43,33 @@ services: max-size: "50m" max-file: "3" + # --------------------------------------------------------------------------- + # Worker + # --------------------------------------------------------------------------- + worker: + build: + context: . + dockerfile: Dockerfile.worker + container_name: finetuning-worker + restart: unless-stopped + environment: + NVIDIA_VISIBLE_DEVICES: "3" + MAX_JOB_DURATION_HOURS: ${MAX_JOB_DURATION_HOURS} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['3'] + capabilities: [gpu] + volumes: + - ./config.yaml:/app/config.yaml:ro + - finetuning_jobs:/data + logging: + driver: json-file + options: + max-size: "50m" + max-file: "3" + volumes: finetuning_jobs: diff --git a/stacks/finetuning-service/requirements.txt b/stacks/finetuning-service/requirements.api.txt similarity index 100% rename from stacks/finetuning-service/requirements.txt rename to stacks/finetuning-service/requirements.api.txt diff --git a/stacks/finetuning-service/requirements.worker.txt b/stacks/finetuning-service/requirements.worker.txt new file mode 100644 index 0000000..c3726e8 --- /dev/null +++ b/stacks/finetuning-service/requirements.worker.txt @@ -0,0 +1 @@ +pyyaml diff --git a/stacks/finetuning-service/worker.py b/stacks/finetuning-service/worker.py new file mode 100644 index 0000000..35ca0c3 --- /dev/null +++ b/stacks/finetuning-service/worker.py @@ -0,0 +1,185 @@ +"""Queue worker for processing fine-tuning jobs.""" + +import json +import logging +import os +import shutil +import subprocess +import time +from pathlib import Path +from typing import Optional + +import yaml +from app.config import load_config +from app.database import ( + claim_job, + complete_job, + get_next_queued_job, +) +from app.models import JobStatus + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) +log = logging.getLogger(__name__) + +POLL_INTERVAL = int(os.getenv("POLL_INTERVAL_SECONDS", "10")) +MAX_DURATION = int(os.getenv("MAX_JOB_DURATION_HOURS", "4")) * 3600 +WORK_DIR = Path("/tmp/finetuning") + + +def build_axolotl_config( + job_id: str, + model: str, + hf_dataset: str, + hub_model_id: str, + user_config: dict, +) -> Path: + """Build an Axolotl config file for a job. + + Merges the service base config with user-provided settings + and writes the result to a per-job temp directory. + + Args: + job_id: The UUID of the job. + model: The HuggingFace model repo path. + hf_dataset: The HuggingFace dataset repo path. + hub_model_id: The destination HuggingFace repo for the adapter. + user_config: The user-provided training configuration. + + Returns: + Path to the generated config file. + """ + job_dir = WORK_DIR / job_id + output_dir = job_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + base = load_config().get("axolotl_base_config", {}) + lora_r = user_config["lora_r"] + + config = { + **base, + "base_model": model, + "output_dir": str(output_dir), + "hub_model_id": hub_model_id, + "datasets": [ + { + "path": hf_dataset, + "type": "chat_template", + "split": "train", + } + ], + "test_datasets": [ + { + "path": hf_dataset, + "split": "validation", + } + ], + "num_epochs": user_config["num_epochs"], + "learning_rate": user_config["learning_rate"], + "micro_batch_size": user_config["micro_batch_size"], + "gradient_accumulation_steps": ( + user_config["gradient_accumulation_steps"] + ), + "sequence_len": user_config["sequence_len"], + "lora_r": lora_r, + "lora_alpha": lora_r * 2, + "lora_dropout": user_config["lora_dropout"], + "lora_target_modules": user_config["lora_target_modules"], + "load_in_4bit": user_config["load_in_4bit"], + "load_in_8bit": user_config["load_in_8bit"], + } + + config_path = job_dir / "config.yaml" + with config_path.open("w") as f: + yaml.dump(config, f) + + log.info("Generated Axolotl config for job %s.", job_id) + return config_path + + +def run_job( + job_id: str, + model: str, + hf_dataset: str, + hf_token: str, + hub_model_id: str, + user_config: dict, +) -> None: + """Execute a fine-tuning job. + + Generates the Axolotl config, runs the training subprocess, + and cleans up the temp directory on completion or failure. + + Args: + job_id: The UUID of the job. + model: The HuggingFace model repo path. + hf_dataset: The HuggingFace dataset repo path. + hf_token: The HuggingFace token for dataset and Hub access. + hub_model_id: The destination HuggingFace repo for the adapter. + user_config: The user-provided training configuration. + """ + job_dir = WORK_DIR / job_id + error_message: Optional[str] = None + final_status = JobStatus.SUCCEEDED + + try: + config_path = build_axolotl_config( + job_id, model, hf_dataset, hub_model_id, user_config + ) + env = {**os.environ, "HF_TOKEN": hf_token} + subprocess.run( + ["axolotl", "train", str(config_path)], + env=env, + timeout=MAX_DURATION, + check=True, + ) + log.info("Job %s completed successfully.", job_id) + except subprocess.TimeoutExpired: + error_message = "Job exceeded maximum wall-clock duration." + final_status = JobStatus.FAILED + log.error("Job %s timed out.", job_id) + except subprocess.CalledProcessError as e: + error_message = f"Training failed with exit code {e.returncode}." + final_status = JobStatus.FAILED + log.error("Job %s failed: %s", job_id, e) + finally: + complete_job(job_id, final_status, error_message) + if job_dir.exists(): + shutil.rmtree(job_dir) + log.info("Cleaned up temp dir for job %s.", job_id) + + +def main() -> None: + """Main worker loop. + + Polls the job queue at a fixed interval and processes + one job at a time. + """ + log.info( + "Worker started. Polling every %ds, max job duration %dh.", + POLL_INTERVAL, + MAX_DURATION // 3600, + ) + while True: + job = get_next_queued_job() + if job: + log.info("Claiming job %s.", job.id) + if claim_job(job.id): + log.info("Running job %s.", job.id) + user_config = json.loads(job.config) + run_job( + job.id, + job.model, + job.hf_dataset, + job.hf_token, + job.hub_model_id, + user_config, + ) + else: + time.sleep(POLL_INTERVAL) + + +if __name__ == "__main__": + main() From a499968f3ec8cbeb77b1e21a659fdb9a592385a5 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Thu, 23 Apr 2026 16:18:41 +0100 Subject: [PATCH 3/9] chore: add chat template to test dataset and initialise db from axolotl worker --- stacks/finetuning-service/requirements.api.txt | 1 + stacks/finetuning-service/worker.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/stacks/finetuning-service/requirements.api.txt b/stacks/finetuning-service/requirements.api.txt index b407bac..36b7287 100644 --- a/stacks/finetuning-service/requirements.api.txt +++ b/stacks/finetuning-service/requirements.api.txt @@ -1,3 +1,4 @@ uvicorn fastapi pyyaml +typing_extensions diff --git a/stacks/finetuning-service/worker.py b/stacks/finetuning-service/worker.py index 35ca0c3..d871f59 100644 --- a/stacks/finetuning-service/worker.py +++ b/stacks/finetuning-service/worker.py @@ -15,6 +15,8 @@ claim_job, complete_job, get_next_queued_job, + init_db, + recover_running_jobs, ) from app.models import JobStatus @@ -73,6 +75,7 @@ def build_axolotl_config( "test_datasets": [ { "path": hf_dataset, + "type": "chat_template", "split": "validation", } ], @@ -157,6 +160,8 @@ def main() -> None: Polls the job queue at a fixed interval and processes one job at a time. """ + init_db() + recover_running_jobs() log.info( "Worker started. Polling every %ds, max job duration %dh.", POLL_INTERVAL, From f91ff8e82e0e82daa3a844183aa059a695d13c0d Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Fri, 24 Apr 2026 13:44:31 +0100 Subject: [PATCH 4/9] feat(stacks): let device for fine-tuning service be a environment variable --- stacks/finetuning-service/.env.example | 2 ++ stacks/finetuning-service/docker-compose.yml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/stacks/finetuning-service/.env.example b/stacks/finetuning-service/.env.example index e1b6c0f..bb8146a 100644 --- a/stacks/finetuning-service/.env.example +++ b/stacks/finetuning-service/.env.example @@ -12,6 +12,8 @@ # ----------------------------------------------------------------------------- # Fine-Tuning Service # ----------------------------------------------------------------------------- +DEVICE=3 +# Multi-gpu not currently supported FINETUNING_PORT=8005 MAX_JOB_DURATION_HOURS=4 diff --git a/stacks/finetuning-service/docker-compose.yml b/stacks/finetuning-service/docker-compose.yml index 68dd9a1..01e23e6 100644 --- a/stacks/finetuning-service/docker-compose.yml +++ b/stacks/finetuning-service/docker-compose.yml @@ -53,14 +53,14 @@ services: container_name: finetuning-worker restart: unless-stopped environment: - NVIDIA_VISIBLE_DEVICES: "3" + NVIDIA_VISIBLE_DEVICES: "${DEVICE}" MAX_JOB_DURATION_HOURS: ${MAX_JOB_DURATION_HOURS} deploy: resources: reservations: devices: - driver: nvidia - device_ids: ['3'] + device_ids: ["${DEVICE}"] capabilities: [gpu] volumes: - ./config.yaml:/app/config.yaml:ro From b6f4eef654c6d2eb5daeb4d05de131a7a51b14f2 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Fri, 24 Apr 2026 14:44:52 +0100 Subject: [PATCH 5/9] feat(deps): add FA4 --- stacks/finetuning-service/requirements.worker.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/stacks/finetuning-service/requirements.worker.txt b/stacks/finetuning-service/requirements.worker.txt index c3726e8..56430d2 100644 --- a/stacks/finetuning-service/requirements.worker.txt +++ b/stacks/finetuning-service/requirements.worker.txt @@ -1 +1,2 @@ pyyaml +flash-attn-4 From 95e0393b4b3803e8a00e0ab097cae510d9b684ea Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Fri, 24 Apr 2026 15:40:22 +0100 Subject: [PATCH 6/9] fix: pin FA4 version with cu13 wheels, increase default sequence_len, add sample packing (training only) --- stacks/finetuning-service/app/models.py | 2 +- stacks/finetuning-service/config.yaml | 2 ++ stacks/finetuning-service/requirements.worker.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index a2c5f44..5dd85e1 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -43,7 +43,7 @@ class JobSubmitRequest(BaseModel): learning_rate: float = Field(gt=0) micro_batch_size: int = Field(default=4, ge=1) gradient_accumulation_steps: int = Field(default=1, ge=1) - sequence_len: int = Field(default=512, ge=64) + sequence_len: int = Field(default=2048, ge=64) lora_r: int = Field(ge=1) lora_dropout: float = Field(default=0.0, ge=0.0, le=1.0) lora_target_modules: list[str] = Field( diff --git a/stacks/finetuning-service/config.yaml b/stacks/finetuning-service/config.yaml index 475789c..0ba4f58 100644 --- a/stacks/finetuning-service/config.yaml +++ b/stacks/finetuning-service/config.yaml @@ -21,6 +21,8 @@ axolotl_base_config: # Memory optimisation flash_attention: true gradient_checkpointing: true + sample_packing: true + eval_sample_packing: false # Evaluation and checkpointing eval_strategy: epoch diff --git a/stacks/finetuning-service/requirements.worker.txt b/stacks/finetuning-service/requirements.worker.txt index 56430d2..fa9b2de 100644 --- a/stacks/finetuning-service/requirements.worker.txt +++ b/stacks/finetuning-service/requirements.worker.txt @@ -1,2 +1,2 @@ pyyaml -flash-attn-4 +flash-attn-4[cu13]==4.0.0b10 From d44b181a7e915e0cd6954cea51df61036eda1c24 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Mon, 27 Apr 2026 11:57:34 +0100 Subject: [PATCH 7/9] feat(stacks): add do_eval as key in fine-tuning service --- stacks/finetuning-service/app/models.py | 1 + stacks/finetuning-service/worker.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index 5dd85e1..535d6a4 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -51,6 +51,7 @@ class JobSubmitRequest(BaseModel): ) load_in_4bit: bool = Field(default=False) load_in_8bit: bool = Field(default=False) + do_eval: bool = Field(default=False) @field_validator("model") @classmethod diff --git a/stacks/finetuning-service/worker.py b/stacks/finetuning-service/worker.py index d871f59..c71ffb5 100644 --- a/stacks/finetuning-service/worker.py +++ b/stacks/finetuning-service/worker.py @@ -72,13 +72,19 @@ def build_axolotl_config( "split": "train", } ], - "test_datasets": [ + **( { - "path": hf_dataset, - "type": "chat_template", - "split": "validation", + "test_datasets": [ + { + "path": hf_dataset, + "type": "chat_template", + "split": "validation", + } + ] } - ], + if user_config.get("do_eval") + else {"eval_strategy": "no"} + ), "num_epochs": user_config["num_epochs"], "learning_rate": user_config["learning_rate"], "micro_batch_size": user_config["micro_batch_size"], From 3d7e2f3b398208fbf8859950387128e78042e5bb Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Mon, 27 Apr 2026 12:18:09 +0100 Subject: [PATCH 8/9] feat(stacks): add wandb tracking --- stacks/finetuning-service/app/database.py | 16 +++++++++++---- stacks/finetuning-service/app/models.py | 25 +++++++++++++++++++++++ stacks/finetuning-service/worker.py | 18 ++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/stacks/finetuning-service/app/database.py b/stacks/finetuning-service/app/database.py index cc6d8dd..7f56103 100644 --- a/stacks/finetuning-service/app/database.py +++ b/stacks/finetuning-service/app/database.py @@ -27,6 +27,7 @@ class JobDetail: hf_dataset: str hf_token: str hub_model_id: str + wandb_token: Optional[str] config: str status: JobStatus @@ -88,6 +89,7 @@ def init_db() -> None: hf_dataset TEXT NOT NULL, hf_token TEXT, hub_model_id TEXT NOT NULL, + wandb_token TEXT, config TEXT NOT NULL, created_at TEXT NOT NULL, started_at TEXT, @@ -136,6 +138,9 @@ def create_job(request: JobSubmitRequest) -> JobResponse: "lora_target_modules": request.lora_target_modules, "load_in_4bit": request.load_in_4bit, "load_in_8bit": request.load_in_8bit, + "do_eval": request.do_eval, + "wandb_project": request.wandb_project, + "wandb_entity": request.wandb_entity, } ) with get_connection() as conn: @@ -143,8 +148,8 @@ def create_job(request: JobSubmitRequest) -> JobResponse: """ INSERT INTO jobs ( id, status, model, hf_dataset, hf_token, - hub_model_id, config, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + hub_model_id, wandb_token, config, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( job_id, @@ -153,6 +158,7 @@ def create_job(request: JobSubmitRequest) -> JobResponse: request.hf_dataset, request.hf_token, request.hub_model_id, + request.wandb_token, config, now, ), @@ -243,6 +249,7 @@ def get_next_queued_job() -> Optional[JobDetail]: hf_dataset=row["hf_dataset"], hf_token=row["hf_token"], hub_model_id=row["hub_model_id"], + wandb_token=row["wandb_token"], config=row["config"], status=JobStatus(row["status"]), ) @@ -275,7 +282,7 @@ def complete_job( status: JobStatus, error_message: Optional[str] = None, ) -> None: - """Mark a job as completed and clear its HF token. + """Mark a job as completed and clear its tokens. Args: job_id: The UUID of the job to complete. @@ -290,7 +297,8 @@ def complete_job( SET status = ?, completed_at = ?, error_message = ?, - hf_token = NULL + hf_token = NULL, + wandb_token = NULL WHERE id = ? """, (status, now, error_message, job_id), diff --git a/stacks/finetuning-service/app/models.py b/stacks/finetuning-service/app/models.py index 535d6a4..822689f 100644 --- a/stacks/finetuning-service/app/models.py +++ b/stacks/finetuning-service/app/models.py @@ -52,6 +52,9 @@ class JobSubmitRequest(BaseModel): load_in_4bit: bool = Field(default=False) load_in_8bit: bool = Field(default=False) do_eval: bool = Field(default=False) + wandb_token: Optional[str] = Field(default=None) + wandb_project: Optional[str] = Field(default=None) + wandb_entity: Optional[str] = Field(default=None) @field_validator("model") @classmethod @@ -94,6 +97,28 @@ def lora_target_modules_must_not_be_empty(cls, v: list[str]) -> list[str]: ) return v + @model_validator(mode="after") + def wandb_fields_are_consistent(self) -> Self: + """Validate wandb field combinations are coherent. + + Returns: + The validated model instance. + + Raises: + ValueError: If wandb fields are provided in an invalid + combination. + """ + has_token = bool(self.wandb_token) + has_project = bool(self.wandb_project) + has_entity = bool(self.wandb_entity) + if has_entity and not has_project: + raise ValueError("wandb_entity requires wandb_project to be set.") + if has_project and not has_token: + raise ValueError("wandb_project requires wandb_token.") + if has_token and not has_project: + raise ValueError("wandb_token requires wandb_project.") + return self + @model_validator(mode="after") def quantisation_modes_are_mutually_exclusive(self) -> Self: """Validate that 4-bit and 8-bit quantisation are not both set. diff --git a/stacks/finetuning-service/worker.py b/stacks/finetuning-service/worker.py index c71ffb5..58cd6b2 100644 --- a/stacks/finetuning-service/worker.py +++ b/stacks/finetuning-service/worker.py @@ -85,6 +85,19 @@ def build_axolotl_config( if user_config.get("do_eval") else {"eval_strategy": "no"} ), + **( + { + "use_wandb": True, + "wandb_project": user_config["wandb_project"], + **( + {"wandb_entity": user_config["wandb_entity"]} + if user_config.get("wandb_entity") + else {} + ), + } + if user_config.get("wandb_project") + else {} + ), "num_epochs": user_config["num_epochs"], "learning_rate": user_config["learning_rate"], "micro_batch_size": user_config["micro_batch_size"], @@ -114,6 +127,7 @@ def run_job( hf_dataset: str, hf_token: str, hub_model_id: str, + wandb_token: Optional[str], user_config: dict, ) -> None: """Execute a fine-tuning job. @@ -127,6 +141,7 @@ def run_job( hf_dataset: The HuggingFace dataset repo path. hf_token: The HuggingFace token for dataset and Hub access. hub_model_id: The destination HuggingFace repo for the adapter. + wandb_token: Optional Weights & Biases API key. user_config: The user-provided training configuration. """ job_dir = WORK_DIR / job_id @@ -138,6 +153,8 @@ def run_job( job_id, model, hf_dataset, hub_model_id, user_config ) env = {**os.environ, "HF_TOKEN": hf_token} + if wandb_token: + env["WANDB_API_KEY"] = wandb_token subprocess.run( ["axolotl", "train", str(config_path)], env=env, @@ -186,6 +203,7 @@ def main() -> None: job.hf_dataset, job.hf_token, job.hub_model_id, + job.wandb_token, user_config, ) else: From 97394f60cdaae097fee471711c2f0feb27ea8fd8 Mon Sep 17 00:00:00 2001 From: Fin Griffin Date: Mon, 27 Apr 2026 14:46:09 +0100 Subject: [PATCH 9/9] feat(docs): add ADR for axolotl implementation --- .../013-axolotl-training-implementation.md | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 docs/ADRs/finetuning/013-axolotl-training-implementation.md diff --git a/docs/ADRs/finetuning/013-axolotl-training-implementation.md b/docs/ADRs/finetuning/013-axolotl-training-implementation.md new file mode 100644 index 0000000..cef583d --- /dev/null +++ b/docs/ADRs/finetuning/013-axolotl-training-implementation.md @@ -0,0 +1,60 @@ +# ADR-013. Axolotl Training Implementation + +Date: 2026-04-27 +Status: Proposed + +## Context + +With the service skeleton established in ADR-012, the next decisions concerned the actual training pipeline: how Axolotl is invoked, how GPU utilisation is maximised on H100 hardware, how evaluation is handled and how optional integrations (Weights & Biases) are exposed to users. + +## Decision + +### Axolotl invocation + +Axolotl is invoked as a subprocess (`axolotl train `) rather than imported as a library. This keeps Axolotl's CUDA environment self-contained. + +A per job Axolotl config is generated at runtime by merging a service level base config (`config.yaml`, mounted read-only) with the user's job parameters. The result is written to a temp directory (`/tmp/finetuning/{job_id}/`) and cleaned up unconditionally in a `finally` block after the subprocess exits. + +### Worker as a separate container + +The worker runs in its own container using the Axolotl base image (`axolotlai/axolotl:main-py3.12-cu130-2.10.0`). The API uses a separate lightweight Python image. Merging them would require the API to carry the full Axolotl image (several GB) for no benefit. + +### Flash Attention 4 + +The Axolotl image ships Flash Attention 2 (FA2). On CUDA 13 / H100 hardware, FA2 produced a `CUBLAS_STATUS_INVALID_VALUE` error in the RoPE computation during evaluation, crashing jobs before training began. Installing Flash Attention 4 (`flash-attn-4[cu13]==4.0.0b10`) resolved this. FA4 is the architecturally correct choice for Hopper GPUs (H100) and is explicitly recommended by Axolotl in its startup logs for this hardware. The `[cu13]` extra selects the CUDA 13 wheel. + +FA4 is pinned to a specific beta version. Upgrading is a deliberate decision identical to how we pin the Axolotl base image. + +### Sample packing + +Sample packing is enabled by default (`sample_packing: true`) and disabled for evaluation (`eval_sample_packing: false`). Without it, sequences are padded individually to `sequence_len`, resulting in approximately 55% padding waste on typical instruction-tuning datasets. With sample packing, multiple conversations are packed end-to-end into each sequence slot using Flash Attention masking to prevent cross conversation attention. This improved trainable token density to ~65% and GPU throughput ~9x on our test dataset. + +Sample packing is not applied during evaluation: the eval set is usually small and packing adds complexity without meaningfully improving evaluation speed. + +### Default sequence length + +The default `sequence_len` is 2048. The original default of 512 dropped approximately 37% of training samples from our representative dataset (max sequence length ~1716 tokens). 2048 retains all sequences and, combined with sample packing, allows more conversations per packed slot. Users may override this per job. + +### Evaluation control + +Evaluation is opt-in. Jobs default to `do_eval: false`, which injects `eval_strategy: "no"` into the generated Axolotl config, overriding the base config's `eval_strategy: epoch`. When `do_eval: true`, the `validation` split of the user's dataset is used and evaluation runs at the end of each epoch. + +This is opt-in rather than opt-out because many Hugging Face datasets do not include a `validation` split; silently failing a job because the split is absent is a worse experience than requiring users to explicitly request evaluation. SDK documentation will specify that users must include a `validation` split if they set `do_eval: true`. + +### Weights & Biases integration + +Optional wandb logging is supported via three fields: `wandb_token`, `wandb_project`, and `wandb_entity`. The token is handled identically to the HF token: stored in the database, passed to the Axolotl subprocess as `WANDB_API_KEY` and cleared to `NULL` on job completion. + +Validation rules: +- `wandb_project` and `wandb_token` must be provided together. +- `wandb_entity` (a team/organisation name) may be omitted; wandb defaults to the user's personal account. +- Providing `wandb_entity` without `wandb_project` is rejected. + +`wandb_entity` alone being omitted is the only partially specified combination that is permitted, reflecting wandb's own behaviour. + +## Consequences + +- Training jobs on H100 / CUDA 13 hardware work with FA4. +- Sample packing significantly improves GPU utilisation but compresses the effective number of training steps per epoch. For small datasets, users should be aware that a large `micro_batch_size` relative to the dataset size can result in very few optimiser steps per epoch. +- Evaluation requires users to know their dataset structure. No automatic detection of available splits is performed. +- Wandb tokens are treated as sensitive credentials and cleared after use, consistent with the HF token policy established in ADR-011.