diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index cd565031e112..70bc13386cf3 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaModel from ..base import Critic @@ -28,15 +27,12 @@ def __init__(self, **kwargs) -> None: if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) + model = LlamaModel.from_pretrained(pretrained) elif config is not None: - model = LlamaForCausalLM(config) + model = LlamaModel(config) else: - model = LlamaForCausalLM(LlamaConfig()) - + model = LlamaModel(LlamaConfig()) if checkpoint: model.gradient_checkpointing_enable() - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)