Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions docs/ADRs/finetuning/013-axolotl-training-implementation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ADR-013. Axolotl Training Implementation

Date: 2026-04-27
Status: Proposed

## Context

With the service skeleton established in ADR-012, the next decisions concerned the actual training pipeline: how Axolotl is invoked, how GPU utilisation is maximised on H100 hardware, how evaluation is handled and how optional integrations (Weights & Biases) are exposed to users.

## Decision

### Axolotl invocation

Axolotl is invoked as a subprocess (`axolotl train <config.yaml>`) rather than imported as a library. This keeps Axolotl's CUDA environment self-contained.

A per job Axolotl config is generated at runtime by merging a service level base config (`config.yaml`, mounted read-only) with the user's job parameters. The result is written to a temp directory (`/tmp/finetuning/{job_id}/`) and cleaned up unconditionally in a `finally` block after the subprocess exits.

### Worker as a separate container

The worker runs in its own container using the Axolotl base image (`axolotlai/axolotl:main-py3.12-cu130-2.10.0`). The API uses a separate lightweight Python image. Merging them would require the API to carry the full Axolotl image (several GB) for no benefit.

### Flash Attention 4

The Axolotl image ships Flash Attention 2 (FA2). On CUDA 13 / H100 hardware, FA2 produced a `CUBLAS_STATUS_INVALID_VALUE` error in the RoPE computation during evaluation, crashing jobs before training began. Installing Flash Attention 4 (`flash-attn-4[cu13]==4.0.0b10`) resolved this. FA4 is the architecturally correct choice for Hopper GPUs (H100) and is explicitly recommended by Axolotl in its startup logs for this hardware. The `[cu13]` extra selects the CUDA 13 wheel.

FA4 is pinned to a specific beta version. Upgrading is a deliberate decision identical to how we pin the Axolotl base image.

### Sample packing

Sample packing is enabled by default (`sample_packing: true`) and disabled for evaluation (`eval_sample_packing: false`). Without it, sequences are padded individually to `sequence_len`, resulting in approximately 55% padding waste on typical instruction-tuning datasets. With sample packing, multiple conversations are packed end-to-end into each sequence slot using Flash Attention masking to prevent cross conversation attention. This improved trainable token density to ~65% and GPU throughput ~9x on our test dataset.

Sample packing is not applied during evaluation: the eval set is usually small and packing adds complexity without meaningfully improving evaluation speed.

### Default sequence length

The default `sequence_len` is 2048. The original default of 512 dropped approximately 37% of training samples from our representative dataset (max sequence length ~1716 tokens). 2048 retains all sequences and, combined with sample packing, allows more conversations per packed slot. Users may override this per job.

### Evaluation control

Evaluation is opt-in. Jobs default to `do_eval: false`, which injects `eval_strategy: "no"` into the generated Axolotl config, overriding the base config's `eval_strategy: epoch`. When `do_eval: true`, the `validation` split of the user's dataset is used and evaluation runs at the end of each epoch.

This is opt-in rather than opt-out because many Hugging Face datasets do not include a `validation` split; silently failing a job because the split is absent is a worse experience than requiring users to explicitly request evaluation. SDK documentation will specify that users must include a `validation` split if they set `do_eval: true`.

### Weights & Biases integration

Optional wandb logging is supported via three fields: `wandb_token`, `wandb_project`, and `wandb_entity`. The token is handled identically to the HF token: stored in the database, passed to the Axolotl subprocess as `WANDB_API_KEY` and cleared to `NULL` on job completion.

Validation rules:
- `wandb_project` and `wandb_token` must be provided together.
- `wandb_entity` (a team/organisation name) may be omitted; wandb defaults to the user's personal account.
- Providing `wandb_entity` without `wandb_project` is rejected.

`wandb_entity` alone being omitted is the only partially specified combination that is permitted, reflecting wandb's own behaviour.

## Consequences

