From 826dd9407108b56793852e8f856595a2bc1795ce Mon Sep 17 00:00:00 2001 From: fakezeta Date: Tue, 12 Mar 2024 19:31:09 +0100 Subject: [PATCH 1/9] fixes #1775 and #1774 Add BitsAndBytes Quantization and fixes embedding on CUDA devices --- .../transformers/transformers-nvidia.yml | 1 + .../transformers/transformers_server.py | 47 ++++++++++++++----- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/backend/python/common-env/transformers/transformers-nvidia.yml b/backend/python/common-env/transformers/transformers-nvidia.yml index f851677e2f44..7daafe51804a 100644 --- a/backend/python/common-env/transformers/transformers-nvidia.yml +++ b/backend/python/common-env/transformers/transformers-nvidia.yml @@ -30,6 +30,7 @@ dependencies: - async-timeout==4.0.3 - attrs==23.1.0 - bark==0.1.5 + - bitsandbytes==0.43.0 - boto3==1.28.61 - botocore==1.31.61 - certifi==2023.7.22 diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 41112c44f6e5..31a606a12c05 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -23,7 +23,7 @@ from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM from transformers import AutoTokenizer, AutoModel, set_seed else: - from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed + from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -75,17 +75,34 @@ def LoadModel(self, request, context): A Result object that contains the result of the LoadModel operation. """ model_name = request.Model + + compute = torch.float32 + if request.F16Memory == True: + compute=torch.bfloat16 + + self.CUDA = request.CUDA + + device_map="cpu" + + quantization = BitsAndBytesConfig( + load_in_4_bit=request.LowVRAM, + bnb_4bit_compute_dtype = compute, + ) + + if self.CUDA: + device_map="cuda" + try: if request.Type == "AutoModelForCausalLM": if XPU: self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, device_map="xpu", load_in_4bit=True) else: - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode) + self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map) else: self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.CUDA = False self.XPU = False @@ -97,13 +114,13 @@ def LoadModel(self, request, context): except Exception as err: print("Not using XPU:", err, file=sys.stderr) - if request.CUDA or torch.cuda.is_available(): - try: - print("Loading model", model_name, "to CUDA.", file=sys.stderr) - self.model = self.model.to("cuda") - self.CUDA = True - except Exception as err: - print("Not using CUDA:", err, file=sys.stderr) + # if request.CUDA or torch.cuda.is_available(): + # try: + # print("Loading model", model_name, "to CUDA.", file=sys.stderr) + # self.model = self.model.to("cuda") + # self.CUDA = True + # except Exception as err: + # print("Not using CUDA:", err, file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service @@ -130,13 +147,17 @@ def Embedding(self, request, context): encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt") # Create word embeddings - model_output = self.model(**encoded_input) + if self.CUDA: + encoded_input = encoded_input.to("cuda") + + with torch.no_grad(): + model_output = self.model(**encoded_input) # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence - sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy() + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) print("Embeddings:", sentence_embeddings, file=sys.stderr) - return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) + return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) def Predict(self, request, context): """ From 2f73c8de39b55b60210ae44ef08be4995ed61a20 Mon Sep 17 00:00:00 2001 From: fakezeta Date: Thu, 14 Mar 2024 15:43:49 +0100 Subject: [PATCH 2/9] Manage 4bit and 8 bit quantization Manage different BitsAndBytes options with the quantization: parameter in yaml --- .../transformers/transformers_server.py | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 31a606a12c05..deb23a2abc20 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -76,7 +76,7 @@ def LoadModel(self, request, context): """ model_name = request.Model - compute = torch.float32 + compute = "auto" if request.F16Memory == True: compute=torch.bfloat16 @@ -84,26 +84,40 @@ def LoadModel(self, request, context): device_map="cpu" - quantization = BitsAndBytesConfig( - load_in_4_bit=request.LowVRAM, - bnb_4bit_compute_dtype = compute, - ) - if self.CUDA: - device_map="cuda" - + if request.Device: + device_map=request.Device + else: + device_map="cuda:0" + if request.Quantization == "bnb_4bit": + quantization = BitsAndBytesConfig( + load_in_4bit = True, + bnb_4bit_compute_dtype = compute, + bnb_4bit_quant_type = "nf4", + bnb_4bit_use_double_quant = True, + load_in_8bit = False, + ) + elif request.Quantization == "bnb_8bit": + quantization = BitsAndBytesConfig( + load_in_4bit=False, + bnb_4bit_compute_dtype = None, + load_in_8bit=True, + ) + else: + quantization = None + try: if request.Type == "AutoModelForCausalLM": if XPU: + if quantization == "xpu_4bit": + xpu_4bit = True self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, - device_map="xpu", load_in_4bit=True) + device_map="xpu", load_in_4bit=xpu_4bit) else: - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map) + self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) else: - self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode) - + self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) - self.CUDA = False self.XPU = False if XPU: @@ -114,13 +128,6 @@ def LoadModel(self, request, context): except Exception as err: print("Not using XPU:", err, file=sys.stderr) - # if request.CUDA or torch.cuda.is_available(): - # try: - # print("Loading model", model_name, "to CUDA.", file=sys.stderr) - # self.model = self.model.to("cuda") - # self.CUDA = True - # except Exception as err: - # print("Not using CUDA:", err, file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service @@ -184,12 +191,8 @@ def Predict(self, request, context): if XPU: inputs = inputs.to("xpu") - outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP) - - generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] - # Remove prompt from response if present - if request.Prompt in generated_text: - generated_text = generated_text.replace(request.Prompt, "") + outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, do_sample=True, pad_token_id=self.tokenizer.eos_token_id) + generated_text = self.tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) From fbdbc58f8440337b03f5add407202149ec05138a Mon Sep 17 00:00:00 2001 From: fakezeta Date: Thu, 14 Mar 2024 18:55:47 +0100 Subject: [PATCH 3/9] fix compilation errors on non CUDA environment --- backend/python/transformers/transformers_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index deb23a2abc20..264e7fad990b 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -84,6 +84,8 @@ def LoadModel(self, request, context): device_map="cpu" + quantization = None + if self.CUDA: if request.Device: device_map=request.Device @@ -103,8 +105,7 @@ def LoadModel(self, request, context): bnb_4bit_compute_dtype = None, load_in_8bit=True, ) - else: - quantization = None + try: if request.Type == "AutoModelForCausalLM": From 9c1059ac8e953c5e30c016cd8b0a4bd49d97703f Mon Sep 17 00:00:00 2001 From: fakezeta Date: Mon, 18 Mar 2024 00:21:56 +0100 Subject: [PATCH 4/9] OpenVINO draft First draft of OpenVINO integration in transformer backend --- .../common-env/transformers/transformers.yml | 4 +++ .../transformers/transformers_server.py | 34 +++++++++++++++---- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index 5726abaf37c3..e47e83d6661a 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -56,6 +56,10 @@ dependencies: - multiprocess==0.70.15 - networkx - numpy==1.26.0 + - onnx==1.15.0 + - openvino==2024.0.0 + - openvino-telemetry==2023.2.1 + - git+https://github.com/huggingface/optimum-intel.git - packaging==23.2 - pandas - peft==0.5.0 diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 264e7fad990b..50585c26f31f 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -22,6 +22,8 @@ import intel_extension_for_pytorch as ipex from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM from transformers import AutoTokenizer, AutoModel, set_seed + from optimum.intel.openvino import OVModelForCausalLM + from openvino.runtime import Core else: from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig @@ -81,6 +83,7 @@ def LoadModel(self, request, context): compute=torch.bfloat16 self.CUDA = request.CUDA + self.OV=False device_map="cpu" @@ -105,17 +108,33 @@ def LoadModel(self, request, context): bnb_4bit_compute_dtype = None, load_in_8bit=True, ) - - + try: if request.Type == "AutoModelForCausalLM": if XPU: - if quantization == "xpu_4bit": + device_map="xpu" + compute=torch.float16 + if request.Quantization == "xpu_4bit": xpu_4bit = True - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, - device_map="xpu", load_in_4bit=xpu_4bit) + xpu_8bit = False + elif request.Quantization == "xpu_8bit": + xpu_4bit = False + xpu_8bit = True + else: + xpu_4bit = False + xpu_8bit = False + + self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, + device_map=device_map, load_in_4bit=xpu_4bit, load_in_8bit=xpu_8bit, torch_dtype=compute) else: self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) + elif request.Type == "OVModelForCausalLM": + if "GPU" in Core().available_devices: + device_map="GPU" + else: + device_map="CPU" + self.model = OVModelForCausalLM(model_name, compile=True, device=device_map) + self.OV = True else: self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) @@ -189,8 +208,9 @@ def Predict(self, request, context): inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids if self.CUDA: inputs = inputs.to("cuda") - if XPU: - inputs = inputs.to("xpu") + if XPU and self.OV == False: + inputs = inputs.to("xpu") + outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, do_sample=True, pad_token_id=self.tokenizer.eos_token_id) generated_text = self.tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] From 44f2f9ed9379a3886ba01848801cc5e7f2a4d424 Mon Sep 17 00:00:00 2001 From: fakezeta Date: Wed, 20 Mar 2024 13:31:55 +0100 Subject: [PATCH 5/9] first working implementation --- .../common-env/transformers/transformers.yml | 2 + .../transformers/transformers_server.py | 48 ++++++++++++++----- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index e47e83d6661a..327533307e93 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -33,6 +33,7 @@ dependencies: - boto3==1.28.61 - botocore==1.31.61 - certifi==2023.7.22 + - coloredlogs==15.0.1 - TTS==0.22.0 - charset-normalizer==3.3.0 - datasets==2.14.5 @@ -47,6 +48,7 @@ dependencies: - funcy==2.0 - grpcio==1.59.0 - huggingface-hub + - humanfriendly==10.0 - idna==3.4 - jinja2==3.1.2 - jmespath==1.0.1 diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 50585c26f31f..de77aeab4b20 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -124,23 +124,40 @@ def LoadModel(self, request, context): xpu_4bit = False xpu_8bit = False - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, - device_map=device_map, load_in_4bit=xpu_4bit, load_in_8bit=xpu_8bit, torch_dtype=compute) + self.model = AutoModelForCausalLM.from_pretrained(model_name, + trust_remote_code=request.TrustRemoteCode, + use_safetensors=True, + device_map=device_map, + load_in_4bit=xpu_4bit, + load_in_8bit=xpu_8bit, + torch_dtype=compute) else: - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) + self.model = AutoModelForCausalLM.from_pretrained(model_name, + trust_remote_code=request.TrustRemoteCode, + use_safetensors=True, + quantization_config=quantization, + device_map=device_map, + torch_dtype=compute) elif request.Type == "OVModelForCausalLM": if "GPU" in Core().available_devices: device_map="GPU" else: device_map="CPU" - self.model = OVModelForCausalLM(model_name, compile=True, device=device_map) + self.model = OVModelForCausalLM.from_pretrained(model_name, + compile=True, + device=device_map) self.OV = True else: - self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) + self.model = AutoModel.from_pretrained(model_name, + trust_remote_code=request.TrustRemoteCode, + use_safetensors=True, + quantization_config=quantization, + device_map=device_map, + torch_dtype=compute) self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.XPU = False - if XPU: + if XPU and self.OV == False: self.XPU = True try: print("Optimizing model", model_name, "to XPU.", file=sys.stderr) @@ -205,15 +222,22 @@ def Predict(self, request, context): if request.Tokens > 0: max_tokens = request.Tokens - inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids if self.CUDA: inputs = inputs.to("cuda") if XPU and self.OV == False: - inputs = inputs.to("xpu") - - - outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, do_sample=True, pad_token_id=self.tokenizer.eos_token_id) - generated_text = self.tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] + inputs = inputs.to("xpu") + + inputs = self.tokenizer(request.Prompt, return_tensors="pt") + outputs = self.model.generate(inputs["input_ids"], + max_new_tokens=max_tokens, + temperature=float(request.Temperature), + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + attention_mask=inputs["attention_mask"], + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id) + generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) From 30fdc7e6497b0db9b78955741909cf7d40ca6f21 Mon Sep 17 00:00:00 2001 From: fakezeta Date: Wed, 20 Mar 2024 15:51:18 +0100 Subject: [PATCH 6/9] Streaming working --- .../transformers/transformers_server.py | 51 ++++++++++++++----- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index de77aeab4b20..36f1de9e31a7 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -16,12 +16,13 @@ import grpc import torch import torch.cuda +from threading import Thread XPU=os.environ.get("XPU", "0") == "1" if XPU: import intel_extension_for_pytorch as ipex from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM - from transformers import AutoTokenizer, AutoModel, set_seed + from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer from optimum.intel.openvino import OVModelForCausalLM from openvino.runtime import Core else: @@ -203,7 +204,7 @@ def Embedding(self, request, context): print("Embeddings:", sentence_embeddings, file=sys.stderr) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0]) - def Predict(self, request, context): + def Predict(self, request, context, streaming=False): """ Generates text based on the given prompt and sampling parameters. @@ -228,17 +229,37 @@ def Predict(self, request, context): inputs = inputs.to("xpu") inputs = self.tokenizer(request.Prompt, return_tensors="pt") - outputs = self.model.generate(inputs["input_ids"], - max_new_tokens=max_tokens, - temperature=float(request.Temperature), - top_p=request.TopP, - top_k=request.TopK, - do_sample=True, - attention_mask=inputs["attention_mask"], - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.eos_token_id) - generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] - + if streaming: + streamer=TextIteratorStreamer(self.tokenizer, + skip_prompt=True, + skip_special_tokens=True) + config=dict(inputs, + max_new_tokens=max_tokens, + temperature=float(request.Temperature), + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + attention_mask=inputs["attention_mask"], + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + streamer=streamer) + thread=Thread(target=self.model.generate, kwargs=config) + thread.start() + generated_text = "" + for new_text in streamer: + generated_text += new_text + yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) + else: + outputs = self.model.generate(inputs["input_ids"], + max_new_tokens=max_tokens, + temperature=float(request.Temperature), + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + attention_mask=inputs["attention_mask"], + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id) + generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) def PredictStream(self, request, context): @@ -252,7 +273,9 @@ def PredictStream(self, request, context): Returns: backend_pb2.Result: The predict stream result. """ - yield self.Predict(request, context) + iterations = self.Predict(request, context, streaming=True) + for iteration in iterations: + yield iteration def serve(address): From d92afe755491fbd17c25ba4388b8503fdfce88e8 Mon Sep 17 00:00:00 2001 From: fakezeta Date: Mon, 25 Mar 2024 12:31:57 +0100 Subject: [PATCH 7/9] Small fix for regression on CUDA and XPU --- .../transformers/transformers_server.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 36f1de9e31a7..93a717d91e34 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -8,6 +8,7 @@ import signal import sys import os +from threading import Thread import time import backend_pb2 @@ -16,7 +17,7 @@ import grpc import torch import torch.cuda -from threading import Thread + XPU=os.environ.get("XPU", "0") == "1" if XPU: @@ -26,7 +27,7 @@ from optimum.intel.openvino import OVModelForCausalLM from openvino.runtime import Core else: - from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig + from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -124,7 +125,6 @@ def LoadModel(self, request, context): else: xpu_4bit = False xpu_8bit = False - self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, @@ -167,6 +167,7 @@ def LoadModel(self, request, context): print("Not using XPU:", err, file=sys.stderr) except Exception as err: + print("Error:", err, file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response @@ -223,19 +224,20 @@ def Predict(self, request, context, streaming=False): if request.Tokens > 0: max_tokens = request.Tokens + inputs = self.tokenizer(request.Prompt, return_tensors="pt") if self.CUDA: inputs = inputs.to("cuda") if XPU and self.OV == False: inputs = inputs.to("xpu") + streaming = False - inputs = self.tokenizer(request.Prompt, return_tensors="pt") if streaming: streamer=TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) config=dict(inputs, max_new_tokens=max_tokens, - temperature=float(request.Temperature), + temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, do_sample=True, @@ -251,15 +253,14 @@ def Predict(self, request, context, streaming=False): yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) else: outputs = self.model.generate(inputs["input_ids"], - max_new_tokens=max_tokens, - temperature=float(request.Temperature), - top_p=request.TopP, - top_k=request.TopK, - do_sample=True, - attention_mask=inputs["attention_mask"], - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.eos_token_id) + max_new_tokens=max_tokens, + temperature=request.Temperature, + top_p=request.TopP, + top_k=request.TopK, + do_sample=True, + pad_token=self.tokenizer.eos_token_id) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] + yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) def PredictStream(self, request, context): From 6c67b375204863eee0e09923ccceb1598dc573a9 Mon Sep 17 00:00:00 2001 From: fakezeta Date: Mon, 25 Mar 2024 12:34:20 +0100 Subject: [PATCH 8/9] use pip version of optimum[openvino] --- backend/python/common-env/transformers/transformers.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index 327533307e93..b32d7b6a6a83 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -61,7 +61,7 @@ dependencies: - onnx==1.15.0 - openvino==2024.0.0 - openvino-telemetry==2023.2.1 - - git+https://github.com/huggingface/optimum-intel.git + - optimum[openvino]==1.17.1 - packaging==23.2 - pandas - peft==0.5.0 From 18747efd60ddd1a7f158e0b8b3188ab0ff2c444c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 26 Mar 2024 18:51:09 +0100 Subject: [PATCH 9/9] Update backend/python/transformers/transformers_server.py Signed-off-by: Ettore Di Giacinto --- backend/python/transformers/transformers_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 93a717d91e34..a87020218f3d 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -260,7 +260,6 @@ def Predict(self, request, context, streaming=False): do_sample=True, pad_token=self.tokenizer.eos_token_id) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] - yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) def PredictStream(self, request, context):