Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to grpc for deploy and eval #11643

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
19 changes: 12 additions & 7 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def deploy(
model_type: str = "llama",
triton_model_name: str = "triton_model",
triton_model_version: Optional[int] = 1,
triton_port: int = 8000,
triton_http_port: int = 8000,
triton_grpc_port: int = 8001,
triton_http_address: str = "0.0.0.0",
HuiyingLi marked this conversation as resolved.
Show resolved Hide resolved
triton_request_timeout: int = 60,
triton_model_repository: Path = None,
Expand Down Expand Up @@ -381,12 +382,12 @@ def deploy(

unset_environment_variables()
if start_rest_service:
if triton_port == rest_service_port:
if triton_http_port == rest_service_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py
os.environ["TRITON_HTTP_ADDRESS"] = triton_http_address
os.environ["TRITON_PORT"] = str(triton_port)
os.environ["TRITON_PORT"] = str(triton_http_port)
os.environ["TRITON_REQUEST_TIMEOUT"] = str(triton_request_timeout)
os.environ["OPENAI_FORMAT_RESPONSE"] = str(openai_format_response)
os.environ["OUTPUT_GENERATION_LOGITS"] = str(output_generation_logits)
Expand All @@ -411,7 +412,8 @@ def deploy(
triton_model_name=triton_model_name,
triton_model_version=triton_model_version,
max_batch_size=max_batch_size,
port=triton_port,
http_port=triton_http_port,
grpc_port=triton_grpc_port,
address=triton_http_address,
)

Expand Down Expand Up @@ -453,7 +455,8 @@ def deploy(

def evaluate(
nemo_checkpoint_path: Path,
url: str = "http://0.0.0.0:8080/v1",
url: str = "grpc://0.0.0.0:8001",
rest_url: str = None,
model_name: str = "triton_model",
eval_task: str = "gsm8k",
num_fewshot: Optional[int] = None,
Expand All @@ -473,7 +476,8 @@ def evaluate(
Args:
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
which is required to tokenize the evaluation input and output prompts.
url (str): rest service url and port that were used in the deploy method above in the format:
url (str): grpc service url that were used in the deploy method above in the format: grpc://{grpc_service_ip}:{grpc_port}.
rest_url (str): rest service url and port that were used in the deploy method above in the format:
http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts
(from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server.
The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server.
Expand Down Expand Up @@ -514,7 +518,8 @@ def evaluate(
# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
# Wait for rest service to be ready before starting evaluation
evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health")
if rest_url is not None:
evaluation.wait_for_rest_service(rest_url=f"{rest_url}/v1/health")
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
model = evaluation.NeMoFWLMEval(
model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/llm/deploy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def get_trtllm_deployable(
trt_llm_exporter.export(
nemo_checkpoint_path=nemo_checkpoint,
model_type=model_type,
n_gpus=num_gpus,
tensor_parallelism_size=tensor_parallelism_size,
pipeline_parallelism_size=pipeline_parallelism_size,
max_input_len=max_input_len,
Expand Down
19 changes: 6 additions & 13 deletions nemo/collections/llm/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from requests.exceptions import RequestException
from tqdm import tqdm

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.llm.evaluation.utils import query_llm
from nemo.utils import logging


Expand All @@ -49,21 +51,12 @@ def _generate_tokens_logits(self, payload, return_text: bool = False, return_log
A private method that sends post request to the model on PyTriton server and returns either generated text or
logits.
"""
# send a post request to /v1/completions/ endpoint with the payload
response = requests.post(f"{self.api_url}/v1/completions/", json=payload)
response_data = response.json()
response = query_llm(url=self.api_url, **payload)

if 'error' in response_data:
raise Exception(f"API Error: {response_data['error']}")

# Assuming the response is in OpenAI format
if return_text:
# in case of generate_until tasks return just the text
return response_data['choices'][0]['text']

return [[x[0].decode("utf-8")] for x in response['outputs']] # shape[batch_size, 1]
if return_logits:
# in case of loglikelihood tasks return the logits
return response_data['choices'][0]['generation_logits']
return response['generation_logits'] # shape[batch_size, 1, num_tokens, vocab_size]

def tokenizer_type(self, tokenizer):
"""
Expand Down Expand Up @@ -93,7 +86,7 @@ def loglikelihood(self, requests: list[Instance]):
special_tokens_kwargs['add_special_tokens'] = self.add_bos

results = []
for request in requests:
for request in tqdm(requests):
# get the input prompt from the request
context = request.arguments[0]
# get the output prompt from the request
Expand Down
107 changes: 107 additions & 0 deletions nemo/collections/llm/evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import numpy as np
from pytriton.client import ModelClient


def str_list2numpy(str_list: List[str]) -> np.ndarray:
""" "
Convert a list of strings to a numpy array of strings.
"""
str_ndarray = np.array(str_list)[..., np.newaxis]
return np.char.encode(str_ndarray, "utf-8")


def query_llm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HuiyingLi can we reuse the query_llm from here so that we can avoid redundant code ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@athitten Thank you for pointing out! Just changed to use the exisiting query_llm.

url: str,
model: str,
prompt: Union[str, List[str]],
output_generation_logits: bool = True,
stop_words_list: List[str] = None,
bad_words_list: List[str] = None,
no_repeat_ngram_size: int = None,
max_tokens: int = 128,
top_k: int = 1,
top_p: float = 0.0,
temperature: float = 1.0,
random_seed: int = None,
task_id: str = None,
lora_uids: str = None,
init_timeout: float = 60.0,
):
"""
A method that sends post request to the model on PyTriton server and returns either generated text or
logits.

Args:
url (str): The URL for the Triton server. Required.
model_name (str): The name of the Triton model. Required.
prompt (str, optional): The prompt to be used. Required if `prompt_file` is not provided.
prompt_file (str, optional): The file path to read the prompt from. Required if `prompt` is not provided.
stop_words_list (str, optional): A list of stop words.
bad_words_list (str, optional): A list of bad words.
no_repeat_ngram_size (int, optional): The size of the n-grams to disallow repeating.
max_output_len (int): The maximum length of the output tokens. Defaults to 128.
top_k (int): The top-k sampling parameter. Defaults to 1.
top_p (float): The top-p sampling parameter. Defaults to 0.0.
temperature (float): The temperature for sampling. Defaults to 1.0.
task_id (str, optional): The task ID for the prompt embedding tables.
"""
prompts = str_list2numpy([prompt] if isinstance(prompt, str) else prompt)
inputs = {"prompts": prompts}

if output_generation_logits:
inputs["output_generation_logits"] = np.full(prompts.shape, output_generation_logits, dtype=np.bool_)

if max_tokens is not None:
inputs["max_output_len"] = np.full(prompts.shape, max_tokens, dtype=np.int_)

if top_k is not None:
inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_)

if top_p is not None:
inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single)

if temperature is not None:
inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single)

if random_seed is not None:
inputs["random_seed"] = np.full(prompts.shape, random_seed, dtype=np.single)

if stop_words_list is not None:
stop_words_list = np.char.encode(stop_words_list, "utf-8")
inputs["stop_words_list"] = np.full((prompts.shape[0], len(stop_words_list)), stop_words_list)

if bad_words_list is not None:
bad_words_list = np.char.encode(bad_words_list, "utf-8")
inputs["bad_words_list"] = np.full((prompts.shape[0], len(bad_words_list)), bad_words_list)

if no_repeat_ngram_size is not None:
inputs["no_repeat_ngram_size"] = np.full(prompts.shape, no_repeat_ngram_size, dtype=np.single)

if task_id is not None:
task_id = np.char.encode(task_id, "utf-8")
inputs["task_id"] = np.full((prompts.shape[0], len([task_id])), task_id)

if lora_uids is not None:
lora_uids = np.char.encode(lora_uids, "utf-8")
inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids)

with ModelClient(url, model, init_timeout_s=init_timeout) as client:
result_dict = client.infer_batch(**inputs)

return result_dict
6 changes: 4 additions & 2 deletions nemo/deploy/deploy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(
checkpoint_path: str = None,
model=None,
max_batch_size: int = 128,
port: int = 8000,
http_port: int = 8000,
grpc_port: int = 8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
Expand All @@ -54,7 +55,8 @@ def __init__(
self.triton_model_version = triton_model_version
self.max_batch_size = max_batch_size
self.model = model
self.port = port
self.http_port = http_port
self.grpc_port = grpc_port
self.address = address
self.triton = None
self.allow_grpc = allow_grpc
Expand Down
10 changes: 7 additions & 3 deletions nemo/deploy/deploy_pytriton.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
checkpoint_path: str = None,
model=None,
max_batch_size: int = 128,
port: int = 8000,
http_port: int = 8000,
grpc_port: int = 8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
Expand All @@ -92,7 +93,8 @@ def __init__(
checkpoint_path=checkpoint_path,
model=model,
max_batch_size=max_batch_size,
port=port,
http_port=http_port,
grpc_port=grpc_port,
address=address,
allow_grpc=allow_grpc,
allow_http=allow_http,
Expand Down Expand Up @@ -128,7 +130,9 @@ def deploy(self):
else:
triton_config = TritonConfig(
http_address=self.address,
http_port=self.port,
http_port=self.http_port,
grpc_address=self.address,
grpc_port=self.grpc_port,
allow_grpc=self.allow_grpc,
allow_http=self.allow_http,
)
Expand Down
Loading