-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathGPT_Neo.py
More file actions
73 lines (53 loc) · 1.92 KB
/
GPT_Neo.py
File metadata and controls
73 lines (53 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
from transformers import pipeline
import sys
import torch
models = ["gpt-neo-125M", "gpt-neo-1.3B", "gpt-neo-2.7B"]
current_model_name = "EleutherAI/" + models[0]
def process_bot_answer(input_text, text_length=50):
tokenizer = GPT2Tokenizer.from_pretrained(current_model_name)
model = GPTNeoForCausalLM.from_pretrained(current_model_name)
model.to('cuda')
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cpu")
generator = pipeline('text-generation', model=current_model_name)
# A 3000 value will produce a buffer overflow so we need to prevent that.
sample_outputs = model.generate(
input_ids,
do_sample=True,
max_length=text_length,
top_k=50,
top_p=0.95,
temperature=0.9,
num_return_sequences=1,
pad_token_id=generator.tokenizer.eos_token_id
)
output_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
output_list = output_text.split("\n")
for phrase in output_list:
if not output_list[-1].endswith("."):
output_list.pop()
phrase = phrase[2:]
output_list.pop(0)
text = ""
for phrase in output_list:
text += " " + phrase
# Clean up resources
del model
del tokenizer
torch.cuda.empty_cache()
return text.lstrip()
def main():
if len(sys.argv) < 2:
print("Please provide a text prompt as the first argument.")
return
if len(sys.argv) == 2:
input_text = sys.argv[1]
# Using default model as language_model
elif len(sys.argv) == 3:
input_text = sys.argv[1]
text_length = int(sys.argv[2])
print(process_bot_answer(input_text, text_length))
return(process_bot_answer(input_text, text_length))
if __name__ == '__main__':
main()