Skip to content

Latest commit

 

History

History
 
 

gpt

GPT

This document explains how to build the GPT model using TensorRT-LLM and run on a single GPU, a single node with multiple GPUs or multiple nodes with multiple GPUs.

Overview

The TensorRT-LLM GPT implementation can be found in tensorrt_llm/models/gpt/model.py. The TensorRT-LLM GPT example code is located in examples/gpt. There is one main file:

In addition, there are two shared files in the parent folder examples for inference and evaluation:

Support Matrix

  • FP16
  • FP8
  • Inflight Batching
  • PAGED_KV_CACHE
  • FP8 KV CACHE
  • Tensor Parallel
  • STRONGLY TYPED
  • INT8 SmoothQuant
  • INT8 weight only
  • INT4 weight only

Usage

The next two sections describe how to convert the weights from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.

1. Download weights from HuggingFace Transformers

Please install required packages first:

pip install -r requirements.txt
# Download hf gpt2 model
rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2
pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin && popd

2. Convert weights from HF Transformers to TensorRT-LLM format

The convert_checkpoint.py script converts HF weights to TensorRT-LLM checkpoints. The number of checkpoint files (in .safetensors format) is same to the number of GPUs used to run inference.

# single gpu, dtype float16
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --output_dir gpt2/trt_ckpt/fp16/1-gpu

# 2-way tensor parallelism
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --tp_size 2 \
        --output_dir gpt2/trt_ckpt/fp16/2-gpu

3. Build TensorRT engine(s)

The trtllm-build command builds TensorRT-LLM engines from TensorRT-LLM checkpoints. The checkpoint directory provides the model's weights and architecture configuration. The number of engine files is also same to the number of GPUs used to run inference.

Normally, the trtllm-build command only requires a single GPU, but you can enable parallel building by passing the number of GPUs to the --workers argument.

# Build a single-GPU float16 engine from TensorRT-LLM checkpoint.
# Enable the special TensorRT-LLM GPT Attention plugin (--gpt_attention_plugin) to increase runtime performance.
# It is recommend to use --remove_input_padding along with --gpt_attention_plugin for better performance
trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/fp16/1-gpu

# Build 2-way tensor parallelism engines from TensorRT-LLM checkpoint.
trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/fp16/2-gpu

If the engines are built successfully, you will see output like:

