Skip to content

Expand rotary, attention, and embedding layer tests into parametrized suites#504

Merged
jlamypoirier merged 4 commits intomainfrom
jlp_test-improvements
May 1, 2026
Merged

Expand rotary, attention, and embedding layer tests into parametrized suites#504
jlamypoirier merged 4 commits intomainfrom
jlp_test-improvements

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 1, 2026

Summary

  • test_rotary.py: rewrites the procedural test as a parametrized config-driven suite covering default, llama3, yarn, and 2D rotary across multiple head sizes and sequence lengths. Independent 1D/2D reference implementations in plain PyTorch (no Fast-LLM kernel calls). Clones query before calling forward — triton_rotary_ writes results in-place and would corrupt the reference otherwise.

  • test_attention.py: replaces sparse per-feature stubs with a single combined test. Each configuration exercises three checks:

    • Independent einsum reference (plain F.linear + per-head matmul, no Fast-LLM internals)
    • Packing equivalence: packed output must match per-sequence forward and backward
    • Flash equivalence: packed flash output must match bfloat16 backup reference

    Configurations: causal, non-causal, sliding-window, MQA (1 KV head), MHA (head_groups == heads), default-rotary.

  • test_embedding.py (new): parametrized coverage for LanguageModelEmbedding. Base cases (default, padding, position embeddings) × variants (float32, bfloat16, full-precision residual). Reference in plain PyTorch, independent of Fast-LLM embedding internals.

Test plan

  • pytest -v tests/layers/test_rotary.py tests/layers/test_attention.py tests/layers/test_embedding.py

@jlamypoirier jlamypoirier changed the title Expand test_rotary and test_attention into parametrized suites Expand rotary, attention, and embedding layer tests into parametrized suites May 1, 2026
@jlamypoirier jlamypoirier mentioned this pull request May 1, 2026
7 tasks
jlamypoirier and others added 3 commits May 1, 2026 03:15
test_rotary.py:
- Rewrites the procedural test as a parametrized config-driven suite
  covering default, llama3, yarn, and 2D rotary across multiple head sizes
  and sequence lengths.
- Independent 1D/2D reference implementations in plain PyTorch (no Fast-LLM
  kernel calls) make the expected values auditable and kernel-agnostic.
- Clones query before calling forward: triton_rotary_ writes results
  in-place, corrupting the reference if both share storage.

test_attention.py:
- Replaces the per-feature test stubs with a single combined test that
  exercises an independent einsum reference, packing equivalence
  (packed == per-sequence forward and backward), and flash equivalence.
- Covers causal, non-causal, sliding-window, MQA (1 KV head), MHA
  (head_groups == heads), and default-rotary configurations.
- TF32 disabled for the reference check to keep summation-order differences
  below 1e-7; packing and flash checks use looser tolerances.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Covers padding (masked tokens), position embeddings, bfloat16, and
full-precision residual across all combinations. Reference implementation
in plain PyTorch, independent of Fast-LLM embedding internals.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop redundant `_test_attention_impl` wrapper; collapse `_no_tf32`
  context into `test_attention` directly.
- Gate the per-sequence backward in `_run_per_seq_reference` behind a
  `with_backward` flag; the bf16 reference doesn't consume gradients.
- Parameterize rotary theta on `AttentionTestConfig` so the reference
  and the attention layer can't desync on theta.
- Switch `Assert.rms_close` (flash) and `torch.testing.assert_close`
  (embedding) to `Assert.rms_close_relative` for consistency with the
  other layer tests.
- Drop redundant comments restating obvious code; trim stale comment
  about Triton CPU failures (q/k norm not exercised here).
- Use `float32` as the default-variant name in test_embedding so the
  empty-string special case can be removed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier force-pushed the jlp_test-improvements branch from 3f75028 to 4c0da7a Compare May 1, 2026 07:15
…ding+position case

- Add @pytest.mark.slow to test_attention and test_rotary (consistent with test_embedding and test_ssm)
- Fix misleading comment on _attention_rotary_cases: packing equivalence does run for single-doc inputs
- Replace _add_configs side-effectful helper with a declarative list comprehension
- Add padding+position_embeddings base case to test_embedding to catch masking/position ordering bugs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jlamypoirier jlamypoirier merged commit a0b2063 into main May 1, 2026
3 checks passed
@jlamypoirier jlamypoirier deleted the jlp_test-improvements branch May 1, 2026 07:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant