Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def __new__(
"""

local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS}
if extra_config is not None:
local_args.update({k: v for k, v in extra_config.to_dict().items() if k in local_args and v is not None})

if NEW_ARCH:
from auto_round.compressors_new.entry import AutoRoundCompatible
Expand Down
17 changes: 17 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
get_layer_names_in_block,
get_lm_head_name,
get_module,
get_reverse_checkpoint_conversion_mapping,
global_state,
hook_ngram_embeddings_on_cpu,
htcore,
Expand All @@ -107,6 +108,7 @@
memory_monitor,
mv_module_from_gpu,
normalize_no_split_modules,
revert_checkpoint_conversion_mapping,
set_amax_for_all_moe_layers,
set_module,
to_device,
Expand Down Expand Up @@ -3603,6 +3605,21 @@ def save_quantized(
serialization_dict["autoround_version"] = __version__
if "scale_dtype" in serialization_dict.keys():
serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"])

# to match the original name
reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model)

if isinstance(serialization_dict["to_quant_block_names"], str):
serialization_dict["to_quant_block_names"] = revert_checkpoint_conversion_mapping(
serialization_dict["to_quant_block_names"], reverse_checkpoint_conversion_mapping
)

elif isinstance(serialization_dict["to_quant_block_names"], list):
for idx in range(len(serialization_dict["to_quant_block_names"])):
serialization_dict["to_quant_block_names"][idx] = revert_checkpoint_conversion_mapping(
serialization_dict["to_quant_block_names"][idx], reverse_checkpoint_conversion_mapping
)

compressed_model = format.save_quantized(
save_folder,
model=self.model,
Expand Down
12 changes: 11 additions & 1 deletion auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from torch.nn import Parameter

from auto_round.logger import logger
from auto_round.utils import get_lm_head_name, get_module
from auto_round.utils import (
get_lm_head_name,
get_module,
get_reverse_checkpoint_conversion_mapping,
revert_checkpoint_conversion_mapping,
)


class ShardWriter:
Expand Down Expand Up @@ -55,6 +60,7 @@ def __init__(self, rounder):
self.shard_meta = [] # List of {tmp_file: str, params: list}
self.global_weight_map = {}
self.shard_counter = 0
self.reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model)

# Persistent set of all parameter names already flushed to a shard file.
# Maintained incrementally in _flush_shard to avoid O(N^2) rebuilds in _add_tensor.
Expand Down Expand Up @@ -105,6 +111,10 @@ def save_module(self, m: torch.nn.Module, name: str = None):
self._add_tensor(param_name, v)

def _add_tensor(self, name: str, tensor: torch.Tensor):

# transformers will handle _checkpoint_conversion_mapping automatically if is_immediate_saving=False
name = revert_checkpoint_conversion_mapping(name, self.reverse_checkpoint_conversion_mapping)

Comment thread
mengniwang95 marked this conversation as resolved.
if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta":
self.skipped_meta_tensors.append(name)
return
Expand Down
16 changes: 16 additions & 0 deletions auto_round/compressors_new/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@
extract_block_names_to_str,
find_matching_blocks,
get_block_names,
get_reverse_checkpoint_conversion_mapping,
is_debug_mode,
is_hpex_available,
is_quantized_input_module,
memory_monitor,
revert_checkpoint_conversion_mapping,
)
from auto_round.utils.device import (
_force_trim_malloc,
Expand Down Expand Up @@ -1170,6 +1172,20 @@ def save_quantized(
if "scale_dtype" in serialization_dict.keys():
serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"])

# to match the original name
reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model)

if isinstance(serialization_dict["to_quant_block_names"], str):
serialization_dict["to_quant_block_names"] = revert_checkpoint_conversion_mapping(
serialization_dict["to_quant_block_names"], reverse_checkpoint_conversion_mapping
)

elif isinstance(serialization_dict["to_quant_block_names"], list):
for idx in range(len(serialization_dict["to_quant_block_names"])):
serialization_dict["to_quant_block_names"][idx] = revert_checkpoint_conversion_mapping(
serialization_dict["to_quant_block_names"][idx], reverse_checkpoint_conversion_mapping
)

compressed_model = format.save_quantized(
save_folder,
model=self.model_context.model,
Expand Down
12 changes: 11 additions & 1 deletion auto_round/compressors_new/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from auto_round.context.compress import CompressContext
from auto_round.context.model import ModelContext
from auto_round.logger import logger
from auto_round.utils import get_lm_head_name, get_module
from auto_round.utils import (
get_lm_head_name,
get_module,
get_reverse_checkpoint_conversion_mapping,
revert_checkpoint_conversion_mapping,
)


class ShardWriter:
Expand Down Expand Up @@ -78,6 +83,7 @@ def __init__(
self.shard_meta = [] # List of {tmp_file: str, params: list}
self.global_weight_map = {}
self.shard_counter = 0
self.reverse_checkpoint_conversion_mapping = get_reverse_checkpoint_conversion_mapping(self.model)

# Persistent set of all parameter names already flushed to a shard file.
# Maintained incrementally in _flush_shard to avoid O(N^2) rebuilds in _add_tensor.
Expand Down Expand Up @@ -154,6 +160,10 @@ def save_module(self, m: torch.nn.Module, name: str = None):
self._add_tensor(param_name, v)

def _add_tensor(self, name: str, tensor: torch.Tensor):

# transformers will handle _checkpoint_conversion_mapping automatically if is_immediate_saving=False
name = revert_checkpoint_conversion_mapping(name, self.reverse_checkpoint_conversion_mapping)

Comment thread
mengniwang95 marked this conversation as resolved.
if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta":
self.skipped_meta_tensors.append(name)
return
Expand Down
11 changes: 10 additions & 1 deletion auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from auto_round.special_model_handler import update_module
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
apply_checkpoint_conversion_mapping,
check_start_with_block_name,
check_to_quantized,
find_matching_blocks,
get_block_names,
get_checkpoint_conversion_mapping,
get_module,
is_hpex_available,
is_transformers_version_greater_or_equal_5,
Expand Down Expand Up @@ -270,12 +272,15 @@ def get_layer_config(model, quantization_config):
)

# Determine the quantization block list
checkpoint_conversion_mapping = get_checkpoint_conversion_mapping(model)
quant_block_list = getattr(quantization_config, "quant_block_list", None)
if quant_block_list is not None:
# Handle nested list format: [[block1, block2, ...], ...] -> [prefix1, ...]
if quant_block_list and isinstance(quant_block_list[0], (list, tuple)):
for i in range(len(quant_block_list)):
quant_block_list[i] = os.path.commonprefix(quant_block_list[i]).rstrip(".")
quant_block_list[i] = apply_checkpoint_conversion_mapping(
os.path.commonprefix(quant_block_list[i]).rstrip("."), checkpoint_conversion_mapping
)
elif quant_block_list is None:
to_quant_block_names = getattr(quantization_config, "block_name_to_quantize", None) # Prioritize this parameter
if to_quant_block_names is None:
Expand All @@ -292,6 +297,10 @@ def get_layer_config(model, quantization_config):
# Speed up the matching
for i in range(len(quant_block_list)):
quant_block_list[i] = os.path.commonprefix(quant_block_list[i]).rstrip(".")
for i in range(len(quant_block_list)):
quant_block_list[i] = apply_checkpoint_conversion_mapping(
quant_block_list[i], checkpoint_conversion_mapping
)

# Get layer names that will be quantized
layer_names = []
Expand Down
65 changes: 65 additions & 0 deletions auto_round/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,68 @@ def infer_bits_by_data_type(data_type: str):
if str.isdigit(data_type[len(supported_dtype)]):
return int(data_type[len(supported_dtype)])
return None


def get_checkpoint_conversion_mapping(model):
"""Get the checkpoint conversion mapping for a given model, if it exists."""
checkpoint_conversion_mapping = {}

# transformers <= 5.3.0 use _checkpoint_conversion_mapping
checkpoint_conversion_mapping.update(getattr(model, "_checkpoint_conversion_mapping", {}))

# transformers > 5.3.0 use get_checkpoint_conversion_mapping
if hasattr(transformers, "conversion_mapping") and (
hasattr(model, "config") and hasattr(model.config, "model_type")
):
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping as transformers_get_checkpoint_conversion_mapping,
)

conversion_mappings = transformers_get_checkpoint_conversion_mapping(model.config.model_type)
if conversion_mappings is not None:
for conversion_mapping in conversion_mappings:
for source_pattern in conversion_mapping.source_patterns:
checkpoint_conversion_mapping[source_pattern] = conversion_mapping.target_patterns
return checkpoint_conversion_mapping


def get_reverse_checkpoint_conversion_mapping(model):
"""Get the reverse checkpoint conversion mapping for a given model, if it exists."""
reverse_checkpoint_conversion_mapping = {
v: k for k, v in getattr(model, "_checkpoint_conversion_mapping", {}).items()
}

if hasattr(model, "_weight_conversions"):
weight_conversions = model._weight_conversions
for weight_conversion in weight_conversions:
reverse_conversion_mapping = weight_conversion.reverse_transform()
for source_pattern in reverse_conversion_mapping.source_patterns:
reverse_checkpoint_conversion_mapping[source_pattern] = reverse_conversion_mapping.target_patterns

return reverse_checkpoint_conversion_mapping


def revert_checkpoint_conversion_mapping(name: str, key_mapping: dict[str, str]) -> str:
for source_pattern, target_patterns in key_mapping.items():
if isinstance(target_patterns, str):
target_patterns = [target_patterns]
for target_pattern in target_patterns:
source_pattern = source_pattern.lstrip("^") # strip off un-needed chars and patterns
source_pattern = re.sub(r"\(.*\)", "", source_pattern)
name, n_replace = re.subn(source_pattern, target_pattern, name)
# Early exit of the loop
if n_replace > 0:
return name
return name


def apply_checkpoint_conversion_mapping(name: str, key_mapping: dict[str, str]) -> str:
for source_pattern, target_patterns in key_mapping.items():
if isinstance(target_patterns, str):
target_patterns = [target_patterns]
for target_pattern in target_patterns:
name, n_replace = re.subn(source_pattern, target_pattern, name)
# Early exit of the loop
if n_replace > 0:
return name
return name
1 change: 1 addition & 0 deletions test/test_cpu/models/test_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_qwen2_5(self, tiny_qwen_2_5_vl_model_path):
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
quantized_model_path, torch_dtype="auto", device_map="auto"
)
assert model.config.quantization_config.block_name_to_quantize == "model.visual.blocks,model.layers"
image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
processor = AutoProcessor.from_pretrained(quantized_model_path)
messages = [
Expand Down
34 changes: 34 additions & 0 deletions test/test_cuda/integrations/test_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,40 @@ def test_mixed_ar_format_sglang(self, dataloader):

shutil.rmtree(self.save_dir, ignore_errors=True)

def test_qwen2_5_vl_loading(self, tiny_qwen_2_5_vl_model_path):
from auto_round.utils import mllm_load_model

layer_config = {
"self_attn": {"bits": 8},
"lm_head": {"bits": 16},
"mlp": {"bits": 16, "act_bits": 16},
}

model, processor, tokenizer, image_processor = mllm_load_model(tiny_qwen_2_5_vl_model_path)

autoround = AutoRound(
model,
tokenizer,
scheme="W4A16",
iters=1,
nsamples=1,
seqlen=32,
processor=processor,
image_processor=image_processor,
layer_config=layer_config,
)

_, quantized_model_path = autoround.quantize_and_save(
output_dir=self.save_dir,
inplace=True,
format="auto_round",
)

generated_text = self._run_sglang_inference(quantized_model_path)
print(generated_text)

assert "!!!" not in generated_text

@pytest.mark.skip_ci(reason="Cannot work well in CI env")
def test_awq_format_sglang(self, dataloader):
autoround = AutoRound(
Expand Down
Loading