From 22f3bbc9759cdaf732ca2934489474732c959e83 Mon Sep 17 00:00:00 2001 From: Yuanchen Xu Date: Tue, 4 Apr 2023 09:17:41 +0800 Subject: [PATCH] fix sft training for bloom, gpt and opt --- applications/Chat/coati/models/bloom/bloom_lm.py | 3 +++ applications/Chat/coati/models/gpt/gpt_lm.py | 3 +++ applications/Chat/coati/models/opt/opt_lm.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/applications/Chat/coati/models/bloom/bloom_lm.py b/applications/Chat/coati/models/bloom/bloom_lm.py index 628af2e341a2..e4184fcd0d9c 100644 --- a/applications/Chat/coati/models/bloom/bloom_lm.py +++ b/applications/Chat/coati/models/bloom/bloom_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_lm.py b/applications/Chat/coati/models/gpt/gpt_lm.py index 23fc13bf23a4..c558d7e9ea8d 100644 --- a/applications/Chat/coati/models/gpt/gpt_lm.py +++ b/applications/Chat/coati/models/gpt/gpt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_lm.py b/applications/Chat/coati/models/opt/opt_lm.py index 65d79e1b2307..47afae847f13 100644 --- a/applications/Chat/coati/models/opt/opt_lm.py +++ b/applications/Chat/coati/models/opt/opt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)