Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 24 additions & 55 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from transformers.utils import HF_MODULES_CACHE

from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
load_pretrained_config)
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
Expand All @@ -25,18 +26,6 @@
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)


class LazyConfigDict(dict):

def __getitem__(self, key):
import tensorrt_llm._torch.configs as configs
return getattr(configs, super().__getitem__(key))


_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
deepseek_v32="DeepseekV3Config",
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class


@dataclass
class MoeLoadBalancerConfig:
num_slots: Optional[int] = None
Expand Down Expand Up @@ -432,51 +421,31 @@ def from_pretrained(cls,
# When handling the case where model_format is TLLM_ENGINE
# send cyclic requests to the NONE URL.
if checkpoint_dir is not None:
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
pretrained_config = load_pretrained_config(
checkpoint_dir,
trust_remote_code=trust_remote_code,
**kwargs,
)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
pretrained_config = config_class.from_pretrained(
checkpoint_dir,
**kwargs,
)
if model_type == "deepseek_v32":
sparse_attention_config = kwargs.get(
'sparse_attention_config')
kwargs[
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
index_n_heads=(
sparse_attention_config.index_n_heads
if sparse_attention_config
and sparse_attention_config.index_n_heads
is not None else
pretrained_config.index_n_heads),
index_head_dim=(
sparse_attention_config.index_head_dim
if sparse_attention_config
and sparse_attention_config.index_head_dim
is not None else
pretrained_config.index_head_dim),
index_topk=(sparse_attention_config.index_topk
if sparse_attention_config and
sparse_attention_config.index_topk
is not None else
pretrained_config.index_topk),
indexer_max_chunk_size=(
sparse_attention_config.
indexer_max_chunk_size
if sparse_attention_config
and sparse_attention_config.
indexer_max_chunk_size is not None else
None))
else:
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)
if pretrained_config.architectures[
0] == "DeepseekV32ForCausalLM":
sparse_attention_config = kwargs.get(
'sparse_attention_config')
if sparse_attention_config:
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
else:
index_n_heads = pretrained_config.index_n_heads
index_head_dim = pretrained_config.index_head_dim
index_topk = pretrained_config.index_topk
indexer_max_chunk_size = None
kwargs[
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
index_n_heads=index_n_heads,
index_head_dim=index_head_dim,
index_topk=index_topk,
indexer_max_chunk_size=indexer_max_chunk_size)
else:
raise ValueError(
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
Expand Down
32 changes: 32 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import transformers


def is_nemotron_hybrid(config):
if hasattr(config, "hybrid_override_pattern"
) and config.hybrid_override_pattern is not None and len(
Expand All @@ -18,3 +21,32 @@ def is_qwen3_next(config):
config, 'architectures'
) and config.architectures is not None and config.architectures[
0] == 'Qwen3NextForCausalLM'


# TODO: remove this once the transformers can support all of those models in _CONFIG_REGISTRY
class LazyConfigDict(dict):

def __getitem__(self, key):
import tensorrt_llm._torch.configs as configs
return getattr(configs, super().__getitem__(key))


_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
deepseek_v32="DeepseekV3Config",
Comment thread
lfr-0531 marked this conversation as resolved.
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class


def load_pretrained_config(model_name_or_path: str,
trust_remote_code: bool = False,
**kwargs) -> transformers.PretrainedConfig:
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
model_name_or_path, **kwargs)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
model_config = config_class.from_pretrained(model_name_or_path,
**kwargs)
else:
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code)
return model_config
9 changes: 4 additions & 5 deletions tensorrt_llm/bench/build/build.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations
from transformers import AutoConfig

from pathlib import Path
from typing import Tuple, get_args
import click
from click_option_group import AllOptionGroup, optgroup

from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, load_pretrained_config
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
Expand Down Expand Up @@ -86,9 +85,9 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
Raises:
ValueError: When model is not supported.
"""
if is_nemotron_hybrid(
AutoConfig.from_pretrained(model_path or model_name,
trust_remote_code=True)):
pretrained_config = load_pretrained_config(model_path or model_name,
trust_remote_code=True)
if is_nemotron_hybrid(pretrained_config):
return NemotronHybridConfig.from_hf(model_name, model_path)
return ModelConfig.from_hf(model_name, model_path)

Expand Down
10 changes: 6 additions & 4 deletions tensorrt_llm/bench/build/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from transformers import AutoConfig
from typing import Optional, Literal
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
import huggingface_hub
Expand All @@ -14,6 +13,8 @@
import json
import struct

from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config


def parse_safetensors_file_metadata(model_path, filename):

Expand Down Expand Up @@ -192,9 +193,10 @@ def get_param_count(cls, model_hf_name, hf_model_path):

@classmethod
def from_hf(cls, model_hf_name, hf_model_path):
model_name_or_path = hf_model_path or model_hf_name
hf_config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True).to_dict()
pretrained_config = load_pretrained_config(hf_model_path
or model_hf_name,
trust_remote_code=True)
hf_config = pretrained_config.to_dict()
param_count = cls.get_param_count(model_hf_name, hf_model_path)

return cls(name=model_hf_name, param_count=param_count, **hf_config)
Expand Down
32 changes: 10 additions & 22 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from transformers import AutoConfig, AutoProcessor
from transformers import AutoProcessor

from tensorrt_llm._tensorrt_engine import LLM
# yapf: disable
Expand Down Expand Up @@ -99,27 +99,15 @@ def __init__(self,
except Exception:
logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
self.processor = None
# Temporary workaround for DSv3.2 config.
import transformers

from tensorrt_llm._torch.model_config import _CONFIG_REGISTRY
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
hf_tokenizer_path,
trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
self.model_config = config_class.from_pretrained(
hf_tokenizer_path,
trust_remote_code=trust_remote_code
)
else:
try:
self.model_config = AutoConfig.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
except Exception:
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
self.model_config = None
# load model config
try:
from tensorrt_llm._torch.pyexecutor.config_utils import \
load_pretrained_config
self.model_config = load_pretrained_config(hf_tokenizer_path,
trust_remote_code=trust_remote_code)
except Exception:
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
self.model_config = None

# Enable response storage for Responses API
self.enable_store = True
Expand Down