Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions auto_round/algorithms/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,6 @@ def _resolve_block_forward(self):
self.config.is_act_quantize and (not self.config.act_dynamic or self.config.is_act_nv_fp)
) or self.enable_alg_ext:
self._resolved_block_forward = block_forward
elif self.compress_context.enable_torch_compile:
compiled = self.__dict__.get("_compiled_block_forward")
if compiled is None:
compiled = compile_func(block_forward, self.compress_context.device)
self._compiled_block_forward = compiled
self._resolved_block_forward = compiled
else:
self._resolved_block_forward = block_forward
Comment on lines 377 to 381
return self._resolved_block_forward
Expand Down
28 changes: 14 additions & 14 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,20 +533,20 @@ def __init__(
self.enable_torch_compile = enable_torch_compile
self._adjust_torch_compile(enable_torch_compile)

if (
(self.act_bits < 16 and (not self.act_dynamic or self.data_type == "nvfp")) # have hooks
or self.enable_alg_ext # Use imatrix
or not self.disable_opt_rtn # Use imatrix
):
self.block_forward = block_forward
else:
# TODO FIXME
# This function could not be compiled, causing a large accuracy drop when `enable_alg_ext` is used.
# To avoid issues, remove it in all scenarios except WOQ.
self.block_forward = (
compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
)

# if (
# (self.act_bits < 16 and (not self.act_dynamic or self.data_type == "nvfp")) # have hooks
# or self.enable_alg_ext # Use imatrix
# or not self.disable_opt_rtn # Use imatrix
# ):
# self.block_forward = block_forward
# else:
# # TODO FIXME
# # This function could not be compiled, causing a large accuracy drop when `enable_alg_ext` is used.
# # To avoid issues, remove it in all scenarios except WOQ.
# self.block_forward = (
# compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
# )
Comment on lines +536 to +548
self.block_forward = block_forward
self._check_configs()
torch.set_printoptions(precision=3, sci_mode=True)

Expand Down
9 changes: 5 additions & 4 deletions auto_round/compressors_new/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,11 @@ def _hardware_setup(self) -> None:
# Only compile block_forward when it will actually be used (calibration path).
# For zero-shot compressors (need_calib=False), block_forward is never called,
# so skipping compilation avoids unnecessary HPU workspace allocation.
if self.enable_torch_compile and not _needs_plain_forward and self.need_calib:
self.block_forward = compile_func(block_forward, self.compress_context.device)
else:
self.block_forward = block_forward
# if self.enable_torch_compile and not _needs_plain_forward and self.need_calib:
# self.block_forward = compile_func(block_forward, self.compress_context.device)
# else:
# self.block_forward = block_forward
Comment on lines +974 to +977
self.block_forward = block_forward
if self.compress_context.low_cpu_mem_usage:
self._offloader.reset()

Expand Down
Loading