- Training jobs on H100 / CUDA 13 hardware work with FA4.
- Sample packing significantly improves GPU utilisation but compresses the effective number of training steps per epoch. For small datasets, users should be aware that a large `micro_batch_size` relative to the dataset size can result in very few optimiser steps per epoch.
- Evaluation requires users to know their dataset structure. No automatic detection of available splits is performed.
- Wandb tokens are treated as sensitive credentials and cleared after use, consistent with the HF token policy established in ADR-011.
2 changes: 2 additions & 0 deletions stacks/finetuning-service/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# -----------------------------------------------------------------------------
# Fine-Tuning Service
# -----------------------------------------------------------------------------
DEVICE=3
# Multi-gpu not currently supported
FINETUNING_PORT=8005
MAX_JOB_DURATION_HOURS=4

Expand Down
11 changes: 0 additions & 11 deletions stacks/finetuning-service/Dockerfile

This file was deleted.

19 changes: 19 additions & 0 deletions stacks/finetuning-service/Dockerfile.api
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# =============================================================================
# Dockerfile.api - Fine-Tuning Service API
# =============================================================================
#
# Lightweight image for the FastAPI service only.
# The Axolotl worker uses a separate image (Dockerfile.worker).
#
# =============================================================================

FROM python:3.12-slim

WORKDIR /app

COPY requirements.api.txt .
RUN pip install -r requirements.api.txt

COPY app/ ./app/

CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${FINETUNING_PORT}"]
22 changes: 22 additions & 0 deletions stacks/finetuning-service/Dockerfile.worker
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# =============================================================================
# Dockerfile.worker - Fine-Tuning Worker
# =============================================================================
#
# Axolotl-based image for the queue worker that runs training jobs.
# The FastAPI service uses a separate lightweight image (Dockerfile.api).
#
# Update the tag deliberately when upgrading Axolotl.
#
# =============================================================================

FROM axolotlai/axolotl:main-py3.12-cu130-2.10.0

WORKDIR /app

COPY requirements.worker.txt .
RUN pip install -r requirements.worker.txt

COPY app/ ./app/
COPY worker.py .

CMD ["python", "worker.py"]
26 changes: 26 additions & 0 deletions stacks/finetuning-service/app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Service configuration loader."""

from pathlib import Path

import yaml

CONFIG_PATH = Path("/app/config.yaml")


def load_config() -> dict:
"""Load the service configuration from disk.

Returns:
The parsed configuration dictionary.
"""
with CONFIG_PATH.open() as f:
return yaml.safe_load(f)


def get_allowed_models() -> list[str]:
"""Return the list of models permitted for fine-tuning.

Returns:
A list of allowed Hugging Face model repo paths.
"""
return load_config().get("allowed_models", [])
133 changes: 125 additions & 8 deletions stacks/finetuning-service/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sqlite3
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Generator, Optional
Expand All @@ -13,6 +14,24 @@
DB_PATH = Path("/data/jobs.db")


@dataclass
class JobDetail:
"""Internal job representation including sensitive fields.

