diff --git a/dev/yes-no-maybe-megatron.py b/dev/yes-no-maybe-megatron.py new file mode 100644 index 000000000..ea7a669fa --- /dev/null +++ b/dev/yes-no-maybe-megatron.py @@ -0,0 +1,79 @@ +import asyncio +from itertools import permutations +import os + +from dotenv import load_dotenv +import openai + +import art +from art.megatron import MegatronBackend + + +async def rollout( + client: openai.AsyncOpenAI, model_name: str, prompt: str +) -> art.Trajectory: + messages: art.Messages = [{"role": "user", "content": prompt}] + chat_completion = await client.chat.completions.create( + messages=messages, model=model_name, max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +def with_quotes(w: str) -> str: + return f"'{w}'" + + +async def main(): + load_dotenv() + + backend = MegatronBackend() + base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507") + model = art.TrainableModel( + name=os.environ.get("MODEL_NAME", "megatron-001"), + project="yes-no-maybe-megatron", + base_model=base_model, + ) + await model.register(backend) + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in ( + list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + ) + ] + + openai_client = model.openai_client() + max_steps = int(os.environ.get("NUM_STEPS", "20")) + start_step = await model.get_step() + + for step in range(start_step, start_step + max_steps): + print(f"\n=== Step {step + 1} ===") + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(openai_client, model.name, prompt) for _ in range(32) + ) + for prompt in prompts + ) + ) + await model.train( + train_groups, + config=art.TrainConfig(learning_rate=1e-4), + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index f4e0f3239..dcf2ec176 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,6 +152,8 @@ allowed-unresolved-imports = [ # plotting deps "matplotlib.**", "seaborn.**", + # megatron deps + "megatron.**", ] [dependency-groups] diff --git a/scripts/setup.sh b/scripts/setup.sh index d1755efee..d7c8fd209 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -13,6 +13,31 @@ if [ -f .env ]; then done < .env fi +if ! command -v sudo >/dev/null 2>&1; then + sudo_path="/usr/local/bin/sudo" + if [ ! -w /usr/local/bin ]; then + sudo_path="$HOME/.local/bin/sudo" + mkdir -p "$HOME/.local/bin" + export PATH="$HOME/.local/bin:$PATH" + fi + + cat <<'EOF' > "$sudo_path" +#!/bin/sh +exec "$@" +EOF + chmod +x "$sudo_path" +fi + +need_pkgs=() +command -v git >/dev/null 2>&1 || need_pkgs+=("git") +command -v curl >/dev/null 2>&1 || need_pkgs+=("curl") +command -v tmux >/dev/null 2>&1 || need_pkgs+=("tmux") + +if [ "${#need_pkgs[@]}" -gt 0 ]; then + apt-get update + apt-get install -y "${need_pkgs[@]}" +fi + # Configure git user name and email git config --global user.name "${GIT_USER_NAME}" git config --global user.email "${GIT_USER_EMAIL}" @@ -29,14 +54,17 @@ else fi # Install astral-uv -sudo snap install --classic astral-uv +if ! command -v uv >/dev/null 2>&1; then + if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then + echo "Failed to install uv." >&2 + exit 1 + fi + export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH" +fi # Update uv uv self update -# Install tmux -apt install tmux -y - # Sync the dependencies if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then uv sync --all-extras diff --git a/skypilot-config.yaml b/skypilot-config.yaml index 638faec5b..7b2b3b73a 100644 --- a/skypilot-config.yaml +++ b/skypilot-config.yaml @@ -383,6 +383,7 @@ workdir: . resources: accelerators: ["H100-SXM:1", "H100:1", "A100-80GB:1"] + image_id: docker:pytorch/pytorch:2.9.0-cuda12.8-cudnn9-devel ports: - 7999 # main ART server - 8000 # vLLM server diff --git a/src/art/local/backend.py b/src/art/local/backend.py index dfc5a5c6f..9fb2a3649 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -2,6 +2,7 @@ import json import math import os +import shutil import subprocess from types import TracebackType from typing import AsyncIterator, Iterable, Literal, cast @@ -570,20 +571,22 @@ async def _train_model( get_model_dir(model=model, art_path=self._path), next_step ) - # If the current checkpoint exists, rename it to the next step + # If the current checkpoint exists, copy it to the next step if os.path.exists(current_checkpoint_dir): - os.rename(current_checkpoint_dir, next_checkpoint_dir) + shutil.copytree( + current_checkpoint_dir, + next_checkpoint_dir, + dirs_exist_ok=True, + ) print( f"Advanced step from {current_step} to {next_step} (no training occurred)" ) try: - # Register the renamed checkpoint as a new LoRA adapter + # Register the copied checkpoint as a new LoRA adapter # so it's available for inference at the new step - from ..unsloth.service import UnslothService - - if isinstance(service, UnslothService): - await service.register_lora_for_step( + if hasattr(service, "register_lora_for_step"): + await service.register_lora_for_step( # type: ignore[attr-defined] next_step, next_checkpoint_dir ) except ModuleNotFoundError: diff --git a/src/art/megatron/__init__.py b/src/art/megatron/__init__.py new file mode 100644 index 000000000..07107df61 --- /dev/null +++ b/src/art/megatron/__init__.py @@ -0,0 +1,3 @@ +from .backend import MegatronBackend + +__all__ = ["MegatronBackend"] diff --git a/src/art/megatron/backend.py b/src/art/megatron/backend.py new file mode 100644 index 000000000..1ebdc7a17 --- /dev/null +++ b/src/art/megatron/backend.py @@ -0,0 +1,39 @@ +from mp_actors import move_to_child_process + +from ..local.backend import LocalBackend +from ..local.service import ModelService +from ..model import TrainableModel +from ..utils.output_dirs import get_model_dir + + +class MegatronBackend(LocalBackend): + def __init__( + self, + *, + in_process: bool = False, + path: str | None = None, + ) -> None: + super().__init__(in_process=in_process, path=path) + + async def _get_service(self, model: TrainableModel) -> ModelService: + from ..dev.get_model_config import get_model_config + from .service import MegatronService + + if model.name not in self._services: + config = get_model_config( + base_model=model.base_model, + output_dir=get_model_dir(model=model, art_path=self._path), + config=model._internal_config, + ) + self._services[model.name] = MegatronService( + model_name=model.name, + base_model=model.base_model, + config=config, + output_dir=get_model_dir(model=model, art_path=self._path), + ) + if not self._in_process: + self._services[model.name] = move_to_child_process( + self._services[model.name], + process_name="megatron-service", + ) + return self._services[model.name] diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py new file mode 100644 index 000000000..3ba97a771 --- /dev/null +++ b/src/art/megatron/lora.py @@ -0,0 +1,445 @@ +import math +from typing import Sequence + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.core import parallel_state as ps +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TELayerNormColumnParallelLinear, + TERowParallelGroupedLinear, + TERowParallelLinear, +) +from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.moe import grouped_gemm_util +from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.transformer_layer import TransformerLayer +import torch + + +class LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + in_features: int, + out_features: int, + rank: int, + alpha: float, + dtype: torch.dtype, + device: torch.device, + num_local_experts: int = 1, + ) -> None: + super().__init__() + assert num_local_experts == 1 or "{expert}" in adapter_model_prefix, ( + "adapter_model_prefix must contain the '{expert}' format placeholder if num_local_experts > 1" + ) + self.adapter_model_prefix = adapter_model_prefix + self.scale = alpha / rank + self.A_T = torch.nn.Parameter( + torch.zeros( + num_local_experts, in_features, rank, dtype=dtype, device=device + ).squeeze(0) + ) + self.B_T = torch.nn.Parameter( + torch.zeros( + num_local_experts, rank, out_features, dtype=dtype, device=device + ).squeeze(0) + ) + self._expert_offset = ps.get_expert_model_parallel_rank() * num_local_experts + self.reset_lora_parameters() + + @property + def num_local_experts(self) -> int: + return self.A_T.shape[0] if self.A_T.ndim == 3 else 1 + + def reset_lora_parameters(self) -> None: + """Initialize LoRA weights (A=Kaiming, B=zeros) like PEFT defaults.""" + if self.A_T.ndim == 3: + for expert in range(self.A_T.shape[0]): + torch.nn.init.kaiming_uniform_(self.A_T[expert].T, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(self.A_T.T, a=math.sqrt(5)) + torch.nn.init.zeros_(self.B_T) + + def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: + try: + self.load_weights( + adapter_model, + suffix="lora_A", + into=self.A_T, + ) + self.load_weights( + adapter_model, + suffix="lora_B", + into=self.B_T, + ) + except KeyError: + print("Unable to find LoRA weights for", self.adapter_model_prefix) + self.reset_lora_parameters() + + def load_weights( + self, + adapter_model: dict[str, torch.Tensor], + *, + suffix: str, + into: torch.nn.Parameter, + ) -> None: + self.load_weight( + ( + torch.stack( + [ + adapter_model[ + f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{suffix}.weight" + ].T + for expert in range(self.num_local_experts) + ] + ) + if self.num_local_experts > 1 + else adapter_model[f"{self.adapter_model_prefix}.{suffix}.weight"].T + ), + into=into, + ) + + def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + setattr(into, "sharded", False) + tp_world_size = ps.get_tensor_model_parallel_world_size() + tp_rank = ps.get_tensor_model_parallel_rank() + for axis in (-2, -1): + if weight.shape[axis] == into.shape[axis]: + continue + # assume our param is tensor sharded along this axis + assert weight.shape[axis] // tp_world_size == into.shape[axis], ( + f"Weight shape {weight.shape} does not match into shape {into.shape} along axis {axis}" + ) + s = into.shape[axis] + weight = weight.narrow(axis, tp_rank * s, s) + setattr(into, "sharded", True) + into.data.copy_(weight) + into.requires_grad = True + + def sharded_lora_state_dict(self) -> dict[str, torch.Tensor]: + if self.num_local_experts > 1: + if ps.get_expert_data_parallel_rank() != 0: + return {} + return { + f"{self.adapter_model_prefix.format(expert=expert + self._expert_offset)}.{key}": param.data[ + expert + ].T + for expert in range(self.num_local_experts) + for key, param in ( + ("lora_A.weight", self.A_T), + ("lora_B.weight", self.B_T), + ) + } + if ps.get_data_parallel_rank() != 0 or torch.all(self.A_T == 0): + return {} + return { + f"{self.adapter_model_prefix}.{key}": param.data.T + for key, param in ( + ("lora_A.weight", self.A_T), + ("lora_B.weight", self.B_T), + ) + if getattr(param, "sharded", False) + or ps.get_tensor_model_parallel_rank() == 0 + } + + def forward( + self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None + ) -> torch.Tensor: + if tokens_per_expert is not None: + assert self.num_local_experts > 1, ( + "tokens_per_expert is only supported if num_local_experts > 1" + ) + bsz = tokens_per_expert + if isinstance(bsz, list): + bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") + # If no tokens routed locally, return zeros + if isinstance(bsz, torch.Tensor) and int(torch.count_nonzero(bsz)) == 0: + return x.new_zeros((x.shape[0], self.B_T.shape[-1])) + tmp = grouped_gemm_util.ops.gmm(x, self.A_T, bsz, trans_b=False) # type: ignore[attr-defined] + out = grouped_gemm_util.ops.gmm(tmp, self.B_T, bsz, trans_b=False) # type: ignore[attr-defined] + return out * self.scale + else: + return ((x @ self.A_T) @ self.B_T) * self.scale + + +class SelfAttentionLinearProjLoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_proj: TERowParallelLinear, + rank: int, + alpha: float, + provider: GPTModelProvider, + ) -> None: + super().__init__() + self.provider = provider + self.linear_proj = linear_proj + assert isinstance(linear_proj.weight, torch.Tensor) + self.lora = LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear_proj.in_features, + out_features=linear_proj.out_features, + rank=rank, + alpha=alpha, + dtype=linear_proj.weight.dtype, + device=linear_proj.weight.device, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + base_output, bias_output = self.linear_proj(x) + assert isinstance(base_output, torch.Tensor) + assert isinstance(bias_output, (torch.Tensor, type(None))) + lora_output = self.lora(x) + if ( + self.provider.sequence_parallel + and self.provider.tensor_model_parallel_size > 1 + ): + tp_rank = ps.get_tensor_model_parallel_rank() + tokens_per_rank = base_output.shape[0] + start = tp_rank * tokens_per_rank + end = start + tokens_per_rank + lora_output = lora_output[start:end] + return base_output + lora_output, bias_output + + +class SelfAttentionLinearQKVLoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_qkv: TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + provider: GPTModelProvider, + ) -> None: + super().__init__() + self.provider = provider + linear_qkv.return_layernorm_output = True + linear_qkv.return_layernorm_output_gathered = True + self.linear_qkv = linear_qkv + assert self.provider.kv_channels is not None + assert self.provider.num_query_groups is not None + assert self.provider.num_attention_heads is not None + q_out_features = self.provider.kv_channels * self.provider.num_attention_heads + kv_out_features = self.provider.kv_channels * self.provider.num_query_groups + tp_world_size = ps.get_tensor_model_parallel_world_size() + assert kv_out_features % tp_world_size == 0, ( + "kv_out_features must be divisible by tensor parallel size" + ) + assert q_out_features % tp_world_size == 0, ( + "q_out_features must be divisible by tensor parallel size" + ) + q_out_features_per_rank = q_out_features // tp_world_size + kv_out_features_per_rank = kv_out_features // tp_world_size + assert isinstance(linear_qkv.weight, torch.Tensor) + self.q_proj_lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.q_proj", + in_features=linear_qkv.in_features, + out_features=q_out_features_per_rank, + rank=rank, + alpha=alpha, + dtype=linear_qkv.weight.dtype, + device=linear_qkv.weight.device, + ) + self.k_proj_lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.k_proj", + in_features=linear_qkv.in_features, + out_features=kv_out_features_per_rank, + rank=rank, + alpha=alpha, + dtype=linear_qkv.weight.dtype, + device=linear_qkv.weight.device, + ) + self.v_proj_lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.v_proj", + in_features=linear_qkv.in_features, + out_features=kv_out_features_per_rank, + rank=rank, + alpha=alpha, + dtype=linear_qkv.weight.dtype, + device=linear_qkv.weight.device, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + ( + linear_output_and_layernorm_output, + bias, + ) = self.linear_qkv(x) + linear_output, layernorm_output = linear_output_and_layernorm_output + assert isinstance(linear_output, torch.Tensor) + assert isinstance(layernorm_output, torch.Tensor) + assert isinstance(bias, (torch.Tensor, type(None))) + + query = self.q_proj_lora(layernorm_output) + key = self.k_proj_lora(layernorm_output) + value = self.v_proj_lora(layernorm_output) + + assert isinstance(self.linear_qkv.config.kv_channels, int) + query_4d = query.reshape( + query.shape[0], query.shape[1], -1, self.linear_qkv.config.kv_channels + ) + key_4d = key.reshape( + key.shape[0], key.shape[1], -1, self.linear_qkv.config.kv_channels + ) + value_4d = value.reshape( + value.shape[0], value.shape[1], -1, self.linear_qkv.config.kv_channels + ) + + qkv_4d = torch.cat([query_4d, key_4d, value_4d], dim=2) + adapter_output = qkv_4d.reshape(qkv_4d.shape[0], qkv_4d.shape[1], -1) + + return linear_output + adapter_output, bias + + +class MLPExpertsLinearFC1LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelGroupedLinear, + rank: int, + alpha: float, + num_local_experts: int, + ) -> None: + super().__init__() + assert linear_fc1 is not None + self.linear_fc1 = linear_fc1 + assert isinstance(linear_fc1.weight0, torch.Tensor) + self.gate_lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_proj", + in_features=linear_fc1.in_features, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + dtype=linear_fc1.weight0.dtype, + device=linear_fc1.weight0.device, + num_local_experts=num_local_experts, + ) + self.up_lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.up_proj", + in_features=linear_fc1.in_features, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + dtype=linear_fc1.weight0.dtype, + device=linear_fc1.weight0.device, + num_local_experts=num_local_experts, + ) + + def forward( + self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + base_out, bias_out = self.linear_fc1(x, tokens_per_expert) + gate_out = self.gate_lora(x, tokens_per_expert=tokens_per_expert) + up_out = self.up_lora(x, tokens_per_expert=tokens_per_expert) + adapter_out = torch.cat([gate_out, up_out], dim=1) + return base_out + adapter_out, bias_out + + +class MLPExpertsLinearFC2LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_fc2: TERowParallelGroupedLinear, + rank: int, + alpha: float, + num_local_experts: int, + ) -> None: + super().__init__() + assert linear_fc2 is not None + assert isinstance(linear_fc2.weight0, torch.Tensor) + self.linear_fc2 = linear_fc2 + self.lora = LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.down_proj", + in_features=linear_fc2.in_features, + out_features=linear_fc2.out_features, + rank=rank, + alpha=alpha, + dtype=linear_fc2.weight0.dtype, + device=linear_fc2.weight0.device, + num_local_experts=num_local_experts, + ) + + def forward( + self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + base_out, bias_out = self.linear_fc2(x, tokens_per_expert) + adapter_out = self.lora(x, tokens_per_expert=tokens_per_expert) + return base_out + adapter_out, bias_out + + +def apply_lora_adapters( + model: Sequence[torch.nn.Module], + provider: GPTModelProvider, +) -> None: + with torch.no_grad(): + for chunk in model: + for module in chunk.modules(): + if isinstance(module, TransformerLayer): + adapter_model_prefix = ( + f"base_model.model.model.layers.{module.layer_number - 1}" + ) + assert isinstance(module.self_attention, SelfAttention) + self_attention_linear_proj = module.self_attention.linear_proj + if not isinstance(self_attention_linear_proj, TERowParallelLinear): + self_attention_linear_proj = ( + self_attention_linear_proj.linear_proj + ) + assert isinstance( + self_attention_linear_proj, TERowParallelLinear + ) + module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=self_attention_linear_proj, + rank=1, + alpha=32, + provider=provider, + ) + self_attention_linear_qkv = module.self_attention.linear_qkv + if not isinstance( + self_attention_linear_qkv, TELayerNormColumnParallelLinear + ): + self_attention_linear_qkv = self_attention_linear_qkv.linear_qkv + assert isinstance( + self_attention_linear_qkv, TELayerNormColumnParallelLinear + ) + module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn", + linear_qkv=self_attention_linear_qkv, + rank=1, + alpha=32, + provider=provider, + ) + assert isinstance(module.mlp.experts, TEGroupedMLP) + mlp_experts_linear_fc1 = module.mlp.experts.linear_fc1 + if not isinstance( + mlp_experts_linear_fc1, + TEColumnParallelGroupedLinear, # type: ignore + ): + mlp_experts_linear_fc1 = mlp_experts_linear_fc1.linear_fc1 + assert isinstance( + mlp_experts_linear_fc1, + TEColumnParallelGroupedLinear, # type: ignore + ) + module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) + mlp_experts_linear_fc2 = module.mlp.experts.linear_fc2 + if not isinstance( + mlp_experts_linear_fc2, + TERowParallelGroupedLinear, # type: ignore + ): + mlp_experts_linear_fc2 = mlp_experts_linear_fc2.linear_fc2 + assert isinstance( + mlp_experts_linear_fc2, + TERowParallelGroupedLinear, # type: ignore + ) + module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=1, + alpha=32, + num_local_experts=module.mlp.experts.num_local_experts, + ) diff --git a/src/art/megatron/offload.py b/src/art/megatron/offload.py new file mode 100644 index 000000000..37e11be23 --- /dev/null +++ b/src/art/megatron/offload.py @@ -0,0 +1,138 @@ +from dataclasses import dataclass, field +import gc +from typing import Any, Sequence + +import torch + + +@dataclass +class OffloadState: + pinned_buffers: dict[str, torch.Tensor] = field(default_factory=dict) + is_offloaded: bool = False + + +def offload_to_cpu( + model: Sequence[torch.nn.Module], + optimizer: Any, + rank: int, + offload_state: OffloadState, +) -> None: + """Offload model params and optimizer state to CPU pinned memory.""" + if offload_state.is_offloaded: + return + pinned_buffers = offload_state.pinned_buffers + + for chunk in model: + for module in chunk.modules(): + for attr in ["A_T", "B_T"]: + if not hasattr(module, attr): + continue + param = getattr(module, attr) + if ( + not isinstance(param, torch.nn.Parameter) + or param.device.type != "cuda" + ): + continue + key = f"{id(module)}_{attr}" + if ( + key not in pinned_buffers + or pinned_buffers[key].shape != param.shape + or pinned_buffers[key].dtype != param.dtype + ): + pinned_buffers[key] = torch.empty( + param.shape, dtype=param.dtype, device="cpu", pin_memory=True + ) + pinned_buffers[key].copy_(param.data, non_blocking=True) + param.data = pinned_buffers[key] + + # Offload remaining model parameters (including base weights). + for chunk in model: + for param in chunk.parameters(): + if not isinstance(param, torch.nn.Parameter) or param.device.type != "cuda": + continue + key = f"param_{id(param)}" + if ( + key not in pinned_buffers + or pinned_buffers[key].shape != param.shape + or pinned_buffers[key].dtype != param.dtype + ): + pinned_buffers[key] = torch.empty( + param.shape, dtype=param.dtype, device="cpu", pin_memory=True + ) + pinned_buffers[key].copy_(param.data, non_blocking=True) + param.data = pinned_buffers[key] + + for param_id, opt_state in optimizer.optimizer.state.items(): + for k, v in opt_state.items(): + if isinstance(v, torch.Tensor) and v.device.type == "cuda": + key = f"opt_{id(param_id)}_{k}" + if ( + key not in pinned_buffers + or pinned_buffers[key].shape != v.shape + or pinned_buffers[key].dtype != v.dtype + ): + pinned_buffers[key] = torch.empty( + v.shape, dtype=v.dtype, device="cpu", pin_memory=True + ) + pinned_buffers[key].copy_(v, non_blocking=True) + opt_state[k] = pinned_buffers[key] + + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + offload_state.is_offloaded = True + if rank == 0: + print("Offloaded model params and optimizer to CPU") + + +def reload_to_gpu( + model: Sequence[torch.nn.Module], + optimizer: Any, + rank: int, + offload_state: OffloadState, + device: torch.device | str | None = None, +) -> None: + """Reload model params and optimizer state to GPU.""" + if not offload_state.is_offloaded: + return + + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + else: + device = torch.device(device) + + for chunk in model: + for module in chunk.modules(): + for attr in ["A_T", "B_T"]: + if not hasattr(module, attr): + continue + param = getattr(module, attr) + if ( + not isinstance(param, torch.nn.Parameter) + or param.device.type != "cpu" + ): + continue + gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) + gpu_tensor.copy_(param.data, non_blocking=True) + param.data = gpu_tensor + + # Reload remaining model parameters (including base weights). + for chunk in model: + for param in chunk.parameters(): + if not isinstance(param, torch.nn.Parameter) or param.device.type != "cpu": + continue + gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) + gpu_tensor.copy_(param.data, non_blocking=True) + param.data = gpu_tensor + + for opt_state in optimizer.optimizer.state.values(): + for k, v in opt_state.items(): + if isinstance(v, torch.Tensor) and v.device.type == "cpu": + gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device) + gpu_tensor.copy_(v, non_blocking=True) + opt_state[k] = gpu_tensor + + torch.cuda.synchronize() + offload_state.is_offloaded = False + if rank == 0: + print("Reloaded LoRA params and optimizer to GPU") diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py new file mode 100644 index 000000000..41fb2bf3f --- /dev/null +++ b/src/art/megatron/provider.py @@ -0,0 +1,31 @@ +from megatron.bridge import AutoBridge +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge +from megatron.core.transformer.enums import AttnBackend +import torch + + +def get_provider(model: str) -> GPTModelProvider: + bridge = AutoBridge.from_hf_pretrained( + model, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + assert isinstance(bridge._model_bridge, Qwen3MoEBridge), ( + "Only Qwen3 MoE models are supported" + ) + provider = bridge.to_megatron_provider() + provider.attention_backend = AttnBackend.fused + provider.recompute_granularity = "full" + provider.recompute_method = "uniform" + provider.recompute_num_layers = 1 + provider.tensor_model_parallel_size = min(2, torch.cuda.device_count()) + provider.context_parallel_size = 1 + provider.pipeline_model_parallel_size = 1 + provider.expert_model_parallel_size = torch.cuda.device_count() + provider.expert_tensor_parallel_size = 1 + provider.moe_shared_expert_overlap = True + provider.moe_router_dtype = "fp32" + if provider.tensor_model_parallel_size > 1: + provider.sequence_parallel = True + return provider diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py new file mode 100644 index 000000000..c335d4e2b --- /dev/null +++ b/src/art/megatron/service.py @@ -0,0 +1,345 @@ +import asyncio +from dataclasses import asdict, dataclass +import datetime +from functools import cached_property +import json +import os +from pathlib import Path +import shutil +import subprocess +from typing import AsyncIterator + +from peft.tuners.lora.config import LoraConfig +from pydantic import BaseModel +from safetensors import safe_open +from safetensors.torch import load_file, save_file +import torch +from vllm import AsyncEngineArgs +from vllm.lora.request import LoRARequest +from vllm.v1.engine.async_llm import AsyncLLM + +from .. import dev, types +from ..local.checkpoints import get_last_checkpoint_dir +from ..preprocessing.pack import DiskPackedTensors +from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache +from ..utils.get_model_step import get_step_from_dir +from ..utils.output_dirs import get_step_checkpoint_dir +from ..vllm import get_llm, openai_server_task, run_on_workers + + +class MegatronTrainingJob(BaseModel): + """Job format for communication with train.py""" + + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + + +@dataclass +class MegatronService: + model_name: str + base_model: str + config: dev.InternalModelConfig + output_dir: str + _is_sleeping: bool = False + _latest_step: int = 0 + _lora_id_counter: int = 1 + _megatron_process: asyncio.subprocess.Process | None = None + _optimizer_state_path: str | None = None + + def _next_lora_id(self) -> int: + self._lora_id_counter += 1 + return self._lora_id_counter + + def _get_optimizer_state_path(self) -> str: + if self._optimizer_state_path is not None: + return self._optimizer_state_path + self._optimizer_state_path = os.path.join(self.output_dir, "optimizer_states") + os.makedirs(self._optimizer_state_path, exist_ok=True) + return self._optimizer_state_path + + def _default_lora_adapter_config(self) -> LoraConfig: + # Keep in sync with LoRA settings in megatron/train.py. + return LoraConfig( + r=1, + lora_alpha=32, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + bias="none", + ) + + def _adapter_has_weights(self, lora_path: str) -> bool: + adapter_path = os.path.join(lora_path, "adapter_model.safetensors") + if not os.path.exists(adapter_path): + return False + try: + with safe_open(adapter_path, framework="pt") as adapter_file: + for key in adapter_file.keys(): + tensor = adapter_file.get_tensor(key) + if torch.any(tensor != 0): + return True + except Exception: + return False + return False + + def _create_identity_lora(self, lora_path: str) -> None: + # Create an identity (zero) LoRA using PEFT so vLLM can load it. + from peft import get_peft_model + from transformers import AutoModelForCausalLM + + lora_config = self._default_lora_adapter_config() + model = AutoModelForCausalLM.from_pretrained( + self.base_model, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + peft_model = get_peft_model(model, lora_config) + # Keep LoRA A initialized (trainable) and zero only B for identity. + for name, param in peft_model.named_parameters(): + if "lora_B" in name: + param.data.zero_() + os.makedirs(lora_path, exist_ok=True) + peft_model.save_pretrained(lora_path) + del peft_model, model + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def _ensure_identity_lora(self, lora_path: str) -> None: + if self._adapter_has_weights(lora_path): + return + self._create_identity_lora(lora_path) + + def _ensure_lora_adapter_config( + self, lora_path: str, *, source_path: str | None = None + ) -> None: + config_path = os.path.join(lora_path, "adapter_config.json") + if os.path.exists(config_path): + return + os.makedirs(lora_path, exist_ok=True) + if source_path is not None: + source_config = os.path.join(source_path, "adapter_config.json") + if os.path.exists(source_config): + shutil.copy(source_config, config_path) + return + with open(config_path, "w") as f: + json.dump(asdict(self._default_lora_adapter_config()), f) + + async def _add_lora_aliases( + self, llm: AsyncLLM, step: int, checkpoint_dir: str + ) -> None: + added = await llm.add_lora( + LoRARequest( + lora_name=f"{self.model_name}@{step}", + lora_int_id=self._next_lora_id(), + lora_path=checkpoint_dir, + ) + ) + if not added: + raise RuntimeError(f"Failed to add LoRA adapter for step {step}") + added_alias = await llm.add_lora( + LoRARequest( + lora_name=self.model_name, + lora_int_id=self._next_lora_id(), + lora_path=checkpoint_dir, + ) + ) + if not added_alias: + raise RuntimeError( + f"Failed to add LoRA alias for step {step} at {checkpoint_dir}" + ) + self._latest_step = step + + async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: + llm = await self.llm + await llm.pause_generation() + await self._add_lora_aliases(llm, step, checkpoint_dir) + await llm.resume_generation() + + async def _ensure_megatron_running(self) -> None: + """Lazily start Megatron training process if not running.""" + if self._megatron_process is not None: + if self._megatron_process.returncode is None: + return + self._megatron_process = None + + try: + import megatron.bridge # type: ignore + + setup_cmd = "" + except ImportError: + setup_script = Path(__file__).parent / "setup.sh" + setup_cmd = f"bash {setup_script} && " + + subprocess.run(["pkill", "-9", "megatron-service"], check=False) + train_script = Path(__file__).parent / "train.py" + num_gpus = torch.cuda.device_count() + os.environ["MODEL_IDENTIFIER"] = self.base_model + + command = ( + f"{setup_cmd}uv run torchrun --nproc_per_node {num_gpus} {train_script}" + ) + self._megatron_process = await asyncio.create_subprocess_shell(command) + + async def start_openai_server( + self, config: dev.OpenAIServerConfig | None + ) -> tuple[str, int]: + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + self._latest_step = 0 + else: + self._latest_step = get_step_from_dir(self.output_dir) + self._ensure_identity_lora(lora_path) + self._ensure_lora_adapter_config(lora_path) + + lora_path_for_server = ( + lora_path if self._adapter_has_weights(lora_path) else None + ) + server_config = dev.get_openai_server_config( + model_name=self.model_name, + base_model=self.base_model, + log_file=f"{self.output_dir}/logs/vllm.log", + lora_path=lora_path_for_server, + config=config, + ) + await openai_server_task(engine=await self.llm, config=server_config) + return ( + server_config.get("server_args", {}).get("host") or "0.0.0.0", + server_config.get("server_args", {}).get("port", 8000), + ) + + async def vllm_engine_is_sleeping(self) -> bool: + return self._is_sleeping + + async def train( + self, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + llm = await self.llm + await llm.pause_generation() + await llm.reset_prefix_cache() + await run_on_workers(llm, do_sleep, level=2) + self._is_sleeping = True + gc_and_empty_cuda_cache() + + # Start Megatron after vLLM has freed GPU memory. + await self._ensure_megatron_running() + + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + self._ensure_lora_adapter_config(lora_path) + + self._optimizer_state_path = self._get_optimizer_state_path() + + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + for job_name in os.listdir(jobs_dir): + if job_name.endswith(".json"): + os.remove(os.path.join(jobs_dir, job_name)) + job = MegatronTrainingJob( + lora_path=lora_path, + optimizer_state_path=self._optimizer_state_path, + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=_config, + ) + job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") + with open(job_path, "w") as f: + f.write(job.model_dump_json()) + + num_lines = 0 + while True: + await asyncio.sleep(0.1) + try: + with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: + log_file.seek(0) + lines = log_file.readlines()[num_lines:] + for line in lines: + if line := line.strip(): + if line == "all done": + self._merge_lora_adapter(lora_path) + os.remove("/tmp/megatron_training_log.jsonl") + break + num_lines += 1 + yield json.loads(line) + else: + continue + break + except FileNotFoundError: + continue + + next_step = self._latest_step + 1 + new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) + os.makedirs(new_checkpoint_dir, exist_ok=True) + shutil.copy( + f"{lora_path}/adapter_model.safetensors", + f"{new_checkpoint_dir}/adapter_model.safetensors", + ) + self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) + + wake_lock_path = "/tmp/megatron_vllm_waking" + try: + with open(wake_lock_path, "w") as lock_file: + lock_file.write("waking vllm\n") + await run_on_workers(llm, do_wake_up) + self._is_sleeping = False + finally: + if os.path.exists(wake_lock_path): + os.remove(wake_lock_path) + + await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) + await llm.resume_generation() + + def _merge_lora_adapter(self, lora_path: str) -> None: + """Merge sharded LoRA adapters from distributed training.""" + base_dir = Path(lora_path) + shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors")) + if not shard_filenames: + return + + adapter_model_path = base_dir / "adapter_model.safetensors" + sharded_tensors: dict[str, list[torch.Tensor]] = {} + + for filename in shard_filenames: + with safe_open(filename, framework="pt") as file: + for key in file.keys(): + tensor = file.get_tensor(key) + sharded_tensors.setdefault(key, []).append(tensor) + + adapter_model: dict[str, torch.Tensor] = {} + if adapter_model_path.exists(): + adapter_model = load_file(adapter_model_path) + + for key, tensors in sharded_tensors.items(): + tensor = torch.cat(tensors, dim=1 if "lora_A" in key else 0) + adapter_model[key] = tensor + + save_file(adapter_model, adapter_model_path) + for filename in shard_filenames: + filename.unlink() + + @cached_property + def llm(self) -> asyncio.Task[AsyncLLM]: + engine_args = { + **self.config.get("engine_args", {}), + "enable_lora": True, + "max_loras": self.config.get("engine_args", {}).get("max_loras", 2), + } + for key in ["enable_log_requests", "disable_log_requests"]: + engine_args.pop(key, None) + return asyncio.create_task(get_llm(AsyncEngineArgs(**engine_args))) # type: ignore diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh new file mode 100644 index 000000000..bc229a98c --- /dev/null +++ b/src/art/megatron/setup.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -euo pipefail + +export CUDA_HOME="/usr/local/cuda-12.8" +export TORCH_CUDA_ARCH_LIST="9.0" +# install missing cudnn headers & ninja build tools +apt-get update +apt-get install -y libcudnn9-headers-cuda-12 ninja-build +# install apex +if [ -d /root/apex ]; then + echo "apex directory already exists, skipping clone" +else + git clone --depth 1 --branch 25.09 https://github.com/NVIDIA/apex.git /root/apex +fi +NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=16 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_FAST_LAYER_NORM=1 uv pip install --no-build-isolation /root/apex +# install flash attention +# git clone https://github.com/Dao-AILab/flash-attention.git /root/flash-attention +# (cd /root/flash-attention && git checkout 27f501d) +# uv run /root/flash-attention/hopper/setup.py install +# install transformer engine and megatron +# Build transformer-engine-torch from source with --no-build-isolation to use venv's torch headers +# (prevents ABI mismatch with system PyTorch in the container) +echo "transformer-engine>=2.11.0" > /tmp/te-override.txt +uv pip install --no-build-isolation --override /tmp/te-override.txt \ + transformer-engine==2.11.0 \ + transformer-engine-cu12==2.11.0 \ + transformer-engine-torch==2.11.0 \ + megatron-core==0.15.2 \ + megatron-bridge==0.2.0rc6 +rm /tmp/te-override.txt +# silence pynvml warnings +uv pip uninstall pynvml +uv pip install nvidia-ml-py==13.580.82 diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py new file mode 100644 index 000000000..f1083f37d --- /dev/null +++ b/src/art/megatron/train.py @@ -0,0 +1,342 @@ +# isort: off +import os + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" +# isort: on + +import json +import shutil +import time +from typing import Any, cast + +from megatron.core import parallel_state as ps +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.transformer.module import MegatronModule +from pydantic import BaseModel +from safetensors.torch import load_file, save_file +import torch + +from art import dev, types +from art.loss import loss_fn, shift_tensor +from art.megatron.lora import apply_lora_adapters +from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu +from art.megatron.provider import get_provider +from art.preprocessing.pack import ( + DiskPackedTensors, + PackedTensors, + packed_tensors_from_dir, +) + +provider = get_provider( + os.environ.get("MODEL_IDENTIFIER", "Qwen/Qwen3-30B-A3B-Instruct-2507") +) + + +def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: + for module in model_chunks: + for param in module.parameters(): + param.requires_grad = False + return model_chunks + + +provider.register_pre_wrap_hook(lambda x: freeze_model(x) or x) + +model = provider.provide_distributed_model( + ddp_config=DistributedDataParallelConfig(), + data_parallel_random_init=False, +) + +rank = torch.distributed.get_rank() +world_size = torch.distributed.get_world_size() + +for module in model: + while not isinstance(module, GPTModel) and hasattr(module, "module"): + module = module.module + if isinstance(module, GPTModel): + _preprocess = module._preprocess + + def _preprocess_hook(*args, **kwargs): + preproc_output = list(_preprocess(*args, **kwargs)) + preproc_output[0].requires_grad = True # type: ignore + table = preproc_output[1] # [S,B,1,D] type: ignore + D = table.size(-1) # type: ignore + table_flat = table.view(table.size(0), D) # type: ignore + # position_ids: [B, S] + position_ids = kwargs["position_ids"] + B, S = position_ids.shape + gathered = table_flat.index_select(0, position_ids.reshape(-1)) # [B*S, D] + gathered = gathered.view(B, S, D).permute(1, 0, 2).contiguous() # [S, B, D] + preproc_output[1] = gathered.unsqueeze(2) # [S, B, 1, D] + return tuple(preproc_output) + + module._preprocess = _preprocess_hook # type: ignore[attr-defined] + + +apply_lora_adapters(model, provider) + +optimizer = get_megatron_optimizer( + config=OptimizerConfig( + bf16=True, + lr=5e-6, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.1, + ), + model_chunks=model, # type: ignore +) + +if rank == 0: + # Print the number of parameters in the optimizer, nicely formatted + num_params = sum( + p.numel() + for group in optimizer.param_groups + if not group["is_decoupled_lr"] + for p in group["params"] + ) + print(f"Number of parameters in optimizer: {num_params:,}") + total_params = sum(p.numel() for m in model for p in m.parameters()) + percent = (num_params / total_params) * 100 if total_params > 0 else 0 + print(f"Optimizer parameters as percent of total: {percent:0.2f}%") + + +class TrainingJob(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + + +def print0(*values: Any) -> None: + if rank == 0: + print(*values) + + +offload_state = OffloadState() + + +def calculate_mask( + batch_size: int, + seq_len: int, + device: torch.device, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> torch.Tensor: + causal_mask = ( + torch.tril( + torch.ones( + seq_len, + seq_len, + dtype=torch.bool, + device=device, + ) + ) + .unsqueeze(0) + .expand(batch_size, seq_len, seq_len) + ) + group_mask = group_ids.unsqueeze(2) == group_ids.unsqueeze(1) + parent_mask = parent_ids.unsqueeze(2) == group_ids.unsqueeze(1) + mask = causal_mask & (group_mask | parent_mask) + return mask + + +offload_to_cpu(model, optimizer, rank, offload_state) + +while True: + torch.distributed.barrier() + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + job_names = sorted( + job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") + ) + if not job_names: + time.sleep(1) + continue + + wake_lock_path = "/tmp/megatron_vllm_waking" + while os.path.exists(wake_lock_path): + time.sleep(0.2) + + reload_to_gpu(model, optimizer, rank, offload_state) + + job_name = job_names[0] + job_path = os.path.join(jobs_dir, job_name) + with open(job_path, "rb") as f: + job = TrainingJob.model_validate_json(f.read()) + config = job.config + experimental_config = job.experimental_config + print0("Loaded job from", job_path) + print0("Job:", job) + adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" + if os.path.exists(adapter_model_path): + print0("Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + with torch.no_grad(): + for chunk in model: + for module in chunk.modules(): + if hasattr(module, "load_lora"): + module.load_lora(adapter_model) # type: ignore + else: + print0("No adapter model found at", adapter_model_path) + adapter_model = {} + with torch.no_grad(): + for chunk in model: + for module in chunk.modules(): + if hasattr(module, "reset_lora_parameters"): + module.reset_lora_parameters() # type: ignore + optimizer_shard_path = os.path.join( + job.optimizer_state_path, f"{rank + 1:02d}-of-{world_size:02d}.pt" + ) + if os.path.exists(optimizer_shard_path): + print( + "Loading optimizer state from", + optimizer_shard_path, + ) + optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + # No checkpoint for this run; reset optimizer state to avoid cross-run leakage + print( + "No optimizer state found at", + optimizer_shard_path, + "— resetting optimizer for new run", + ) + optimizer.optimizer.state.clear() + optimizer.reload_model_params() + print0("Loading packed tensors from", job.disk_packed_tensors["dir"]) + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + num_sequences = job.disk_packed_tensors["num_sequences"] + dp_rank = ps.get_data_parallel_rank() + dp_world_size = ps.get_data_parallel_world_size() + indices = list( + range( + dp_rank, + num_sequences, + dp_world_size, + ) + ) + # pad indices + if num_sequences % dp_world_size <= dp_rank > 0: + indices.append( + (list(range(num_sequences)) * (dp_world_size // num_sequences + 1))[dp_rank] + ) + for index in indices: + inputs = PackedTensors( # type: ignore + **{ + key: value[index : index + 1] + for key, value in packed_tensors.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=[None], + image_grid_thw=[None], + ) + ref_logprobs = None + device = next(model[0].parameters()).device + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) # type: ignore + attention_mask = ~calculate_mask( + batch_size=inputs["tokens"].shape[0], + seq_len=inputs["tokens"].shape[1], + device=device, + group_ids=inputs["group_ids"], + parent_ids=inputs["parent_ids"], + ).unsqueeze(1) # add head dimension [B, H=1, S, S] + attention_bias = torch.where( + attention_mask, + torch.tensor( + float("-inf"), dtype=next(model[0].parameters()).dtype, device=device + ), + torch.tensor(0.0, dtype=next(model[0].parameters()).dtype, device=device), + ) + new_logprobs: torch.Tensor = -model[0]( + input_ids=inputs["tokens"], + position_ids=inputs["input_pos"], + attention_mask=attention_mask, + labels=shift_tensor(inputs["tokens"], 0), + extra_block_kwargs={"attention_bias": attention_bias}, + ) + loss = loss_fn( + inputs, # type: ignore + new_logprobs, + ref_logprobs, + None, + experimental_config, + ) + probs_corr = loss.probs_corr.item() + print0("Correlation between old and new probabilities:", probs_corr) + loss = loss.mean_policy_loss + config.beta * loss.mean_kl + loss.backward() + # Reduce LoRA grads + start = time.perf_counter() + num_grads = 0 + for chunk in model: + for param in chunk.parameters(): + if param.grad is None: + continue + torch.distributed.all_reduce( + param.grad, + op=torch.distributed.ReduceOp.AVG, + group=ps.get_data_parallel_group(), + ) + num_grads += 1 + print0( + f"Reduced {num_grads} LoRA grads in {(time.perf_counter() - start) * 1e3:.1f} ms" + ) + for param_group in optimizer.param_groups: + param_group["lr"] = config.learning_rate + update_successful, grad_norm, num_zeros_in_grad = cast( + tuple[bool, float, int | None], optimizer.step() + ) + optimizer.zero_grad() + + # Mean reduce loss across all ranks for logging + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + + if rank == 0: + with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: + log_msg = json.dumps( + { + "loss": loss.item(), + "grad_norm": grad_norm, + "probs_corr": probs_corr, + } + ) + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + sharded_state_dict = {} + for chunk in model: + for module in chunk.modules(): + if hasattr(module, "sharded_lora_state_dict"): + module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( + module.sharded_lora_state_dict() # type: ignore + ) + for key, value in module_sharded_lora_state_dict.items(): + target_dtype = ( + adapter_model[key].dtype + if key in adapter_model + else value.dtype + ) + sharded_state_dict[key] = value.to(target_dtype) + shard_path = os.path.join( + job.lora_path, + f"adapter_model-{rank + 1:02d}-of-{world_size:02d}.safetensors", + ) + print("Saving adapter shard to", shard_path) + save_file(sharded_state_dict, shard_path) + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(job.optimizer_state_path, exist_ok=True) + torch.save(optimizer.state_dict(), optimizer_shard_path) + offload_to_cpu(model, optimizer, rank, offload_state) + # Ensure all ranks have finished saving before signaling completion + torch.distributed.barrier() + if rank == 0: + os.remove(job_path) + with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: + log_file.write("all done\n") + shutil.rmtree(job.disk_packed_tensors["dir"]) diff --git a/tests/unit/test_trajectory_parquet.py b/tests/unit/test_trajectory_parquet.py index c48608ee0..63b77d4bb 100644 --- a/tests/unit/test_trajectory_parquet.py +++ b/tests/unit/test_trajectory_parquet.py @@ -173,8 +173,7 @@ def test_tool_calls(self, tmp_path: Path): assert tool_calls, "Assistant message should include tool calls" first_call = tool_calls[0] assert first_call["type"] == "function" - function_call = cast(ChatCompletionMessageFunctionToolCallParam, first_call) - assert function_call["function"]["name"] == "search" + assert first_call["function"]["name"] == "search" # Check tool result message tool_result_msg = _ensure_tool_message(traj.messages_and_choices[2])