Skip to content

Torch-TensorRT v2.5.0

Latest
Compare
Choose a tag to compare
@narendasan narendasan released this 18 Oct 19:45
f2e1e6c

PyTorch 2.5, CUDA 12.4, TensorRT 10.3, Python 3.12

Torch-TensorRT 2.5.0 targets PyTorch 2.5, TensorRT 10.3 and CUDA 12.4.
(builds for CUDA 11.8/12.1 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu121)

Deprecation notice

The torchscript frontend will be deprecated in v2.6. Specifically, the following usage will no longer be supported and will issue a deprecation warning at runtime if used:

torch_tensorrt.compile(model, ir="torchscript")

Moving forward, we encourage users to transition to one of the supported options:

torch_tensorrt.compile(model)
torch_tensorrt.compile(model, ir="dynamo")
torch.compile(model, backend="tensorrt")

Torchscript will continued to be supported as a deployment format via post compilation tracing

dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
ts_model(...)

Please refer to the README for more information regarding our deprecation policy.

Refit (Beta)

v2.5.0 introduces direct model refitting from PyTorch for your compiled Torch-TensorRT programs. Sometimes the weights need to change through the course of inference and in the past full recompilation was necessary to change out the weights of the model, either through automatic recompilation through torch.compile or through manual recompilation with torch_tensorrt.compile. Now using the refit_module_weights API, compiled modules can be refitted by providing a new PyTorch module (with identical structure) containing the new weights. Compiled modules must be compiled with make_refittable to use this feature.

# Create and export the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))


compiled_trt_ep = torch_trt.load("./compiled.ep")

# This returns a new module with updated weights
new_trt_gm = refit_module_weights(
    compiled_module=compiled_trt_ep,
    new_weight_module=exp_program2,
)

There are some ops that are not compatible with refit, such as ops that utilize ILoop layer. When make_refittable is enabled, these ops will be forced to run in PyTorch. It should also be known that engines that are refit enabled may be slightly less performant than non-refittable engines as TensorRT cannot tune for the specific weights it will see at execution time.

Refit Caching (Experimental)

Refitting on its own can help to speed up update model swap times by 0.5-2x. However, the speed of refit can be further improved by utilizing refit caching. Refit caching at compile time stores hints for a direct mapping from PyTorch module members to TRT layer names in the metadata of TorchTensorRTModule. This caching can speed up refit by orders of magnitude. However, it currently has limitations when dealing with layers that have compile time optimization. This feature is still experimental as there may be some ops that are not amenable to refit caching. We still enable using the cache by default when refitting to collect feedback on the edge cases and we provide a output validator which can be used to ensure that refit occurred properly. When verify_outputs is True if the refit failed, then the refitter will discard the cache and refit from scratch.

new_trt_gm = refit_module_weights(
    compiled_module=compiled_trt_ep,
    new_weight_module=exp_program2,
    arg_inputs=inputs,
    verify_outputs=True, 
)

MutableTorchTensorRTModule (Experimental)

torch.compile is incredibly useful when it comes to trying to optimize models that may change over time since it can automatically recompile the module when something changes. However, the major limitation of torch.compile is it cannot be serialized. For users who are looking for similar flexibility but the added ability to serialize and move their work we have introduced the MutableTorchTensorRTModule. This module wraps a PyTorch module and exposes its members transparently, however it injects listeners on setattr and overrides the forward function to use TensorRT accelerated subgraphs. This means you can make changes to your module such as applying adapters and the MutableTorchTensorRTModule will detect the change and mark the function for refit or recompilation based on the change. Similar to torch.compile this is done in a JIT manner, so the first inference after a change will perform the refit or recompile operation.

from diffusers import DiffusionPipeline