......
[03/12/2024-10:21:08] [TRT] [I] Engine generation completed in 35.9738 seconds.
[03/12/2024-10:21:08] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 212 MiB, GPU 775 MiB
[03/12/2024-10:21:08] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +775, now: CPU 0, GPU 775 (MiB)
[03/12/2024-10:21:09] [TRT] [I] [MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 6600 MiB
[03/12/2024-10:21:09] [TRT-LLM] [I] Total time of building Unnamed Network 0: 00:00:36
[03/12/2024-10:21:09] [TRT-LLM] [I] Serializing engine to gpt2/trt_engines/fp16/1-gpu/rank0.engine...
[03/12/2024-10:21:11] [TRT-LLM] [I] Engine serialized. Total time: 00:00:02
[03/12/2024-10:21:11] [TRT-LLM] [I] Total time of building all engines: 00:00:41

Fused MultiHead Attention (FMHA)

You can enable the FMHA kernels by adding --context_fmha enable to the invocation of trtllm-build.

If you find that the default fp16 accumulation (--context_fmha enable) cannot meet the requirement, you can try to enable fp32 accumulation by adding --context_fmha_fp32_acc enable. However, it is expected to see performance drop.

Note that the FMHA kernels have to be used together with --gpt_attention_plugin float16.

In-flight batching and paged KV cache

If one wants to use in-flight batching in C++ runtime, the engine must be built accordingly. In-flight batching in C++ runtime works only with attention plugin, paged KV cache and with packed data. Hence, the trtllm-build should be called with --gpt_attention_plugin float16, --paged_kv_cache enable, --remove_input_padding enable. It is possible to choose a different precision for --gpt_attention_plugin if the flag is provided separately. One can additionally control the size of the block in paged KV cache using --tokens_per_block=N.

4. Build TensorRT engine(s) with Random Weights

You can build engine(s) using random weights, which is useful for benchmarking. First, the ../generate_checkpoint_config.py script can be used to generate a TensorRT-LLM checkpoint config file:

# Generate an 8-GPU GPT-175B float16 checkpoint config file.
python3 ../generate_checkpoint_config.py --architecture GPTForCausalLM \
        --vocab_size 51200 \
        --hidden_size 12288 \
        --num_hidden_layers 96 \
        --num_attention_heads 96 \
        --dtype float16 \
        --tp_size 8 \
        --output_path gpt_175b/trt_ckpt/fp16/8-gpu/config.json


# Generate a 16-GPU GPT-530B float16 checkpoint config file.
python3 ../generate_checkpoint_config.py --architecture GPTForCausalLM \
        --vocab_size 51200 \
        --hidden_size 20480 \
        --num_hidden_layers 105 \
        --num_attention_heads 128 \
        --dtype float16 \
        --tp_size 16 \
        --output_path gpt_530b/trt_ckpt/fp16/16-gpu/config.json

Then, use trtllm-build command to build engine(s) with random weights and the model architecture specified by the generated config file.

# Build 8-GPU GPT-175B float16 engines using dummy weights, useful for performance tests.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
trtllm-build --model_config gpt_175b/trt_ckpt/fp16/8-gpu/config.json \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --context_fmha enable \
        --gemm_plugin float16 \
        --max_batch_size 256 \
        --output_dir gpt_175b/trt_engines/fp16/8-gpu \
        --workers 8

# Build 16-GPU GPT-530B float16 engines using dummy weights, useful for performance tests.
# Enable several TensorRT-LLM plugins to increase runtime performance. It also helps with build time.
trtllm-build --model_config gpt_530b/trt_ckpt/fp16/16-gpu/config.json \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --context_fmha enable \
        --gemm_plugin float16 \
        --max_batch_size 128 \
        --max_input_len 128 \
        --max_output_len 20 \
        --output_dir gpt_530b/trt_engines/fp16/16-gpu \
        --workers 8

5. Run inference

Single node, single GPU

The ../run.py script can be used to run inference with the built engine(s).

python3 ../run.py --engine_dir gpt2/trt_engines/fp16/1-gpu \
        --tokenizer_dir gpt2 \
        --max_output_len 8

If the engines are run successfully, you will see output like:

......
Input [Text 0]: "Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: " chef before moving to London in the early"

The ../summarize.py script can run the built engines to summarize the articles from the cnn_dailymail dataset. For each summary, the script can compute the ROUGE scores and use the ROUGE-1 score to validate the implementation. By passing --test_trt_llm flag, the script will evaluate TensorRT-LLM engines. You may also pass --test_hf flag to evaluate the HF model.

python3 ../summarize.py --engine_dir gpt2/trt_engines/fp16/1-gpu \
        --hf_model_dir gpt2 \
        --test_trt_llm \
        --test_hf

If the engines are run successfully, you will see output like:

......
[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (total latency: 1.520904541015625 sec)
[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 0)
[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 0.0)
[03/13/2024-05:43:18] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[03/13/2024-05:43:18] [TRT-LLM] [I]   rouge1 : 21.13474087351942
[03/13/2024-05:43:18] [TRT-LLM] [I]   rouge2 : 6.2641616526063775
[03/13/2024-05:43:18] [TRT-LLM] [I]   rougeL : 16.693574311238077
[03/13/2024-05:43:18] [TRT-LLM] [I]   rougeLsum : 18.477384201634088
[03/13/2024-05:43:18] [TRT-LLM] [I] Hugging Face (total latency: 8.76440143585205 sec)
[03/13/2024-05:43:18] [TRT-LLM] [I] HF beam 0 result
[03/13/2024-05:43:18] [TRT-LLM] [I]   rouge1 : 20.834898522466
[03/13/2024-05:43:18] [TRT-LLM] [I]   rouge2 : 5.6914719275508805
[03/13/2024-05:43:18] [TRT-LLM] [I]   rougeL : 16.297064309934132
[03/13/2024-05:43:18] [TRT-LLM] [I]   rougeLsum : 18.018627021792142

Single node, multiple GPUs

To run engines using multiple GPUs on a single node, you can use mpirun as:

mpirun -np 2 \
    python3 ../run.py --engine_dir gpt2/trt_engines/fp16/2-gpu \
        --tokenizer_dir gpt2 \
        --max_output_len 8

# Note that GPT-175B is built with random weights, so the output will also be random
mpirun -np 8 \
    python3 ../run.py --engine_dir gpt_175b/trt_engines/fp16/8-gpu \
        --max_output_len 8

Multiple nodes, multiple GPUs using Slurm

To run engines using multiple nodes, you should use a cluster manager like Slurm. The following section shows how to configure TensorRT-LLM to execute on two nodes using Slurm.

We start by preparing an sbatch script called tensorrt_llm_run.sub. That script contains the following code (you must replace the <REPLACE ...> strings with your own values):

#!/bin/bash
#SBATCH -o logs/tensorrt_llm.out
#SBATCH -e logs/tensorrt_llm.error
#SBATCH -J <REPLACE WITH YOUR JOB's NAME>
#SBATCH -A <REPLACE WITH YOUR ACCOUNT's NAME>
#SBATCH -p <REPLACE WITH YOUR PARTITION's NAME>
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --time=00:30:00

sudo nvidia-smi -lgc 1410,1410

srun --mpi=pmix \
    --container-image <image> \
    --container-mounts <path>:<path> \
    --container-workdir <path> \
    --output logs/tensorrt_llm_%t.out \
    --error logs/tensorrt_llm_%t.error \
        python3 -u ../run.py --engine_dir <engine_dir> --max_output_len 8

Then, submit the job using:

sbatch tensorrt_llm_run.sub

You might have to contact your cluster's administrator to help you customize the above script.

Quantization

SmoothQuant

This section explains how to use SmoothQuant on GPT models with TensorRT-LLM.

SmoothQuant is a post-training quantization (PTQ) method to quantize LLM models to INT8 for faster inference. As explained in the article, SmoothQuant modifies a model to enable INT8 quantization without significantly altering the accuracy.

Model Transformation

A LLM model is made of multiple matrix-multiplication operations (or GEMMs): Y = XW where X of shape [n, k], holds the activation (produced at run-time) and W, of shape [k, m] are the learned weights. Y, of shape [n, m], is the matrix product of X and W.

SmoothQuant introduces scaling along the k dimension by defining a vector of strictly positive coefficients s. Y = X diag(s)^{-1} diag(s) W. We now have Y = X'W' where X' = X diag(s)^{-1} and W' = diag(s) W. This transformation is introduced so the quantization behaves better. In normal models, X tends to be ill-conditioned: it has mostly small-magnitude coefficients, but also some outliers that makes quantization difficult. Conversely, the re-scaled X' is better suited for INT8 conversion.

In this example, we only replace Attention's QKV and MLP's FC1 GEMMs to their Smoothquant'd version since it is sufficient to maintain the accuracy for the GPT model. During inference, X' is computed by fusing the channel-wise multiplication by diag(s)^{-1} with the preceding layernorm's lambda and beta parameters. W' is pre-computed and doesn't need additional modification during inference.

INT8 inference

The INT8 quantization scheme used in TensorRT-LLM theoretically works on any GPT model. However, Smoothquant'd models tend to produce more accurate results with reduced precision.

INT8 inference modifies GEMMs Y = XW so that both X and W use INT8. The matrix-multiplication is sped-up because of smaller weight size and fast matrix products computation thanks to NVIDIA Tensor Cores operating on INT8 inputs.

During inference, X is transformed from its standard floating point (fp) values: X_{i8} <- X_{fp} * s_x. This scaling puts X values in the INT8 range: [-128, 127]. Similarly, W is scaled, W_{i8} <- W_{fp} * s_w but that operation is done at model export time, no need for subsequent operations at run-time.

The optimized TensorRT-LLM GEMM implementation for SmoothQuant does the integer matrix-multiplication Y_{i32} <- X_{i8} W_{i8} and rescales the result to its original range Y_{fp} <- Y_{i32} * (s_x)^{-1} * (s_w)^{-1}. Note that Y_{i32} isn't stored in memory, the re-scaling happens in the GEMM's epilogue and only Y_{fp} gets saved.

By default s_x and s_w are single-value coefficients. This is the per-tensor mode. Values for s_x and s_w are static, estimated at model export time.

TensorRT-LLM also supports more elaborate modes:

  • per-channel: s_w is a fixed vector of size [1, m]. For that, TensorRT-LLM loads the adequately scaled version of of W_{i8} at model construction time.
  • per-token: s_x is a vector of size [n, 1] determined at run-time, based on the per-token (a.k.a. per-row) absolute maximum of X. Users can mix-and-match per-channel and per-token options. Both tend to increase the accuracy of the model at the cost of a slightly increased latency.

Usage

convert_checkpoint.py features a --smoothquant option. It must be set to a decimal value in [0, 1] and corresponds to the alpha parameter in the SmoothQuant paper. Setting --smoothquant will smooth the model as explained in model transformation and export the scaling factors needed for INT8 inference.

By default, it will run the model in the per-tensor mode, as explained in INT8 inference. You can add any combination of --per_token and --per_channel to get the corresponding behaviors.

# Per-tensor SmoothQuant
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --smoothquant 0.5 \
        --output_dir gpt2/trt_ckpt/int8-sq/1-gpu

# Per-token per-channel SmoothQuant
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --smoothquant 0.5 \
        --per_token \
        --per_channel \
        --output_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu

Then, use trtllm-build to build engine(s).

# Per-tensor SmoothQuant
trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/int8-sq/1-gpu

# Per-token per-channel SmoothQuant
trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/int8-sq-ptpc/1-gpu

Note that GPT attention plugin is required to be enabled for SmoothQuant for now.

INT8 KV Cache

convert_checkpoint.py features a --int8_kv_cache option. Setting --int8_kv_cache will calibrate the model and export the scaling factors needed for INT8 KV cache inference.

# Int8 KV cache
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --int8_kv_cache \
        --output_dir gpt2/trt_ckpt/int8kv/1-gpu

trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8kv/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --strongly_typed \
        --output_dir gpt2/trt_engines/int8kv/1-gpu

INT8 KV cache can be used with or without gpt attention plugin.

Weight Only Quantization

convert_checkpoint.py features a --use_weight_only option that can enable weight-only quantization. You can further set the weight-only precision by passing int8 or int4 to the --weight_only_precision flag.

# Int8 weight-only quantization
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --use_weight_only \
        --weight_only_precision int8 \
        --output_dir gpt2/trt_ckpt/int8-wo/1-gpu

# Int4 weight-only quantization
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --use_weight_only \
        --weight_only_precision int4 \
        --output_dir gpt2/trt_ckpt/int4-wo/1-gpu

Then, use trtllm-build to build engine(s).

# Int8 weight-only quantization
trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-wo/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/int8-wo/1-gpu

# Int4 weight-only quantization
trtllm-build --checkpoint_dir gpt2/trt_ckpt/int4-wo/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/int4-wo/1-gpu

FP8 Quantization

../quantization/quantize.py can do FP8 quantization and/or FP8 kv cache quantization, and export TensorRT-LLM checkpoint.

# FP8 quantization with FP8 kv cache
python3 ../quantization/quantize.py --model_dir gpt2 \
        --dtype float16 \
        --qformat fp8 \
        --kv_cache_dtype fp8 \
        --output_dir gpt2/trt_ckpt/fp8/1-gpu

trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp8/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --strongly_typed \
        --output_dir gpt2/trt_engines/fp8/1-gpu

Embedding Parallelism and Sharing

Since the embedding lookup table can be several gigabytes in size. We can distribute this weight across multiple GPUs in order to reduce the memory consumption per GPU.

1. Embedding parallelism

To enable this feature, add the flag --use_parallel_embedding to convert_checkpoint.py.

2. The sharding dimension for embedding parallelism

Assume the size of embedding lookup table is (vocab_size * hidden_size), we can shard it along the vocab_size (--embedding_sharding_dim 0) or hidden_size (--embedding_sharding_dim 1) dimension.

2.1 To shard the embedding lookup table along the hidden_size dimension, set the flag --use_parallel_embedding --embedding_sharding_dim 1. Here is an example:

# 2-way tensor parallelism with embedding parallelism along hidden dimension
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --tp_size 2 \
        --use_parallel_embedding \
        --embedding_sharding_dim 1 \
        --output_dir gpt2/trt_ckpt/fp16/2-gpu

trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --output_dir gpt2/trt_engines/fp16/2-gpu

2.2 To shard the embedding lookup table along the vocab_size dimension, set the flag --use_parallel_embedding --embedding_sharding_dim 0. In this case, you can optionally enable the lookup plugin when building the engines.

# 2-way tensor parallelism with embedding parallelism along vocab dimension
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --tp_size 2 \
        --use_parallel_embedding \
        --embedding_sharding_dim 0 \
        --output_dir gpt2/trt_ckpt/fp16/2-gpu

# It is optional to add --lookup_plugin
trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --lookup_plugin float16 \
        --output_dir gpt2/trt_engines/fp16/2-gpu

3. Embedding sharing

In some models, the embedding weight is used in both the embedding layer and lm_head (language modeling head) layer. In this case, sharing the embedding weight can reduce memory consumption.

With flag --use_embedding_sharing for convert_checkpoint.py, we will try to enable this feature. However it only takes effect when the following criteria are met:

  • The embedding weight is shared between the embedding and lm_head layers. If not, we should not enable this feature.
  • For tensor parallelism cases, --use_parallel_embedding --embedding_sharding_dim 0 must be set. In other words, we must enable embedding parallelism along the vocab dimension, which minimizes the overall communication cost.
  • For TensorRT 9.0 version, the engine size is expected to be reduced when the lookup and gemm plugin are enabled.

Here is an example for using embedding parallelism and sharing feature:

# 2-way tensor parallelism with embedding sharing
# It requires enabling embedding parallelism along vocab dimension
python3 convert_checkpoint.py --model_dir gpt2 \
        --dtype float16 \
        --tp_size 2 \
        --use_embedding_sharing \
        --use_parallel_embedding \
        --embedding_sharding_dim 0 \
        --output_dir gpt2/trt_ckpt/fp16/2-gpu

# It is recommended to add --lookup_plugin and --gemm_plugin
trtllm-build --checkpoint_dir gpt2/trt_ckpt/fp16/2-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --lookup_plugin float16 \
        --gemm_plugin float16 \
        --output_dir gpt2/trt_engines/fp16/2-gpu

GPT Variant - SantaCoder

The SantaCoder extends the existing GPT model with multi-query attention mechanism. The following example shows building a 4-GPU engine and running simple prompt to generate the implementation of print_hello_world().

# Download hf santacoder model
git clone https://huggingface.co/bigcode/santacoder

# Convert to TensorRT-LLM checkpoint
python3 convert_checkpoint.py --model_dir santacoder \
        --dtype float16 \
        --tp_size 4 \
        --output_dir santacoder/trt_ckpt/fp16/4-gpu

# Build TensorRT-LLM engines
trtllm-build --checkpoint_dir santacoder/trt_ckpt/fp16/4-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --context_fmha enable \
        --gemm_plugin float16 \
        --output_dir santacoder/trt_engines/fp16/4-gpu

# Run inference
mpirun -np 4 \
    python3 ../run.py --engine_dir santacoder/trt_engines/fp16/4-gpu \
        --tokenizer_dir santacoder \
        --input_text "def print_hello_world():" \
        --max_output_len 20

GPT Variant - StarCoder (v1 and v2)

For StarCoder, the steps are similar to SantaCoder.

# Download hf starcoder model
git clone https://huggingface.co/bigcode/starcoder

# Convert to TensorRT-LLM checkpoint
python3 convert_checkpoint.py --model_dir starcoder \
        --dtype float16 \
        --tp_size 4 \
        --output_dir starcoder/trt_ckpt/fp16/4-gpu

# Build TensorRT-LLM engines
trtllm-build --checkpoint_dir starcoder/trt_ckpt/fp16/4-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --context_fmha enable \
        --gemm_plugin float16 \
        --output_dir starcoder/trt_engines/fp16/4-gpu

# Run inference
mpirun -np 4 \
    python3 ../run.py --engine_dir starcoder/trt_engines/fp16/4-gpu \
        --tokenizer_dir starcoder \
        --input_text "def print_hello_world():" \
        --max_output_len 20

For StarCoder2, you can use almost the same steps as shown above.

  • Note that StarCoder2 hasn't been merged to the official releases of transformers package yet, so remember using the main branch of transformers repo.
  • Add --max_attention_window_size 4096 when running with run.py or summarization, which enables the sliding window attention.
    • the sliding window size comes from the hf model config.json.

GPT-Next

NVIDIA has released a GPT-like model with some architectural improvements, that you can find here: https://huggingface.co/nvidia/GPT-2B-001. This architecture is also supported by TensorRT-LLM.

Different from Huggingface's checkpoint, you should specify the NeMo checkpoint path using --nemo_ckpt_path for convert_checkpoint.py. The script also extracts the tokenizer file from the NeMo checkpoint and saves it to the TensorRT-LLM checkpoint folder, which can be used in the inference scripts.

# Download NeMo checkpoint
wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo

# Convert to TensorRT-LLM checkpoint
# It also extracts the tokenizer file and saves to the TensorRT-LLM checkpoint folder
python3 convert_checkpoint.py --nemo_ckpt_path GPT-2B-001_bf16_tp1.nemo \
        --dtype bfloat16 \
        --output_dir gpt-next-2B/trt_ckpt/bf16/1-gpu

# Build TensorRT-LLM engines
# --gpt_attention_plugin must be set for GPT-Next since Rotary positional embeddings (RoPE) is only supported by the gpt attention plugin at this time.
trtllm-build --checkpoint_dir gpt-next-2B/trt_ckpt/bf16/1-gpu \
        --gpt_attention_plugin bfloat16 \
        --remove_input_padding enable \
        --output_dir gpt-next-2B/trt_engines/bf16/1-gpu

# Run inference
python3 ../run.py --engine_dir gpt-next-2B/trt_engines/bf16/1-gpu \
        --vocab_file gpt-next-2B/trt_ckpt/bf16/1-gpu/tokenizer.model \
        --no_add_special_tokens \
        --max_output_len 8

Prompt-tuning

For efficient fine-tuning, the NeMo framework allows you to learn virtual tokens to accomplish a downstream task. For more details, please read the NeMo documentation here.

TensorRT-LLM supports inference with those virtual tokens. To enable it, pass the prompt embedding table's maximum size at build time with --max_prompt_embedding_table_size N. For example:

# Convert to TensorRT-LLM checkpoint
python3 convert_checkpoint.py --nemo_ckpt_path megatron_converted_8b_tp4_pp1.nemo \
        --dtype float16 \
        --output_dir gpt-next-8B/trt_ckpt/fp16/1-gpu

# Build TensorRT-LLM engines with prompt-tuning enabled
trtllm-build --checkpoint_dir gpt-next-8B/trt_ckpt/fp16/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --max_prompt_embedding_table_size 100 \
        --output_dir gpt-next-8B/trt_engines/fp16/1-gpu

You can now export the learned embedding table with:

python3 nemo_prompt_convert.py -i email_composition.nemo -o email_composition.npy

It'll give you a summary of the different tasks in the table, that you can specify at runtime.

Finally, you can run inference on pre-defined tokens:

python3 ../run.py --engine_dir gpt-next-8B/trt_engines/fp16/1-gpu \
        --vocab_file gpt-next-8B/trt_ckpt/fp16/1-gpu/tokenizer.model \
        --no_add_special_tokens \
        --prompt_table_path email_composition.npy \
        --prompt_tasks 0 \
        --max_output_len 8

MultiLoRA with the Nemo checkpoint

# Download NeMo checkpoint
wget https://huggingface.co/nvidia/GPT-2B-001/resolve/main/GPT-2B-001_bf16_tp1.nemo

# Convert to TensorRT-LLM checkpoint
python3 convert_checkpoint.py --nemo_ckpt_path GPT-2B-001_bf16_tp1.nemo \
        --dtype float16 \
        --output_dir gpt-next-2B/trt_ckpt/fp16/1-gpu

# Build TensorRT-LLM engines
trtllm-build --checkpoint_dir gpt-next-2B/trt_ckpt/fp16/1-gpu \
        --gpt_attention_plugin float16 \
        --remove_input_padding enable \
        --lora_plugin float16 \
        --lora_dir gpt2b_lora-900.nemo gpt2b_lora-stories.nemo \
        --lora_ckpt_source "nemo" \
        --lora_target_modules attn_qkv \
        --max_batch_size 4 \
        --max_beam_width 2 \
        --max_input_len 512 \
        --max_output_len 50 \
        --output_dir gpt-next-2B/trt_engines/fp16/1-gpu

# Run inference directly from NeMo LoRA checkpoint
# --lora_task_ids correspond to the index of the models given with --lora_dir. -1 means no LoRA
python3 ../run.py --engine_dir gpt-next-2B/trt_engines/fp16/1-gpu \
        --vocab_file gpt-next-2B/trt_ckpt/fp16/1-gpu/tokenizer.model \
        --no_add_special_tokens \
        --max_output_len 20 \
        --use_py_session \
        --lora_task_uids 0 -1 1 \
        --input_text "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprised the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relationships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells" "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells"

The output would look like (Note that in this case the adapters have only been trained for a few epochs, so the result quality is poor):

......
Input [Text 0]: "After Washington had returned to Williamsburg, Dinwiddie ordered him to lead a larger force to assist Trent in his work. While en route, Washington learned of Trent's retreat. Since Tanaghrisson had promised support to the British, Washington continued toward Fort Duquesne and met with the Mingo leader. Learning of a French scouting party in the area, Washington, with Tanaghrisson and his party, surprised the Canadians on May 28 in what became known as the Battle of Jumonville Glen. They killed many of the Canadians, including their commanding officer, Joseph Coulon de Jumonville, whose head was reportedly split open by Tanaghrisson with a tomahawk. The historian Fred Anderson suggests that Tanaghrisson was acting to gain the support of the British and regain authority over his own people. They had been inclined to support the French, with whom they had long trading relationships. One of Tanaghrisson's men told Contrecoeur that Jumonville had been killed by British musket fire. Question: Upon learning of a French scounting party in the area, what did Washington do? Answer:"
Output [Text 0 Beam 0]: "He surprised the Canadians on May 28 in what became known as the Battle of Jumonville"
Input [Text 1]: "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells"
Output [Text 1 Beam 0]: ".

The game is played with a deck of cards, and the player who has the most"
Input [Text 2]: "You hold the job title in the Wizarding World of Harry Potter where you say random words looking for spells"
Output [Text 2 Beam 0]: ".

You are a wizard who is a wizard.

You are a wizard who is"