diff --git a/applications/Chat/coati/models/bloom/bloom_lm.py b/applications/Chat/coati/models/bloom/bloom_lm.py index 628af2e341a2..e4184fcd0d9c 100644 --- a/applications/Chat/coati/models/bloom/bloom_lm.py +++ b/applications/Chat/coati/models/bloom/bloom_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_lm.py b/applications/Chat/coati/models/gpt/gpt_lm.py index 23fc13bf23a4..c558d7e9ea8d 100644 --- a/applications/Chat/coati/models/gpt/gpt_lm.py +++ b/applications/Chat/coati/models/gpt/gpt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_lm.py b/applications/Chat/coati/models/opt/opt_lm.py index 65d79e1b2307..47afae847f13 100644 --- a/applications/Chat/coati/models/opt/opt_lm.py +++ b/applications/Chat/coati/models/opt/opt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)