with torch.no_grad():
    settings = {
        "use_python_runtime": True,
        "enabled_precisions": {torch.float16},
        "debug": True,
        "make_refittable": True,
    }

    model_id = "runwayml/stable-diffusion-v1-5"
    device = "cuda:0"

    prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
    negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"

    pipe = DiffusionPipeline.from_pretrained(
        model_id, revision="fp16", torch_dtype=torch.float16
    )
    pipe.to(device)

    # The only extra line you need
    pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

    image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
    image.save("./without_LoRA_mutable.jpg")

    # Standard Huggingface LoRA loading procedure
    pipe.load_lora_weights(
        "stablediffusionapi/load_lora_embeddings",
        weight_name="moxin.safetensors",
        adapter_name="lora1",
    )
    pipe.set_adapters(["lora1"], adapter_weights=[1])
    pipe.fuse_lora()
    pipe.unload_lora_weights()

    # Refit triggered
    image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
    image.save("./with_LoRA_mutable.jpg")

Engine Caching

In some scenarios, users may compile a module multiple times and each time it takes a long time to build a TensorRT engine in the backend. Engine caching will boost performance by reusing previously compiled TensorRT engines rather than recompiling it every time, thereby avoiding recompilation time. When a cached engine is loaded, it will be refitted with the new module weights.

To make it more efficient, as long as two graph modules have the same structure, even though their weights are not the same, we still consider they are the same, i.e., isomorphic graph modules. Isomorphic graph modules with the same compilation settings will share cached engines.

We implemented DiskEngineCache so that users can directly use the APIs to control how and where to save/load cached engines on the disk of the local machine. For exmaple,

trt_gm = torch_trt.dynamo.compile(
    exp_program,
    tuple(inputs),
    make_refitable=True,
    cache_built_engines=True,
    reuse_cached_engines=True,
    engine_cache_dir="/tmp/torch_trt_engine_cache"
    engine_cache_size=1 << 30,  # 1GB
)

In addition, considering some users want to save to or load engines from other servers, clusters, or cloud, we also provided a base class BaseEngineCache so that users are able to easily implement their own logic to save and load engines. For example,

class MyEngineCache(BaseEngineCache):
    def __init__(
        self,
        addr: str,
    ) -> None:
        self.addr= addr

    def save(
        self,
        hash: str,
        blob: bytes,
        prefix: str = "blob",
    ):
        # user's customized function to save engines
        write_to(self.addr, name=f"{prefix}_{hash}.bin", content=blob)

    def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
        # user's customized function to load engines
        return read_from(self.addr, name=f"{prefix}_{hash}.bin")


trt_gm = torch_trt.dynamo.compile(
    exp_program,
    tuple(inputs),
    make_refitable=True,
    cache_built_engines=True,
    reuse_cached_engines=True,
    custom_engine_cache=MyEngineCache("xxxxx"),
)

CUDA Graphs

In v2.5.0 CUDA graph support for in engine kernel launch optimization has been added through a new runtime mode. This mode can be activated from Python using

import torch_tensorrt 

my_torchtrt_model = torch_tensorrt.compile(...)

with torch_tensorrt.runtime.enable_cudagraphs():
    my_torchtrt_model(inputs)

This mode works by creating CUDAGraphs around individual TensorRT engines which improves their efficiency. It creates graph through a capture phase which is tied to the input shape to the engine. When the input shape changes, this graph is invalidated and the graph is automatically recaptured.

Model Optimizer based Int8 Quantization(PTQ) support for Linux

This version introduces official support for the int8 Quantization via modelopt (https://github.com/NVIDIA/TensorRT-Model-Optimizer) 17.0 for Linux.
Full examples can be found at https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/vgg16_ptq.py
running the vgg16 example for int8 ptq

step1:  generate checkpoint file for vgg16:
cd examples/int8/training/vgg16
python main.py --lr 0.01 --batch-size 128 --drop-ratio 0.15 \
--ckpt-dir $(pwd)/vgg16_ckpts --epochs 20 --seed 545
this should produce a ckpt file at examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth

step2: run int8 ptq for vgg16:
python examples/dynamo/vgg16_fp8_ptq.py --batch-size 128 \
--ckpt=examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth \
--quantize-type=int8

LLM examples

We now offer dynamic shape support for all converters (covering core ATen operations). Dynamic shapes are widely utilized in leading LLM models, where input sequence lengths may vary. With this release, we showcase full graph compilation for Llama2 and GPT2 models using Torch-TensorRT. For detailed examples, please refer to our documentation.

What's Changed

New Contributors

Full Changelog: v2.4.0...v2.5.0