Not exposed via the API. Used by the worker to access the
HF token and full training config.
"""

id: str
model: str
hf_dataset: str
hf_token: str
hub_model_id: str
wandb_token: Optional[str]
config: str
status: JobStatus


@contextmanager
def get_connection() -> Generator[sqlite3.Connection, None, None]:
"""Yield a database connection with row factory configured.
Expand Down Expand Up @@ -43,7 +62,7 @@ def _row_to_job(row: sqlite3.Row) -> JobResponse:
status=JobStatus(row["status"]),
model=row["model"],
hf_dataset=row["hf_dataset"],
suffix=row["suffix"],
hub_model_id=row["hub_model_id"],
created_at=datetime.fromisoformat(row["created_at"]),
started_at=(
datetime.fromisoformat(row["started_at"])
Expand All @@ -68,7 +87,9 @@ def init_db() -> None:
status TEXT NOT NULL,
model TEXT NOT NULL,
hf_dataset TEXT NOT NULL,
suffix TEXT,
hf_token TEXT,
hub_model_id TEXT NOT NULL,
wandb_token TEXT,
config TEXT NOT NULL,
created_at TEXT NOT NULL,
started_at TEXT,
Expand Down Expand Up @@ -105,23 +126,39 @@ def create_job(request: JobSubmitRequest) -> JobResponse:
now = datetime.now(timezone.utc).isoformat()
config = json.dumps(
{
# TODO: Add hyperparameters/config keys
"num_epochs": request.num_epochs,
"learning_rate": request.learning_rate,
"micro_batch_size": request.micro_batch_size,
"gradient_accumulation_steps": (
request.gradient_accumulation_steps
),
"sequence_len": request.sequence_len,
"lora_r": request.lora_r,
"lora_dropout": request.lora_dropout,
"lora_target_modules": request.lora_target_modules,
"load_in_4bit": request.load_in_4bit,
"load_in_8bit": request.load_in_8bit,
"do_eval": request.do_eval,
"wandb_project": request.wandb_project,
"wandb_entity": request.wandb_entity,
}
)
with get_connection() as conn:
conn.execute(
"""
INSERT INTO jobs (
id, status, model, hf_dataset,
suffix, config, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
id, status, model, hf_dataset, hf_token,
hub_model_id, wandb_token, config, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
job_id,
JobStatus.QUEUED,
request.model,
request.hf_dataset,
request.suffix,
request.hf_token,
request.hub_model_id,
request.wandb_token,
config,
now,
),
Expand All @@ -131,7 +168,7 @@ def create_job(request: JobSubmitRequest) -> JobResponse:
status=JobStatus.QUEUED,
model=request.model,
hf_dataset=request.hf_dataset,
suffix=request.suffix,
hub_model_id=request.hub_model_id,
created_at=datetime.fromisoformat(now),
)

Expand Down Expand Up @@ -186,3 +223,83 @@ def cancel_job(job_id: str) -> Optional[JobResponse]:
(JobStatus.CANCELLED, job_id, JobStatus.QUEUED),
)
return get_job(job_id)


def get_next_queued_job() -> Optional[JobDetail]:
"""Fetch the oldest queued job for the worker to process.

Returns:
A JobDetail if a queued job exists, otherwise None.
"""
with get_connection() as conn:
row = conn.execute(
"""
SELECT * FROM jobs
WHERE status = ?
ORDER BY created_at ASC
LIMIT 1
""",
(JobStatus.QUEUED,),
).fetchone()
if not row:
return None
return JobDetail(
id=row["id"],
model=row["model"],
hf_dataset=row["hf_dataset"],
hf_token=row["hf_token"],
hub_model_id=row["hub_model_id"],
wandb_token=row["wandb_token"],
config=row["config"],
status=JobStatus(row["status"]),
)


def claim_job(job_id: str) -> bool:
"""Atomically mark a queued job as running.

Args:
job_id: The UUID of the job to claim.

Returns:
True if the job was successfully claimed, False if it
was already claimed by another process.
"""
now = datetime.now(timezone.utc).isoformat()
with get_connection() as conn:
result = conn.execute(
"""
UPDATE jobs SET status = ?, started_at = ?
WHERE id = ? AND status = ?
""",
(JobStatus.RUNNING, now, job_id, JobStatus.QUEUED),
)
return result.rowcount > 0


def complete_job(
job_id: str,
status: JobStatus,
error_message: Optional[str] = None,
) -> None:
"""Mark a job as completed and clear its tokens.

Args:
job_id: The UUID of the job to complete.
status: The final status (succeeded or failed).
error_message: Optional error description for failed jobs.
"""
now = datetime.now(timezone.utc).isoformat()
with get_connection() as conn:
conn.execute(
"""
UPDATE jobs
SET status = ?,
completed_at = ?,
error_message = ?,
hf_token = NULL,
wandb_token = NULL
WHERE id = ?
""",
(status, now, error_message, job_id),
)
Loading