Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero


if TYPE_CHECKING:
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest
Expand Down Expand Up @@ -325,7 +324,7 @@ def ptq(
def deploy(
nemo_checkpoint: Path = None,
model_type: str = "llama",
triton_model_name: str = 'triton_model',
triton_model_name: str = "triton_model",
triton_model_version: Optional[int] = 1,
triton_port: int = 8000,
triton_http_address: str = "0.0.0.0",
Expand Down Expand Up @@ -377,22 +376,22 @@ def deploy(
output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True.
generation_logits are used to compute the logProb of the output token. Default: True.
"""
from nemo.collections.llm import deploy
from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables
from nemo.deploy import DeployPyTriton

deploy.unset_environment_variables()
unset_environment_variables()
if start_rest_service:
if triton_port == rest_service_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py
os.environ['TRITON_HTTP_ADDRESS'] = triton_http_address
os.environ['TRITON_PORT'] = str(triton_port)
os.environ['TRITON_REQUEST_TIMEOUT'] = str(triton_request_timeout)
os.environ['OPENAI_FORMAT_RESPONSE'] = str(openai_format_response)
os.environ['OUTPUT_GENERATION_LOGITS'] = str(output_generation_logits)
os.environ["TRITON_HTTP_ADDRESS"] = triton_http_address
os.environ["TRITON_PORT"] = str(triton_port)
os.environ["TRITON_REQUEST_TIMEOUT"] = str(triton_request_timeout)
os.environ["OPENAI_FORMAT_RESPONSE"] = str(openai_format_response)
os.environ["OUTPUT_GENERATION_LOGITS"] = str(output_generation_logits)

triton_deployable = deploy.get_trtllm_deployable(
triton_deployable = get_trtllm_deployable(
nemo_checkpoint,
model_type,
triton_model_repository,
Expand Down Expand Up @@ -513,18 +512,22 @@ def evaluate(
from nemo.collections.llm import evaluation

# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + '/context', subpath="model").tokenizer
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
# Wait for rest service to be ready before starting evaluation
evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health")
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
)
results = evaluator.simple_evaluate(
model=model, tasks=eval_task, limit=limit, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters
model=model,
tasks=eval_task,
limit=limit,
num_fewshot=num_fewshot,
bootstrap_iters=bootstrap_iters,
)

print("score", results['results'][eval_task])
print("score", results["results"][eval_task])


@run.cli.entrypoint(name="import", namespace="llm")
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def generate_until(self, inputs: list[Instance]):
return results


def wait_for_rest_service(rest_url, max_retries=60, retry_interval=2):
def wait_for_rest_service(rest_url, max_retries=600, retry_interval=2):
"""
Wait for REST service to be ready.

Expand Down