diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 66f2f3363437..c0f32feb9eca 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -209,7 +209,7 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - if ShardConfig.inference_only: + if shard_config.inference_only: policy_location = _INFER_POLICY_LIST.get(full_name, None) else: policy_location = _POLICY_LIST.get(full_name, None) @@ -219,5 +219,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location, ShardConfig.inference_only) + policy = import_policy(policy_location, shard_config.inference_only) return policy(model, shard_config)