Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ def _register_custom_attention_to_sglang() -> None:
sglang only accepts pre-registered backend names, so we reuse the "aiter"
name to inject ATOMAttnBackendForSgl without modifying sglang source.
"""
import sglang.srt.layers.attention.aiter_backend as sglang_aiter_backend

from sglang.srt.layers.attention.attention_registry import (
register_attention_backend,
)
from atom.plugin.sglang.attention_backend.sgl_attn_backend import (
ATOMAttnBackendForSgl,
)

# here register the custom attention backend with the name "aiter"
# as sglang defines the fixed attention backend choices, which must be
# in-tree
logger.info("Register custom attention backend ATOMAttnBackendForSgl to SGLang")

# Speculative draft paths instantiate AiterAttnBackend directly inside
# AiterMultiStepDraftBackend, bypassing the attention registry. Rebind the
# module symbol as well so both registry lookup and direct construction use
# the plugin backend.
sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl

@register_attention_backend("aiter")
def create_atom_backend(runner):
from atom.plugin.sglang.attention_backend.sgl_attn_backend import (
ATOMAttnBackendForSgl,
)

return ATOMAttnBackendForSgl(runner)


Expand Down
17 changes: 10 additions & 7 deletions atom/plugin/sglang/attention_backend/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,20 @@ def __init__(
self.attn.k_scale = atom_parameter(
torch.tensor([1.0], dtype=torch.float32, device="cuda")
)
elif not self.attn.k_scale.is_cuda:
self.attn.k_scale = torch.nn.Parameter(
self.attn.k_scale.detach().to(device="cuda"),
requires_grad=False,
)
if self.attn.v_scale is None:
self.attn.v_scale = atom_parameter(
torch.tensor([1.0], dtype=torch.float32, device="cuda")
)
# Some SGLang attention backends consume the host-side float scales
# directly. Keep them in sync with the device-side defaults so the
# plugin path works even when checkpoint loading never populates them.
if self.attn.k_scale_float is None:
self.attn.k_scale_float = 1.0
if self.attn.v_scale_float is None:
self.attn.v_scale_float = 1.0
elif not self.attn.v_scale.is_cuda:
self.attn.v_scale = torch.nn.Parameter(
self.attn.v_scale.detach().to(device="cuda"),
requires_grad=False,
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
Expand Down
Loading
Loading