From efd6de273c093bdca29ddcade154096b27a8cb10 Mon Sep 17 00:00:00 2001 From: Guoming Zhang <37257613+nv-guomingz@users.noreply.github.com> Date: Wed, 31 Jul 2024 16:21:11 +0000 Subject: [PATCH] add missing attribute _supports_param_buffer_assignment for gpt-j. --- src/transformers/models/gptj/modeling_gptj.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 72a00f3ebe68..fa658d9e0578 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -500,6 +500,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_param_buffer_assignment = False def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs)