diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 2051f844d888..a67d24f5bbc8 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -231,3 +231,10 @@ __all__.append("deploy") except ImportError as error: logging.warning(f"The deploy module could not be imported: {error}") + +try: + from nemo.collections.llm.api import evaluate + + __all__.append("evaluate") +except ImportError as error: + logging.warning(f"The evaluate module could not be imported: {error}") diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index fdceff5d959e..07899b2ee484 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from copy import deepcopy from pathlib import Path @@ -256,84 +255,13 @@ def validate( return app_state.exp_dir -def get_trtllm_deployable( - nemo_checkpoint, - model_type, - triton_model_repository, - num_gpus, - tensor_parallelism_size, - pipeline_parallelism_size, - max_input_len, - max_output_len, - max_batch_size, - dtype, -): - from nemo.export.tensorrt_llm import TensorRTLLM - - if triton_model_repository is None: - trt_llm_path = "/tmp/trt_llm_model_dir/" - Path(trt_llm_path).mkdir(parents=True, exist_ok=True) - else: - trt_llm_path = triton_model_repository - - if nemo_checkpoint is None and triton_model_repository is None: - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." - ) - - if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." - ) - - if nemo_checkpoint is not None and model_type is None: - raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") - - trt_llm_exporter = TensorRTLLM( - model_dir=trt_llm_path, - load_model=(nemo_checkpoint is None), - ) - - if nemo_checkpoint is not None: - try: - logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") - trt_llm_exporter.export( - nemo_checkpoint_path=nemo_checkpoint, - model_type=model_type, - n_gpus=num_gpus, - tensor_parallelism_size=tensor_parallelism_size, - pipeline_parallelism_size=pipeline_parallelism_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - dtype=dtype, - ) - except Exception as error: - raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) - - return trt_llm_exporter - - -def store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response): - args_dict = { - "triton_service_ip": triton_http_address, - "triton_service_port": triton_port, - "triton_request_timeout": triton_request_timeout, - "openai_format_response": openai_format_response, - } - with open("nemo/deploy/service/config.json", "w") as f: - json.dump(args_dict, f) - - @run.cli.entrypoint(namespace="llm") def deploy( nemo_checkpoint: Path = None, model_type: str = "llama", - triton_model_name: str = "xxx", + triton_model_name: str = 'triton_model', triton_model_version: Optional[int] = 1, - triton_port: int = 8080, + triton_port: int = 8000, triton_http_address: str = "0.0.0.0", triton_request_timeout: int = 60, triton_model_repository: Path = None, @@ -344,21 +272,61 @@ def deploy( max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 8, - start_rest_service: bool = False, + start_rest_service: bool = True, rest_service_http_address: str = "0.0.0.0", - rest_service_port: int = 8000, - openai_format_response: bool = False, + rest_service_port: int = 8080, + openai_format_response: bool = True, + output_generation_logits: bool = True, ): + """ + Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm. + Also starts rest service that is used to send OpenAI API compatible input request + to the PyTiton server. + + Args: + nemo_checkpoint (Path): Path for nemo checkpoint. + model_type (str): Type of the model. Choices: gpt, llama, falcon, starcoder. Default: llama. + triton_model_name (str): Name for the model that gets deployed on PyTriton. Please ensure that the same model + name is passed to the evalute method for the model to be accessible while sending evalution requests. + Default: 'triton_model'. + triton_model_version (Optional[int]): Version for the triton model. Default: 1. + triton_port (int): Port for the PyTriton server. Default: 8000. + triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0". + triton_request_timeout (int): Timeout in seconds for Triton server. Default: 60. + triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified + path. If None, saves it in /tmp dir. Default: None. + num_gpus (int): Number of GPUs for export to trtllm and deploy. Default: 1. + tensor_parallelism_size (int): Tensor parallelism size. Default: 1. + pipeline_parallelism_size (int): Pipeline parallelism size. Default: 1. + dtype (str): dtype of the TensorRT-LLM model. Default: "bfloat16". + max_input_len (int): Max input length of the model. Default: 256. + max_output_len (int): Max output length of the model. Default: 256. + max_batch_size (int): Max batch size of the model. Default: 8. + start_rest_service (bool): Start rest service that is used to send evaluation requests to the PyTriton server. + Needs to be True to be able to run evaluation. Default: True. + rest_service_http_address (str): HTTP address for the rest service. Default: "0.0.0.0". + rest_service_port (int): Port for the rest service. Default: 8080. + openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to + be True while running evaluation. Default: True. + output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True. + generation_logits are used to compute the logProb of the output token. Default: True. + """ + from nemo.collections.llm import deploy from nemo.deploy import DeployPyTriton + deploy.unset_environment_variables() if start_rest_service: if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") return - # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py - store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response) - - triton_deployable = get_trtllm_deployable( + # Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py + os.environ['TRITON_HTTP_ADDRESS'] = triton_http_address + os.environ['TRITON_PORT'] = str(triton_port) + os.environ['TRITON_REQUEST_TIMEOUT'] = str(triton_request_timeout) + os.environ['OPENAI_FORMAT_RESPONSE'] = str(openai_format_response) + os.environ['OUTPUT_GENERATION_LOGITS'] = str(output_generation_logits) + + triton_deployable = deploy.get_trtllm_deployable( nemo_checkpoint, model_type, triton_model_repository, @@ -369,6 +337,7 @@ def deploy( max_output_len, max_batch_size, dtype, + output_generation_logits, ) try: @@ -383,6 +352,7 @@ def deploy( logging.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: logging.error("Error message has occurred during deploy function. Error message: " + str(error)) return @@ -416,6 +386,81 @@ def deploy( nm.stop() +def evaluate( + nemo_checkpoint_path: Path, + url: str = "http://0.0.0.0:8080/v1", + model_name: str = "triton_model", + eval_task: str = "gsm8k", + num_fewshot: Optional[int] = None, + limit: Optional[Union[int, float]] = None, + bootstrap_iters: int = 100000, + # inference params + max_tokens_to_generate: Optional[int] = 256, + temperature: Optional[float] = 0.000000001, + top_p: Optional[float] = 0.0, + top_k: Optional[int] = 1, + add_bos: Optional[bool] = False, +): + """ + Evaluates nemo model deployed on PyTriton server (via trtllm) using lm-evaluation-harness + (https://github.com/EleutherAI/lm-evaluation-harness/tree/main). + + Args: + nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt + which is required to tokenize the evaluation input and output prompts. + url (str): rest service url and port that were used in the deploy method above in the format: + http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts + (from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server. + The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server. + model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as + triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model". + eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k". + These are the tasks that are supported currently. Any other task of type generate_until or loglikelihood from + lm-evaluation-harness can be run, but only the above mentioned ones are tested. Tasks of type + loglikelihood_rolling are not supported yet. + num_fewshot (int): number of examples in few-shot context. Default: None. + limit (Union[int, float]): Limit the number of examples per task. If <1 (i.e float val between 0 and 1), limit + is a percentage of the total number of examples. If int say x, then run evaluation only on x number of samples + from the eval dataset. Default: None, which means eval is run the entire dataset. + bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0 + for no stderr calculations to be performed. Default: 100000. + # inference params + max_tokens_to_generate (int): max tokens to generate. Default: 256. + temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token + with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM + (# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001. + top_p: Optional[float]: float value between 0 and 1. limits to the top tokens within a certain probability. + top_p=0 means the model will only consider the single most likely token for the next prediction. Default: 0.0. + top_k: Optional[int]: limits to a certain number (K) of the top tokens to consider. top_k=1 means the model + will only consider the single most likely token for the next prediction. Default: 1 + add_bos: Optional[bool]: whether a special token representing the beginning of a sequence should be added when + encoding a string. Default: False since typically for CausalLM its set to False. If needed set add_bos to True. + """ + try: + # lm-evaluation-harness import + from lm_eval import evaluator + except ImportError: + raise ImportError( + "Please ensure that lm-evaluation-harness is installed in your env as it is required " "to run evaluations" + ) + + from nemo.collections.llm import evaluation + + # Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt. + tokenizer = io.load_context(nemo_checkpoint_path + '/context', subpath="model").tokenizer + # Wait for rest service to be ready before starting evaluation + evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health") + # Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate + model = evaluation.NeMoFWLMEval( + model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos + ) + results = evaluator.simple_evaluate( + model=model, tasks=eval_task, limit=limit, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters + ) + + print("score", results['results'][eval_task]) + + @run.cli.entrypoint(name="import", namespace="llm") def import_ckpt( model: pl.LightningModule, diff --git a/nemo/collections/llm/deploy/__init__.py b/nemo/collections/llm/deploy/__init__.py new file mode 100644 index 000000000000..24c102bfa0d2 --- /dev/null +++ b/nemo/collections/llm/deploy/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables + +__all__ = ["unset_environment_variables", "get_trtllm_deployable"] diff --git a/nemo/collections/llm/deploy/base.py b/nemo/collections/llm/deploy/base.py new file mode 100644 index 000000000000..e21198f5884b --- /dev/null +++ b/nemo/collections/llm/deploy/base.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from pathlib import Path + +from nemo.utils import logging + + +def unset_environment_variables() -> None: + """ + SLURM_, PMI_, PMIX_ Variables are needed to be unset for trtllm export to work + on clusters. This method takes care of unsetting these env variables + """ + logging.info("Unsetting all SLURM_, PMI_, PMIX_ Variables") + + # Function to unset variables with a specific prefix + def unset_vars_with_prefix(prefix): + unset_vars = [] + cmd = f"env | grep ^{prefix} | cut -d= -f1" + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + vars_to_unset = result.stdout.strip().split('\n') + for var in vars_to_unset: + if var: # Check if the variable name is not empty + os.environ.pop(var, None) + unset_vars.append(var) + return unset_vars + + # Collect all unset variables across all prefixes + all_unset_vars = [] + + # Unset variables for each prefix + for prefix in ['SLURM_', 'PMI_', 'PMIX_']: + unset_vars = unset_vars_with_prefix(prefix) + all_unset_vars.extend(unset_vars) + + if all_unset_vars: + logging.info(f"Unset env variables: {', '.join(all_unset_vars)}") + else: + logging.info("No env variables were unset.") + + +def get_trtllm_deployable( + nemo_checkpoint, + model_type, + triton_model_repository, + num_gpus, + tensor_parallelism_size, + pipeline_parallelism_size, + max_input_len, + max_output_len, + max_batch_size, + dtype, + output_generation_logits, +): + """ + Exports the nemo checkpoint to trtllm and returns trt_llm_exporter that is used to deploy on PyTriton. + """ + from nemo.export.tensorrt_llm import TensorRTLLM + + if triton_model_repository is None: + trt_llm_path = "/tmp/trt_llm_model_dir/" + Path(trt_llm_path).mkdir(parents=True, exist_ok=True) + else: + trt_llm_path = triton_model_repository + + if nemo_checkpoint is None and triton_model_repository is None: + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." + ) + + if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." + ) + + if nemo_checkpoint is not None and model_type is None: + raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + + trt_llm_exporter = TensorRTLLM( + model_dir=trt_llm_path, + load_model=(nemo_checkpoint is None), + ) + + if nemo_checkpoint is not None: + try: + logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") + trt_llm_exporter.export( + nemo_checkpoint_path=nemo_checkpoint, + model_type=model_type, + n_gpus=num_gpus, + tensor_parallelism_size=tensor_parallelism_size, + pipeline_parallelism_size=pipeline_parallelism_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + dtype=dtype, + gather_generation_logits=output_generation_logits, + ) + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + + return trt_llm_exporter diff --git a/nemo/collections/llm/evaluation/__init__.py b/nemo/collections/llm/evaluation/__init__.py new file mode 100644 index 000000000000..3012689bb8da --- /dev/null +++ b/nemo/collections/llm/evaluation/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_rest_service + +__all__ = ["NeMoFWLMEval", "wait_for_rest_service"] diff --git a/nemo/collections/llm/evaluation/base.py b/nemo/collections/llm/evaluation/base.py new file mode 100644 index 000000000000..b1734d6f4d43 --- /dev/null +++ b/nemo/collections/llm/evaluation/base.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import requests +import torch +import torch.nn.functional as F +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from requests.exceptions import RequestException + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.utils import logging + + +class NeMoFWLMEval(LM): + """ + NeMoFWLMEval is a wrapper class subclassing lm_eval.api.model.LM class, that defines how lm_eval interfaces with + NeMo model deployed on PyTriton server. + Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md + """ + + def __init__(self, model_name, api_url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos): + self.model_name = model_name + self.api_url = api_url + self.tokenizer = tokenizer + self.max_tokens_to_generate = max_tokens_to_generate + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.add_bos = add_bos + super().__init__() + + def _generate_tokens_logits(self, payload, return_text: bool = False, return_logits: bool = False): + """ + A private method that sends post request to the model on PyTriton server and returns either generated text or + logits. + """ + # send a post request to /v1/completions/ endpoint with the payload + response = requests.post(f"{self.api_url}/v1/completions/", json=payload) + response_data = response.json() + + if 'error' in response_data: + raise Exception(f"API Error: {response_data['error']}") + + # Assuming the response is in OpenAI format + if return_text: + # in case of generate_until tasks return just the text + return response_data['choices'][0]['text'] + + if return_logits: + # in case of loglikelihood tasks return the logits + return response_data['choices'][0]['generation_logits'] + + def tokenizer_type(self, tokenizer): + """ + Returns the type of the tokenizer. + """ + if isinstance(tokenizer, AutoTokenizer): + return "AutoTokenizer" + elif isinstance(tokenizer, SentencePieceTokenizer): + return "SentencePieceTokenizer" + else: + raise ValueError( + "Tokenizer type is not one of SentencePieceTokenizer or HF's AutoTokenizer. Please check " + "how to handle special tokens for this tokenizer" + ) + + def loglikelihood(self, requests: list[Instance]): + """ + Defines the loglikelihood request. Takes input requests of type list[Instance] where Instance is a dataclass + defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request type(here + loglikelihood) and other relevant args like few shot samples. + """ + special_tokens_kwargs = {} + tokenizer_type = self.tokenizer_type(self.tokenizer) + if tokenizer_type == "SentencePieceTokenizer": + special_tokens_kwargs['add_bos'] = self.add_bos + elif tokenizer_type == "AutoTokenizer": + special_tokens_kwargs['add_special_tokens'] = self.add_bos + + results = [] + for request in requests: + # get the input prompt from the request + context = request.arguments[0] + # get the output prompt from the request + continuation = request.arguments[1] + # get encoded tokens of continuation + continuation_enc = self.tokenizer.tokenizer.encode(continuation, **special_tokens_kwargs) + # for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space. + if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer": + continuation_enc = continuation_enc[1:] + num_cont_tokens = len(continuation_enc) + # Update self.max_tokens_to_generate with number of continuation tokens (or output tokens) in the request + self.max_tokens_to_generate = num_cont_tokens + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": context, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the logits from the model + generation_logits = self._generate_tokens_logits(payload, return_logits=True) + # Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax + multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1) + # Convert encoded continuation tokens to torch tensor + cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0) + # Get the greedy token from the logits (i.e token with the highest prob) + greedy_tokens = multi_logits.argmax(dim=-1) + # Check if all greedy_tokens match the the actual continuation tokens + is_greedy = (greedy_tokens == cont_toks).all() + # Get the logits corresponding to the actual continuation tokens + logits = torch.gather(multi_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) + # result is tuple of logProb of generating the continuation token and is_greedy + result = (float(logits.sum()), bool(is_greedy)) + + results.append(result) + + return results + + def loglikelihood_rolling(self, requests: list[Instance]): + """ + Defines the loglikelihood_rolling request type. Yet to be implemented. + """ + pass + + def generate_until(self, inputs: list[Instance]): + """ + Defines the generate_until request type. Takes input requests of type list[Instance] where Instance is a + dataclass defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request + type(here loglikelihood) and other relevant args like few shot samples. + """ + results = [] + for instance in inputs: + # Access the 'arguments' attribute of the Instance which contains the input prompt string + prompt = instance.arguments[0] + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": prompt, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the text generated by the model + generated_text = self._generate_tokens_logits(payload, return_text=True) + + results.append(generated_text) + + return results + + +def wait_for_rest_service(rest_url, max_retries=60, retry_interval=2): + """ + Wait for REST service to be ready. + + Args: + rest_url (str): URL of the REST service's health endpoint + max_retries (int): Maximum number of retry attempts. Defaul: 60. + retry_interval (int): Time to wait between retries in seconds. Default: 2. + + Returns: + bool: True if rest service is ready, False otherwise + """ + + def check_service(url): + """ + Check if the service is ready by making a GET request to its health endpoint. + + Args: + url (str): URL of the service's health endpoint + + Returns: + bool: True if the service is ready, False otherwise + """ + try: + response = requests.get(url, timeout=5) + return response.status_code == 200 + except RequestException: + return False + + for _ in range(max_retries): + rest_ready = check_service(rest_url) + + if rest_ready: + logging.info("REST service is ready.") + return True + + logging.info(f"REST Service not ready yet. Retrying in {retry_interval} seconds...") + time.sleep(retry_interval) + + logging.info("Timeout: REST service did not become ready.") + return False diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 7e873db6b5b1..e1d21bb54b76 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -174,6 +174,7 @@ def query_llm( end_strings=None, init_timeout=60.0, openai_format_response: bool = False, + output_generation_logits: bool = False, ): """ Query the Triton server synchronously and return a list of responses. @@ -190,6 +191,8 @@ def query_llm( no_repeat_ngram_size (int): no repeat ngram size. task_id (str): downstream task id if virtual tokens are used. init_timeout (flat): timeout for the connection. + openai_format_response: return response similar to OpenAI API format + output_generation_logits: return generation logits from model on PyTriton """ prompts = str_list2numpy(prompts) @@ -248,6 +251,9 @@ def query_llm( if end_strings is not None: inputs["end_strings"] = str_list2numpy(end_strings) + if output_generation_logits is not None: + inputs["output_generation_logits"] = np.full(prompts.shape, output_generation_logits, dtype=np.bool_) + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: result_dict = client.infer_batch(**inputs) output_type = client.model_config.outputs[0].dtype @@ -269,6 +275,9 @@ def query_llm( "model": self.model_name, "choices": [{"text": str(sentences)}], } + # Convert gneration logits to a list to make it json serializable and add it to openai_response dict + if output_generation_logits: + openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"].tolist() return openai_response else: return sentences diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py index fbc774883faa..64afea167295 100644 --- a/nemo/deploy/service/rest_model_api.py +++ b/nemo/deploy/service/rest_model_api.py @@ -8,8 +8,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import json import os from pathlib import Path import requests @@ -19,6 +17,7 @@ from pydantic_settings import BaseSettings from nemo.deploy.nlp import NemoQueryLLM +from nemo.utils import logging class TritonSettings(BaseSettings): @@ -29,14 +28,13 @@ class TritonSettings(BaseSettings): def __init__(self): super(TritonSettings, self).__init__() try: - with open(os.path.join(Path.cwd(), 'nemo/deploy/service/config.json')) as config: - config_json = json.load(config) - self._triton_service_port = config_json["triton_service_port"] - self._triton_service_ip = config_json["triton_service_ip"] - self._triton_request_timeout = config_json["triton_request_timeout"] - self._openai_format_response = config_json["openai_format_response"] + self._triton_service_port = int(os.environ.get('TRITON_PORT', 8080)) + self._triton_service_ip = os.environ.get('TRITON_HTTP_ADDRESS', '0.0.0.0') + self._triton_request_timeout = int(os.environ.get('TRITON_REQUEST_TIMEOUT', 60)) + self._openai_format_response = os.environ.get('OPENAI_FORMAT_RESPONSE', 'False').lower() == 'true' + self._output_generation_logits = os.environ.get('OUTPUT_GENERATION_LOGITS', 'False').lower() == 'true' except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred trying to retrieve set args in TritonSettings class. Error:", error) return @property @@ -54,11 +52,17 @@ def triton_request_timeout(self): @property def openai_format_response(self): """ - Retuns the response from Triton server in OpenAI compatible formar if set to True, - default set in config.json is false. + Retuns the response from Triton server in OpenAI compatible format if set to True. """ return self._openai_format_response + @property + def output_generation_logits(self): + """ + Retuns the generation logits along with text in Triton server output if set to True. + """ + return self._output_generation_logits + app = FastAPI() triton_settings = TritonSettings() @@ -70,19 +74,27 @@ class CompletionRequest(BaseModel): max_tokens: int = 512 temperature: float = 1.0 top_p: float = 0.0 - n: int = 1 + top_k: int = 1 stream: bool = False stop: str | None = None frequency_penalty: float = 1.0 -@app.get("/triton_health") +@app.get("/v1/health") +def health_check(): + return {"status": "ok"} + + +@app.get("/v1/triton_health") async def check_triton_health(): """ This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application. - Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible. + Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible. """ - triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready" + triton_url = ( + f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready" + ) + logging.info(f"Attempting to connect to Triton server at: {triton_url}") try: response = requests.get(triton_url, timeout=5) if response.status_code == 200: @@ -101,11 +113,13 @@ def completions_v1(request: CompletionRequest): output = nq.query_llm( prompts=[request.prompt], max_output_len=request.max_tokens, - top_k=request.n, + # when these below params are passed as None + top_k=request.top_k, top_p=request.top_p, temperature=request.temperature, init_timeout=triton_settings.triton_request_timeout, openai_format_response=triton_settings.openai_format_response, + output_generation_logits=triton_settings.output_generation_logits, ) if triton_settings.openai_format_response: return output @@ -114,5 +128,5 @@ def completions_v1(request: CompletionRequest): "output": output[0][0], } except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred with the post request to /v1/completions/ endpoint:", error) return {"error": "An exception occurred"} diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 08b0b822cad4..a1e6cb0e03c4 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -180,6 +180,8 @@ def export( reduce_fusion: bool = True, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None, + gather_context_logits: Optional[bool] = False, + gather_generation_logits: Optional[bool] = False, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -218,6 +220,8 @@ def export( reduce_fusion (bool): enables fusing extra kernels after custom TRT-LLM allReduce fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. + gather_context_logits (Optional[bool]): if True, enables gather_context_logits while building trtllm engine. Default: False + gather_generation_logits (Optional[bool]): if True, enables gather_generation_logits while building trtllm engine. Default: False """ if n_gpus is not None: warnings.warn( @@ -495,6 +499,8 @@ def get_transformer_config(nemo_model_config): multiple_profiles=multiple_profiles, gpt_attention_plugin=gpt_attention_plugin, gemm_plugin=gemm_plugin, + gather_context_logits=gather_context_logits, + gather_generation_logits=gather_generation_logits, ) tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") @@ -688,6 +694,7 @@ def forward( prompt_embeddings_checkpoint_path: str = None, streaming: bool = False, output_log_probs: bool = False, + output_generation_logits: bool = False, **sampling_kwargs, ): """ @@ -706,6 +713,7 @@ def forward( task_ids (List(str)): list of the task ids for the prompt tables. prompt_embeddings_table (List(float)): prompt embeddings table. prompt_embeddings_checkpoint_path (str): path for the nemo checkpoint for the prompt embedding table. + output_generation_logits (bool): if True returns generation_logits in the outout of generate method. sampling_kwargs: Additional kwargs to set in the SamplingConfig. """ @@ -784,6 +792,7 @@ def forward( no_repeat_ngram_size=no_repeat_ngram_size, output_log_probs=output_log_probs, multiprocessed_env=multiprocessed_env, + output_generation_logits=output_generation_logits, **sampling_kwargs, ) else: @@ -862,16 +871,21 @@ def get_triton_input(self): Tensor(name="no_repeat_ngram_size", shape=(-1,), dtype=np.single, optional=True), Tensor(name="task_id", shape=(-1,), dtype=bytes, optional=True), Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True), + Tensor(name="output_generation_logits", shape=(-1,), dtype=np.bool_, optional=False), ) return inputs @property def get_triton_output(self): - outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + outputs = ( + Tensor(name="outputs", shape=(-1,), dtype=bytes), + Tensor(name="generation_logits", shape=(-1,), dtype=np.single), + ) return outputs @batch def triton_infer_fn(self, **inputs: np.ndarray): + output_dict = {} try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -898,14 +912,20 @@ def triton_infer_fn(self, **inputs: np.ndarray): if "lora_uids" in inputs: lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8") infer_input["lora_uids"] = lora_uids[0].tolist() + if "output_generation_logits" in inputs: + infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0] - output_texts = self.forward(**infer_input) - output = cast_output(output_texts, np.bytes_) + if infer_input["output_generation_logits"]: + output_texts, generation_logits = self.forward(**infer_input) + output_dict["generation_logits"] = np.array(generation_logits.cpu().numpy()) + else: + output_texts = self.forward(**infer_input) + output_dict["outputs"] = cast_output(output_texts, np.bytes_) except Exception as error: err_msg = "An error occurred: {0}".format(str(error)) - output = cast_output([err_msg], np.bytes_) + output_dict["outputs"] = cast_output([err_msg], np.bytes_) - return {"outputs": output} + return output_dict @batch def triton_infer_fn_streaming(self, **inputs: np.ndarray): diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 4be2d42ebe4d..38fb80ca3272 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -54,6 +54,8 @@ def build_and_save_engine( gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", reduce_fusion: bool = False, + gather_context_logits: bool = False, + gather_generation_logits: bool = False, ): architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture try: @@ -96,8 +98,8 @@ def build_and_save_engine( 'max_num_tokens': max_num_tokens, 'opt_num_tokens': opt_num_tokens, 'max_prompt_embedding_table_size': max_prompt_embedding_table_size, - 'gather_context_logits': False, - 'gather_generation_logits': False, + 'gather_context_logits': gather_context_logits, + 'gather_generation_logits': gather_generation_logits, 'strongly_typed': False, 'builder_opt': None, 'use_refit': use_refit, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index bd7b8abd5f9e..ef67c918290f 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -647,6 +647,7 @@ def generate( streaming: bool = False, output_log_probs=False, multiprocessed_env=False, + output_generation_logits=False, **sampling_kwargs, ) -> Optional[List[List[str]]]: """Generate the output sequence from the input sequence. @@ -692,6 +693,7 @@ def generate( multiprocessed_env=multiprocessed_env, **sampling_kwargs, ) + assert outputs is not None if tensorrt_llm.mpi_rank() != 0: return None @@ -705,8 +707,8 @@ def generate( for b in range(output_ids.shape[0]) ] - if output_log_probs: - return output_lines_list, log_probs + if output_generation_logits: + return output_lines_list, outputs['generation_logits'] return output_lines_list diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index e3394726fa1c..154ffc90dc9c 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -419,13 +419,14 @@ def nemo_deploy(argv): LOGGER.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) return try: LOGGER.info("Model serving on Triton is will be started.") - if args.start_rest_service == "True": + if args.start_rest_service: try: LOGGER.info("REST service will be started.") uvicorn.run(