diff --git a/auto_round/algorithms/quantization/base.py b/auto_round/algorithms/quantization/base.py index 12d972554..13739dbb0 100644 --- a/auto_round/algorithms/quantization/base.py +++ b/auto_round/algorithms/quantization/base.py @@ -377,12 +377,6 @@ def _resolve_block_forward(self): self.config.is_act_quantize and (not self.config.act_dynamic or self.config.is_act_nv_fp) ) or self.enable_alg_ext: self._resolved_block_forward = block_forward - elif self.compress_context.enable_torch_compile: - compiled = self.__dict__.get("_compiled_block_forward") - if compiled is None: - compiled = compile_func(block_forward, self.compress_context.device) - self._compiled_block_forward = compiled - self._resolved_block_forward = compiled else: self._resolved_block_forward = block_forward return self._resolved_block_forward diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 0503c8235..03b76135a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -533,20 +533,20 @@ def __init__( self.enable_torch_compile = enable_torch_compile self._adjust_torch_compile(enable_torch_compile) - if ( - (self.act_bits < 16 and (not self.act_dynamic or self.data_type == "nvfp")) # have hooks - or self.enable_alg_ext # Use imatrix - or not self.disable_opt_rtn # Use imatrix - ): - self.block_forward = block_forward - else: - # TODO FIXME - # This function could not be compiled, causing a large accuracy drop when `enable_alg_ext` is used. - # To avoid issues, remove it in all scenarios except WOQ. - self.block_forward = ( - compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward - ) - + # if ( + # (self.act_bits < 16 and (not self.act_dynamic or self.data_type == "nvfp")) # have hooks + # or self.enable_alg_ext # Use imatrix + # or not self.disable_opt_rtn # Use imatrix + # ): + # self.block_forward = block_forward + # else: + # # TODO FIXME + # # This function could not be compiled, causing a large accuracy drop when `enable_alg_ext` is used. + # # To avoid issues, remove it in all scenarios except WOQ. + # self.block_forward = ( + # compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward + # ) + self.block_forward = block_forward self._check_configs() torch.set_printoptions(precision=3, sci_mode=True) diff --git a/auto_round/compressors_new/base.py b/auto_round/compressors_new/base.py index 025dfbd3f..ca332a50a 100644 --- a/auto_round/compressors_new/base.py +++ b/auto_round/compressors_new/base.py @@ -971,10 +971,11 @@ def _hardware_setup(self) -> None: # Only compile block_forward when it will actually be used (calibration path). # For zero-shot compressors (need_calib=False), block_forward is never called, # so skipping compilation avoids unnecessary HPU workspace allocation. - if self.enable_torch_compile and not _needs_plain_forward and self.need_calib: - self.block_forward = compile_func(block_forward, self.compress_context.device) - else: - self.block_forward = block_forward + # if self.enable_torch_compile and not _needs_plain_forward and self.need_calib: + # self.block_forward = compile_func(block_forward, self.compress_context.device) + # else: + # self.block_forward = block_forward + self.block_forward = block_forward if self.compress_context.low_cpu_mem_usage: self._offloader.reset()