From 67660f2e82ecc0c522dd351c39a26305d779befd Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 13 Apr 2026 04:33:02 +0000 Subject: [PATCH] release: v0.4.0 - public API, CLI, config, and reward improvements - Expand public API: lazily expose reward helpers (math_verify_reward, format_reward, extract_answer), eval helpers (compare_stages, parse_results, BENCHMARK_PRESETS), and the Trainer protocol. dir(alignrl) now reflects all lazy exports. - Add BaseTrainConfig.to_yaml() for round-tripping configs and writing them to disk. - Validate BaseTrainConfig numeric fields (learning_rate, lora_r, lora_dropout, batch sizes, etc.) and forbid unknown keys so YAML typos fail fast. - CLI: add 'alignrl version' subcommand and top-level -V/--version flag, expose --num-fewshot and --batch-size on 'eval', and --temperature / --max-tokens on 'serve' (forwarded to ModelServer via create_demo). - Rewards: _normalize_numeric now handles thousands separators, currency prefixes, and percent suffixes; extract_answer supports 'final answer:' variants, strips trailing punctuation, and unwraps \text{...} inside \boxed{} groups; _answers_match is now case- insensitive for string comparisons. - Bump version to 0.4.0; add CHANGELOG.md. Test suite grows from 179 to 209 tests, all passing. --- CHANGELOG.md | 64 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/alignrl/__init__.py | 24 +++++++++++- src/alignrl/cli.py | 57 ++++++++++++++++++++++++++++- src/alignrl/config.py | 52 +++++++++++++++++++------- src/alignrl/demo.py | 6 +++ src/alignrl/rewards.py | 66 +++++++++++++++++++++++++++------ tests/test_cli.py | 81 ++++++++++++++++++++++++++++++++++++++++- tests/test_config.py | 56 ++++++++++++++++++++++++++++ tests/test_init.py | 31 +++++++++++++++- tests/test_rewards.py | 37 +++++++++++++++++++ 11 files changed, 445 insertions(+), 31 deletions(-) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..55744d3 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,64 @@ +# Changelog + +All notable changes to `alignrl` are documented in this file. The format is +based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this +project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.4.0] - 2026-04-13 + +### Added + +- **Expanded public API.** `alignrl` now lazily exports reward helpers + (`math_verify_reward`, `format_reward`, `extract_answer`), evaluation + helpers (`compare_stages`, `parse_results`, `BENCHMARK_PRESETS`), and the + `Trainer` protocol. `dir(alignrl)` reflects every lazy export for better + IDE discoverability. +- **`BaseTrainConfig.to_yaml()`** — serialize a validated config back to YAML + for round-tripping from CLI overrides to committed config files. When a + path is given, parent directories are created and the file is written. +- **`alignrl version` subcommand** and a top-level `-V` / `--version` flag + on the CLI. +- **CLI `eval` flags**: `--num-fewshot` and `--batch-size` for configuring + few-shot prompting and lm-eval batch size from the command line. +- **CLI `serve` flags**: `--temperature` and `--max-tokens` are now piped + through to every `ModelServer` in the Gradio comparison demo. +- **Config validation.** `BaseTrainConfig` now uses Pydantic field + constraints (`gt=0`, `ge=0`, etc.) for numeric fields and `extra="forbid"` + so typos in YAML configs fail loudly at load time instead of silently + falling back to defaults. + +### Changed + +- **Reward normalization is more robust.** `_normalize_numeric` now handles + thousands separators (`1,234`), currency prefixes (`$42`, `\$42`), + trailing percent (`50%`), and strips trailing periods. `_answers_match` + performs case-insensitive comparison before numeric normalization. +- **`extract_answer` supports more formats.** The regex now matches + `final answer: X` and `answer X` variants, accepts commas inside numeric + answers, and unwraps `\text{...}` inside `\boxed{...}` groups. + +### Fixed + +- Trailing punctuation (`.`, `,`, `;`, `:`) is no longer carried into + extracted answers from `"the answer is …"` / `"= …"` patterns, which + previously caused spurious reward mismatches. + +## [0.3.0] - 2026-03-25 + +### Added + +- Public API lazy-imports surface (`alignrl.SFTConfig`, `alignrl.GRPORunner`, …). +- W&B integration: `detect_wandb`, `log_eval_to_wandb`, CLI `--wandb` flag. +- HuggingFace Hub helpers: `push_adapter`, `merge_and_push`. +- Benchmark presets for `EvalConfig` (`core`, `reasoning`, `leaderboard`). +- Docker support (Dockerfile, docker-compose) for GPU-ready workflows. + +### Fixed + +- Guard against empty `loss_history` in all trainers. +- Copy preset list to prevent aliasing mutation in `EvalConfig`. +- Cache lazy imports in module globals after first resolution. +- Pass LoRA adapter to vLLM via `LoRARequest` instead of silently dropping it. + +[0.4.0]: https://github.com/sacredvoid/alignrl/releases/tag/v0.4.0 +[0.3.0]: https://github.com/sacredvoid/alignrl/releases/tag/v0.3.0 diff --git a/pyproject.toml b/pyproject.toml index 9d5d5c8..b4f8f21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "alignrl" -version = "0.3.0" +version = "0.4.0" description = "LLM post-training playbook: SFT, GRPO, DPO, eval, and inference" readme = "README.md" license = "MIT" diff --git a/src/alignrl/__init__.py b/src/alignrl/__init__.py index 9519e55..73f8331 100644 --- a/src/alignrl/__init__.py +++ b/src/alignrl/__init__.py @@ -4,25 +4,42 @@ import importlib -__version__ = "0.3.0" +__version__ = "0.4.0" _LAZY_IMPORTS: dict[str, str] = { + # Config "BaseTrainConfig": "alignrl.config", + # SFT "SFTConfig": "alignrl.sft", "SFTRunner": "alignrl.sft", + # DPO "DPOConfig": "alignrl.dpo", "DPORunner": "alignrl.dpo", + # GRPO "GRPOConfig": "alignrl.grpo", "GRPORunner": "alignrl.grpo", + # Evaluation "EvalConfig": "alignrl.eval", "EvalRunner": "alignrl.eval", + "compare_stages": "alignrl.eval", + "parse_results": "alignrl.eval", + "BENCHMARK_PRESETS": "alignrl.eval", + # Inference "InferenceConfig": "alignrl.inference", "ModelServer": "alignrl.inference", "build_prompt": "alignrl.inference", + # Shared types / protocols "TrainResult": "alignrl.types", "EvalResult": "alignrl.types", + "Trainer": "alignrl.types", + # Rewards + "math_verify_reward": "alignrl.rewards", + "format_reward": "alignrl.rewards", + "extract_answer": "alignrl.rewards", + # HF Hub helpers "push_adapter": "alignrl.hub", "merge_and_push": "alignrl.hub", + # W&B integration "detect_wandb": "alignrl.callbacks", "log_eval_to_wandb": "alignrl.callbacks", } @@ -37,4 +54,9 @@ def __getattr__(name: str): raise AttributeError(f"module 'alignrl' has no attribute {name!r}") +def __dir__() -> list[str]: + """Expose lazy exports in ``dir(alignrl)`` for discoverability.""" + return sorted([*_LAZY_IMPORTS, "__version__"]) + + __all__ = [*_LAZY_IMPORTS, "__version__"] diff --git a/src/alignrl/cli.py b/src/alignrl/cli.py index 1abae89..3f086f1 100644 --- a/src/alignrl/cli.py +++ b/src/alignrl/cli.py @@ -8,6 +8,13 @@ from pathlib import Path +def cmd_version(args: argparse.Namespace) -> None: + """Print the installed alignrl version.""" + from alignrl import __version__ + + print(f"alignrl {__version__}") + + def cmd_train(args: argparse.Namespace) -> None: """Run a training pipeline.""" config_path = Path(args.config) @@ -59,6 +66,10 @@ def cmd_eval(args: argparse.Namespace) -> None: config_kwargs["tasks"] = args.tasks.split(",") if args.preset: config_kwargs["preset"] = args.preset + if getattr(args, "num_fewshot", None) is not None: + config_kwargs["num_fewshot"] = args.num_fewshot + if getattr(args, "batch_size", None) is not None: + config_kwargs["batch_size"] = args.batch_size config = EvalConfig(**config_kwargs) runner = EvalRunner(config) @@ -89,14 +100,32 @@ def cmd_serve(args: argparse.Namespace) -> None: name, _, path = spec.partition("=") stages[name] = path if path else None - demo = create_demo(stages=stages, model_name=args.model) + demo_kwargs: dict = {"stages": stages, "model_name": args.model} + if getattr(args, "temperature", None) is not None: + demo_kwargs["temperature"] = args.temperature + if getattr(args, "max_tokens", None) is not None: + demo_kwargs["max_tokens"] = args.max_tokens + + demo = create_demo(**demo_kwargs) demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share) def main() -> None: + from alignrl import __version__ + parser = argparse.ArgumentParser(prog="alignrl", description="LLM Post-Training Playbook") + parser.add_argument( + "-V", + "--version", + action="version", + version=f"alignrl {__version__}", + ) sub = parser.add_subparsers(dest="command", required=True) + # Version (as a subcommand for scripting use) + version_p = sub.add_parser("version", help="Print the installed alignrl version") + version_p.set_defaults(func=cmd_version) + # Train train_p = sub.add_parser("train", help="Run training pipeline") train_p.add_argument("stage", choices=["sft", "grpo", "dpo"]) @@ -118,6 +147,19 @@ def main() -> None: choices=["core", "reasoning", "leaderboard"], help="Benchmark preset (default: core)", ) + eval_p.add_argument( + "--num-fewshot", + dest="num_fewshot", + type=int, + default=None, + help="Number of few-shot examples (default: 0)", + ) + eval_p.add_argument( + "--batch-size", + dest="batch_size", + default=None, + help="Batch size for evaluation (e.g. 'auto', 8)", + ) eval_p.add_argument("--limit", type=int, default=None) eval_p.add_argument("--output", default="./results") eval_p.add_argument("--wandb", action="store_true", help="Log results to W&B") @@ -134,6 +176,19 @@ def main() -> None: ) serve_p.add_argument("--port", type=int, default=7860) serve_p.add_argument("--share", action="store_true") + serve_p.add_argument( + "--temperature", + type=float, + default=None, + help="Sampling temperature for generation (default: 0.7)", + ) + serve_p.add_argument( + "--max-tokens", + dest="max_tokens", + type=int, + default=None, + help="Maximum tokens to generate per response (default: 512)", + ) serve_p.set_defaults(func=cmd_serve) args = parser.parse_args() diff --git a/src/alignrl/config.py b/src/alignrl/config.py index f60bb9f..b164e00 100644 --- a/src/alignrl/config.py +++ b/src/alignrl/config.py @@ -6,33 +6,43 @@ from typing import TYPE_CHECKING import yaml -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field if TYPE_CHECKING: from typing_extensions import Self class BaseTrainConfig(BaseModel): - """Shared training configuration.""" + """Shared training configuration. + + All training stage configs (SFT, GRPO, DPO) inherit from this. Fields + are validated at construction time via Pydantic, so malformed YAML + files fail fast rather than partway through a training run. + """ + + # Use a custom config dict to forbid unknown keys. This catches typos + # in YAML files (e.g. ``learnign_rate: 2e-4``) before any training + # begins, which is much friendlier than a silent default. + model_config = ConfigDict(extra="forbid") model_name: str = "Qwen/Qwen2.5-3B" output_dir: Path = Path("./outputs") - max_seq_length: int = 2048 - per_device_train_batch_size: int = 4 - gradient_accumulation_steps: int = 4 - learning_rate: float = 2e-4 - num_train_epochs: int = 1 - max_steps: int = -1 - warmup_steps: int = 10 + max_seq_length: int = Field(default=2048, gt=0) + per_device_train_batch_size: int = Field(default=4, gt=0) + gradient_accumulation_steps: int = Field(default=4, gt=0) + learning_rate: float = Field(default=2e-4, gt=0) + num_train_epochs: int = Field(default=1, ge=0) + max_steps: int = Field(default=-1, ge=-1) + warmup_steps: int = Field(default=10, ge=0) optim: str = "adamw_8bit" - seed: int = 42 + seed: int = Field(default=42, ge=0) report_to: str = "none" - logging_steps: int = 10 + logging_steps: int = Field(default=10, gt=0) # LoRA - lora_r: int = 16 - lora_alpha: int = 32 - lora_dropout: float = 0.0 + lora_r: int = Field(default=16, gt=0) + lora_alpha: int = Field(default=32, gt=0) + lora_dropout: float = Field(default=0.0, ge=0.0, le=1.0) lora_target_modules: list[str] = Field( default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] ) @@ -42,10 +52,24 @@ class BaseTrainConfig(BaseModel): @classmethod def from_yaml(cls, path: Path) -> Self: + """Load a config from a YAML file. Missing keys use defaults.""" with open(path) as f: data = yaml.safe_load(f) return cls(**(data or {})) + def to_yaml(self, path: Path | None = None) -> str: + """Serialize the config to YAML. + + Returns the YAML string. If ``path`` is provided, also writes the + YAML to disk (parent directories are created as needed). + """ + data = self.model_dump(mode="json") + text: str = yaml.safe_dump(data, sort_keys=False, default_flow_style=False) + if path is not None: + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_text(text) + return text + # ChatML template used as fallback when the tokenizer doesn't have one set. # This is the standard format for Qwen, Yi, and many other models. diff --git a/src/alignrl/demo.py b/src/alignrl/demo.py index 5b46745..c81d267 100644 --- a/src/alignrl/demo.py +++ b/src/alignrl/demo.py @@ -13,12 +13,16 @@ def create_demo( stages: dict[str, str | None], model_name: str = "Qwen/Qwen2.5-3B", + temperature: float = 0.7, + max_tokens: int = 512, ): """Create a Gradio demo comparing model outputs across training stages. Args: stages: {stage_name: adapter_path_or_None} model_name: base model name + temperature: sampling temperature passed to each backend + max_tokens: maximum number of tokens to generate per response """ import gradio as gr @@ -28,6 +32,8 @@ def create_demo( model_name=model_name, adapter_path=adapter_path, backend="unsloth", + temperature=temperature, + max_tokens=max_tokens, ) server = ModelServer(config) server.load() diff --git a/src/alignrl/rewards.py b/src/alignrl/rewards.py index 6c1391c..f1226b7 100644 --- a/src/alignrl/rewards.py +++ b/src/alignrl/rewards.py @@ -5,8 +5,15 @@ import math import re -_RE_ANSWER = re.compile(r"(?:the answer is|answer:)\s*([^\s.,]+)", re.IGNORECASE) -_RE_EQUALS = re.compile(r"=\s*([^\s.,=]+)") +# Match "the answer is X" / "final answer: X" / "answer: X" variants, +# allowing commas inside numeric answers like "1,234" and an optional +# currency/latex prefix like "$" or "\$". +_RE_ANSWER = re.compile( + r"(?:final\s+answer|the\s+answer\s+is|answer)\s*[:=]?\s*" + r"\*?\*?(\\?\$?-?[\w,./\\{}%]+)", + re.IGNORECASE, +) +_RE_EQUALS = re.compile(r"=\s*([^\s,=]+)") def _extract_boxed_contents(text: str) -> list[str]: @@ -34,30 +41,57 @@ def _extract_boxed_contents(text: str) -> list[str]: return results +_RE_TEXT_WRAPPER = re.compile(r"\\text\{([^{}]*)\}") + + +def _unwrap_latex(s: str) -> str: + """Strip common LaTeX wrappers like ``\\text{...}`` from a boxed answer.""" + prev = None + current = s + while prev != current: + prev = current + current = _RE_TEXT_WRAPPER.sub(r"\1", current) + return current + + def extract_answer(text: str) -> str | None: """Extract the final answer from model output. - Supports: \\boxed{...}, 'the answer is X', 'X = Y' patterns. + Supports ``\\boxed{...}`` (with nested braces and ``\\text{}`` wrappers), + ``the answer is X`` / ``final answer: X``, and ``X = Y`` patterns. Returns the last match found (most likely the final answer). """ boxed = _extract_boxed_contents(text) if boxed: - return boxed[-1].strip() + return _unwrap_latex(boxed[-1]).strip() answer_match = _RE_ANSWER.findall(text) if answer_match: - return answer_match[-1].strip() + return answer_match[-1].strip().rstrip(".,;:") eq_match = _RE_EQUALS.findall(text) if eq_match: - return eq_match[-1].strip() + return eq_match[-1].strip().rstrip(".,;:") return None def _normalize_numeric(s: str) -> str | None: - """Try to parse as a number for comparison.""" - s = s.strip().rstrip(".") + """Try to parse as a number for comparison. + + Handles common wrappers seen in LLM math outputs: + - commas as thousands separators (``1,234``) + - leading currency symbols (``$42``, ``\\$42``) + - trailing percent sign (``50%``) + - trailing period (``42.``) + """ + s = s.strip() + # Strip common latex/currency/format wrappers + if s.startswith("\\$"): + s = s[2:] + s = s.lstrip("$").rstrip("%").rstrip(".") + # Remove thousands separators (commas between digits) + s = s.replace(",", "") try: val = float(s) if not math.isfinite(val): @@ -70,11 +104,19 @@ def _normalize_numeric(s: str) -> str | None: def _answers_match(predicted: str, expected: str) -> bool: - """Check if two answers are equivalent.""" - if predicted.strip() == expected.strip(): + """Check if two answers are equivalent. + + First tries exact string match (case-insensitive), then normalized + numeric match so that ``3.0`` == ``3`` and ``1,234`` == ``1234``. + """ + pred = predicted.strip() + exp = expected.strip() + if pred == exp: + return True + if pred.casefold() == exp.casefold(): return True - norm_pred = _normalize_numeric(predicted) - norm_exp = _normalize_numeric(expected) + norm_pred = _normalize_numeric(pred) + norm_exp = _normalize_numeric(exp) if norm_pred and norm_exp: return norm_pred == norm_exp return False diff --git a/tests/test_cli.py b/tests/test_cli.py index 2015505..f74d1d2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ import pytest -from alignrl.cli import cmd_eval, cmd_serve, cmd_train, main +from alignrl.cli import cmd_eval, cmd_serve, cmd_train, cmd_version, main class TestCLIParser: @@ -114,6 +114,8 @@ def test_eval_creates_output(self, tmp_path) -> None: stage="base", output=str(tmp_path / "results"), wandb=False, + num_fewshot=None, + batch_size=None, ) with ( @@ -124,6 +126,35 @@ def test_eval_creates_output(self, tmp_path) -> None: mock_runner.evaluate.assert_called_once_with(stage="base") assert (tmp_path / "results" / "eval_base.json").exists() + def test_eval_forwards_num_fewshot_and_batch_size(self, tmp_path) -> None: + mock_result = MagicMock() + mock_result.to_dict.return_value = {"benchmarks": {}} + mock_result.benchmarks = {} + mock_runner = MagicMock() + mock_runner.evaluate.return_value = mock_result + + args = argparse.Namespace( + model="test-model", + adapter=None, + tasks=None, + preset=None, + limit=None, + stage="sft", + output=str(tmp_path / "results"), + wandb=False, + num_fewshot=5, + batch_size="8", + ) + + with ( + patch("alignrl.eval.EvalRunner", return_value=mock_runner), + patch("alignrl.eval.EvalConfig") as mock_cfg_cls, + ): + cmd_eval(args) + call_kwargs = mock_cfg_cls.call_args.kwargs + assert call_kwargs["num_fewshot"] == 5 + assert call_kwargs["batch_size"] == "8" + class TestMainEntry: def test_main_module_execution(self) -> None: @@ -148,6 +179,8 @@ def test_serve_parses_stage_specs(self) -> None: stages=["base", "sft=./outputs/sft/final"], port=7860, share=False, + temperature=None, + max_tokens=None, ) with patch("alignrl.demo.create_demo", return_value=mock_demo) as mock_create: @@ -157,3 +190,49 @@ def test_serve_parses_stage_specs(self) -> None: assert stages["base"] is None assert stages["sft"] == "./outputs/sft/final" mock_demo.launch.assert_called_once() + + def test_serve_forwards_generation_params(self) -> None: + mock_demo = MagicMock() + + args = argparse.Namespace( + model="test-model", + stages=["base"], + port=7860, + share=False, + temperature=0.2, + max_tokens=256, + ) + + with patch("alignrl.demo.create_demo", return_value=mock_demo) as mock_create: + cmd_serve(args) + kwargs = mock_create.call_args.kwargs + assert kwargs["temperature"] == 0.2 + assert kwargs["max_tokens"] == 256 + + +class TestCmdVersion: + def test_version_subcommand_prints_version(self, capsys: pytest.CaptureFixture) -> None: + from alignrl import __version__ + + cmd_version(argparse.Namespace()) + captured = capsys.readouterr() + assert __version__ in captured.out + assert "alignrl" in captured.out + + def test_version_via_main(self, capsys: pytest.CaptureFixture) -> None: + from alignrl import __version__ + + sys.argv = ["alignrl", "version"] + main() + captured = capsys.readouterr() + assert __version__ in captured.out + + def test_version_flag(self, capsys: pytest.CaptureFixture) -> None: + from alignrl import __version__ + + sys.argv = ["alignrl", "--version"] + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 0 + captured = capsys.readouterr() + assert __version__ in captured.out diff --git a/tests/test_config.py b/tests/test_config.py index af17b95..72526d8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -64,6 +64,62 @@ def test_custom_lora_modules(self) -> None: cfg = BaseTrainConfig(lora_target_modules=["q_proj", "k_proj"]) assert cfg.lora_target_modules == ["q_proj", "k_proj"] + def test_rejects_negative_learning_rate(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(learning_rate=-1e-4) + + def test_rejects_zero_learning_rate(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(learning_rate=0.0) + + def test_rejects_nonpositive_lora_r(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(lora_r=0) + + def test_rejects_nonpositive_batch_size(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(per_device_train_batch_size=0) + + def test_rejects_lora_dropout_gt_one(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(lora_dropout=1.5) + + def test_rejects_unknown_field(self) -> None: + # Typos in YAML fields should fail fast, not silently default. + with pytest.raises(ValueError): + BaseTrainConfig(learnign_rate=1e-4) # type: ignore[call-arg] + + def test_max_steps_allows_minus_one(self) -> None: + cfg = BaseTrainConfig(max_steps=-1) + assert cfg.max_steps == -1 + + def test_rejects_max_steps_below_minus_one(self) -> None: + with pytest.raises(ValueError): + BaseTrainConfig(max_steps=-2) + + +class TestBaseTrainConfigToYaml: + def test_to_yaml_roundtrips(self, tmp_path: Path) -> None: + cfg = BaseTrainConfig(model_name="round-trip-model", learning_rate=1e-5, lora_r=8) + text = cfg.to_yaml() + assert "round-trip-model" in text + assert "learning_rate" in text + + out_path = tmp_path / "written.yaml" + cfg.to_yaml(out_path) + assert out_path.exists() + + restored = BaseTrainConfig.from_yaml(out_path) + assert restored.model_name == "round-trip-model" + assert restored.learning_rate == 1e-5 + assert restored.lora_r == 8 + + def test_to_yaml_creates_parent_dirs(self, tmp_path: Path) -> None: + cfg = BaseTrainConfig() + out_path = tmp_path / "nested" / "dir" / "config.yaml" + cfg.to_yaml(out_path) + assert out_path.exists() + class TestEnsureChatTemplate: def test_sets_template_when_missing(self) -> None: diff --git a/tests/test_init.py b/tests/test_init.py index 28eae6c..42ba25b 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -5,13 +5,42 @@ class TestPublicAPI: def test_version(self) -> None: - assert alignrl.__version__ == "0.3.0" + assert alignrl.__version__ == "0.4.0" + + def test_version_matches_package_metadata(self) -> None: + import contextlib + from importlib.metadata import PackageNotFoundError, version + + # Running from a source checkout without an installed dist is fine. + with contextlib.suppress(PackageNotFoundError): + assert version("alignrl") == alignrl.__version__ def test_all_exports_listed(self) -> None: assert "__version__" in alignrl.__all__ assert "SFTConfig" in alignrl.__all__ assert "TrainResult" in alignrl.__all__ + def test_lazy_import_rewards(self) -> None: + from alignrl import extract_answer, format_reward, math_verify_reward + + assert callable(math_verify_reward) + assert callable(format_reward) + assert callable(extract_answer) + + def test_lazy_import_eval_helpers(self) -> None: + from alignrl import BENCHMARK_PRESETS, compare_stages, parse_results + + assert isinstance(BENCHMARK_PRESETS, dict) + assert "core" in BENCHMARK_PRESETS + assert callable(compare_stages) + assert callable(parse_results) + + def test_dir_includes_lazy_exports(self) -> None: + names = set(dir(alignrl)) + assert "SFTConfig" in names + assert "math_verify_reward" in names + assert "__version__" in names + def test_lazy_import_sft(self) -> None: from alignrl import SFTConfig, SFTRunner diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 0dcd7af..3d3f6d8 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -28,6 +28,20 @@ def test_no_answer_found(self) -> None: def test_multiple_boxed_takes_last(self) -> None: assert extract_answer(r"\boxed{1} and \boxed{2}") == "2" + def test_final_answer_prefix(self) -> None: + assert extract_answer("So the final answer: 7") == "7" + + def test_answer_with_commas(self) -> None: + # Thousands-separated numeric answers now survive extraction. + assert extract_answer("The answer is 1,234") == "1,234" + + def test_trailing_punctuation_stripped(self) -> None: + assert extract_answer("The answer is 42.") == "42" + assert extract_answer("The answer is 42;") == "42" + + def test_unwraps_text_wrapper_in_boxed(self) -> None: + assert extract_answer(r"\boxed{\text{42}}") == "42" + class TestMathVerifyReward: def test_correct_integer(self) -> None: @@ -122,6 +136,20 @@ def test_infinity_returns_none(self) -> None: def test_nan_returns_none(self) -> None: assert _normalize_numeric("nan") is None + def test_thousands_separator(self) -> None: + assert _normalize_numeric("1,234") == "1234" + assert _normalize_numeric("1,234,567") == "1234567" + + def test_currency_prefix(self) -> None: + assert _normalize_numeric("$42") == "42" + assert _normalize_numeric("\\$42") == "42" + + def test_percent_suffix(self) -> None: + assert _normalize_numeric("50%") == "50" + + def test_negative_number(self) -> None: + assert _normalize_numeric("-5") == "-5" + class TestAnswersMatch: def test_exact_match(self) -> None: @@ -136,6 +164,15 @@ def test_non_matching_strings(self) -> None: def test_non_matching_numbers(self) -> None: assert _answers_match("3", "4") is False + def test_case_insensitive_string(self) -> None: + assert _answers_match("True", "true") is True + + def test_comma_separated_matches_plain(self) -> None: + assert _answers_match("1,234", "1234") is True + + def test_currency_matches_plain(self) -> None: + assert _answers_match("$42", "42") is True + class TestMathVerifyRewardEdgeCases: def test_empty_completion_list(self) -> None: