From 193014da52baff26646cf60f62b81bd0f54f028f Mon Sep 17 00:00:00 2001 From: gongenlei Date: Thu, 6 Apr 2023 09:39:13 +0000 Subject: [PATCH 1/2] mv LlamaForCausalLM to LlamaModel --- applications/Chat/coati/models/llama/llama_critic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index cd565031e112..ba08306dced0 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import AutoModelForCausalLM, LlamaConfig, LlamaModel from ..base import Critic @@ -28,11 +28,11 @@ def __init__(self, **kwargs) -> None: if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) + model = LlamaModel.from_pretrained(pretrained) elif config is not None: - model = LlamaForCausalLM(config) + model = LlamaModel(config) else: - model = LlamaForCausalLM(LlamaConfig()) + model = LlamaModel(LlamaConfig()) if checkpoint: model.gradient_checkpointing_enable() From 9478b6c1e14076ea3260d42bbe2ad2eb08a186f1 Mon Sep 17 00:00:00 2001 From: gongenlei Date: Fri, 7 Apr 2023 03:29:51 +0000 Subject: [PATCH 2/2] rm unused imports --- applications/Chat/coati/models/llama/llama_critic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index ba08306dced0..dd9e5e7bfa1a 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaModel +from transformers import LlamaConfig, LlamaModel from ..base import Critic