Skip to content

TRTorch v0.4.0

Compare
Choose a tag to compare
@narendasan narendasan released this 24 Aug 21:49

TRTorch v0.4.0

Support for PyTorch 1.9, TensorRT 8.0. Introducing INT8 Execution for QAT models, Module Based Partial Compilation, Auto Device Configuration, Input Class, Usability Improvements, New Converters, Bug Fixes

Target Platform Changes

This is the fourth beta release of TRTorch, targeting PyTorch 1.9, CUDA 11.1 (on x86_64, CUDA 10.2 on aarch64), cuDNN 8.2 and TensorRT 8.0 with backwards compatible source for TensorRT 7.1. On aarch64 TRTorch targets Jetpack 4.6 primarily with backwards compatibile source for Jetpack 4.5. When building on Jetson, the flag --platforms //toolchains:jetpack_4.x must be now be provided for C++ compilation to select the correct dependency paths. For python by default it is assumed the Jetpack version is 4.6. To override this add the --jetpack-version 4.5 flag when building.

TensorRT 8.0

This release adds support for compiling models trained with Quantization aware training (QAT) allowing users using the TensorRT PyTorch Quantization Toolkit (https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization) to compile their models using TRTorch. For more information and a tutorial, refer to https://www.github.com/NVIDIA/TRTorch/tree/v0.4.0/examples/int8/qat. It also adds support for sparsity via the sparse_weights flag in the compile spec. This allows TensorRT to utilize specialized hardware in Ampere GPUs to minimize unnecessary computation and therefore increase computational efficiency.

Partial Compilation

In v0.4.0 the partial compilation feature of TRTorch can now be considered beta level stability. New in this release is the ability to specify entire PyTorch modules to run in PyTorch explicitly as part of partial compilation. This should let users isolate troublesome code easily when compiling. Again, feedback on this feature is greatly appreciated.

Automatic Device Configuration at Runtime

v0.4.0 also changes the "ABI" of TRTorch to now include information about the target device for the program. Programs compiled with v0.4.0 will look for and select the most compatible available device. The rules used are: Any valid device option must have the same SM capability as the device building the engine. From there, TRTorch prefers the same device (e.g. Built on A100 so A100 is better than A30) and finally prefers the same device ID. Users will be warned if this selected device is not the current active device in the course of execution as overhead may be incurred in transferring input tensors from the current device to the target device. Users can then modify their code to avoid this. Due to this ABI change, existing compiled TRTorch programs are incompatible with the TRTorch v0.4.0 runtime. From v0.4.0 onwards an internal ABI version will check program compatibility. This ABI version is only incremented with breaking changes to the ABI.

API Changes (Input, enabled_precisions, Device)

TRTorch v0.4.0 changes the API for specifying Input shapes and data types to provide users more control over configuration. The new API makes use of the class trtorch.Input which lets users set the shape (or shape range) as well as memory layout and expected data type. These input specs are set in the input field of the CompileSpec.

"inputs": [
        trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1
        trtorch.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32,
            format=torch.channel_last,
        ) # Dynamic input shape for input #2, input type int and channel last format
    ],

The legacy input_shapes field and associated usage with lists of tuples/InputRanges should now be considered deprecated. They remain usable in v0.4.0 but will be removed in the next release. Similarly, the compile spec field op_precision is now also deprecated in favor of enabled_precisions. enabled_precisions is a set containing the data types that kernels will be allowed to use. Whereas setting op_precision = torch.int8 would implicitly enable FP32 and FP16 kernels as well, now enabled_precisions should be set as {torch.float32, torch.float16, torch.int8} to do the same. In order to maintain similar behavior to normal PyTorch, if FP16 is the lowest precision enabled but no explicit data type is set for the inputs to the model, the expectation will be that inputs will be in FP16 . For other cases (FP32, INT8) FP32 is the default, similar to PyTorch and previous versions of TRTorch. Finally in the Python API, a class trtorch.Device has been added. While users can continue to use torch.Device or other torch APIs, trtorch.Device allows for better control for the specific use cases of compiling with TRTorch (e.g. setting DLA core and GPU fallback). This class is very similar to the C++ version with a couple additions of syntactic sugar to make the class easier and more familiar to use:

trtorch.Device("dla:0", allow_gpu_fallback=False) #Set device as DLA Core 0 (implicitly sets the GPU managing DLA cores as the GPU and sets fallback to false)

trtorch.Device can be used instead of a dictionary in the compile spec if desired.

trtorchc has been updated to reflect these API changes. Users can set the shape, dtype and format of inputs from the command line using the following format "[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]@DTYPE%FORMAT" e.g. (3, 3, 32,32)@f16%NHWC. -p is now a repeatable flag to enable multiple precisions. Also added are repeatable flags --ffm and --ffo to mark specific modules and operators for running in PyTorch respectively. To use these two options, --allow-torch-fallback should be set. Options for embedding serialized engines (--embed-engine) and sparsity (--sparse-weights) added as well.

Usability

Finally, TRTorch v0.4.0 also now includes the ability to provide backtraces for locations in your model which TRTorch does not support. This can help in identifying locations in the model that might need to change for TRTorch support or modules which should run fully in PyTorch via partial compilation.

Dependencies

- Bazel 4.0.0
- LibTorch 1.9.0
- CUDA 11.1 (on x86_64, by default, newer CUDA 11 supported with compatible PyTorch Build), 10.2 (on aarch64)
- cuDNN 8.2.2.3
- TensorRT 8.0.1.6

0.4.0 (2021-08-24)

  • feat(serde)!: Refactor CudaDevice struct, implement ABI versioning, (9327cce)
  • feat(//py)!: Implementing top level python api changes to reflect new (482265f)
  • feat(//cpp)!: Changes to TRTorch C++ api reflecting Input and (08b4942)
  • feat!: Pytorch 1.9 version bump (a12d249)
  • feat(//core/runtime)!: Better and more portable names for engines (6eb3bb2)

Bug Fixes

  • //core/conversion/conversionctx: Guard final engine building (dfa9ae8)
  • //core/lowering: use lower_info as parameter (370aeb9)
  • //cpp/ptq: fixing bad accuracy in just the example code (7efa11d)
  • //py: Fix python setup.py with new libtrtorch.so location (68ba63c)
  • //tests: fix optional jetson tests (4c32a83)
  • //tests: use right type for masked_fill test (4a5c28f)
  • aten::cat: support neg dim for cat (d8ca182)
  • aten::select and aten::var: Fix converters to handle negative axes (3a734a2)
  • aten::slice: Allow slicing of pytorch tensors (50f012e)
  • aten::tensor: Last dim doesnt always get written right (b68d4aa)
  • aten::tensor: Last dim doesnt always get written right (38744bc)
  • Address review comments, fix failing tests due to bool mishandling (13eef91)
  • Final working version of QAT in TRTorch (521a0cb)
  • fix aten::sub.scalar operator (9a09514)
  • Fix linear lowering pass, lift layer_norm scale layer restriction and matmul layer nbdims restriction (930d582)
  • Fix testcases using old InputRange API (ff87956)
  • Fix TRT8 engine capability flags (2b69742)
  • Fix warnings thrown by noexcept functions (c5f7eea)
  • Fix warnings thrown by noexcept functions (ddc8950)
  • Minor fixes to qat scripts (b244423)
  • Restrict TRTorch to compile only forward methods (9f006d5)
  • Transfer calibration data to gpu when it is not a batch (23739cb)
  • typo in aten::batch_norm (d47f48f)
  • qat: Rescale input data for C++ application (9dc6061)
  • Use len() to get size of dataset (ccc60d5)
  • device_conf: Devices never actually got swithed in multi device (f1d0a43)
  • exception_elimination: Exception branches are no longer consistent (d61b667)
  • to_backend: Clean up to_backend implementation (4e15605)
  • trtorchc: Allow for workspaces larger than 2G and better debugging (e1e7812)
  • Using TensorRT 8 new API calls (14691e7)
  • Using TensorRT 8 new API calls (fa969a5)

Features

  • //core/conversion: Adding error prefix to python source traceback (4bf2a41)
  • //core/conversion: Handle adding and wrapping ITensors as (a22e99b)
  • //core/ir: Implementing new internal input spec type (316df28)
  • //core/lowering: Adding two passes, one to delimit and one to mark (2e04ce5)
  • //core/lowering: additional logging in module fallback (ad07645)
  • //core/plugins: Add adaptive_max_pool2d plugin, enable the plugins to run on GPU (6f4aa40)
  • //cpp/int8/qat: QAT application release (d8f5d29)
  • //examples/int8: Implement Makefile based execution for ptq and qat (b7f6d8a)
  • //examples/int8/qat: Install pytorch-quantization with (1ca1484)
  • //py: add user level device class in py for embed engine (d99169f)
  • aten::masked_fill: In progress implementation of masked_fill (fa7d6d9)
  • aten::ones: Adding support for aten::ones (2b45a3d)
  • aten::slice: Patching slice for new optional params (a11287f)
  • aten::sqrt: Adding support for sqrt evaluators (6aaba3b)
  • aten::std|aten::masked_fill: Implement masked_fill, aten::std (a086a5b)
  • aten::std|aten::masked_fill: Implement masked_fill, aten::std (2866627)
  • jetson: Support for Jetpack 4.6 (9760fe3)
  • to_backend: Updating backend integration preproc function (080b594)
  • Enable sparsity support in TRTorch (f9e1f2b)
  • trtorchc: Adding flag for sparse weights (bfdc6f5)
  • Add aten::full converter, quantization ops testcases (9f2ffd0)
  • Add aten::type_as lowering pass (b57a6dd)
  • Add functionality for QAT workflow (fc8eafb)
  • Add functionality for QAT workflow (f776e76)
  • Add support for providing input datatypes in TRTorch (a3f4a3c)
  • Adding automatic casting to compare layers (90af26e)
  • Enable sparsity support in TRTorch (decd0ed)
  • Enable TRT 8.0 QAT functionality in TRTorch (c76a28a)
  • Makefile for trtorchrt.so example (c60c521)
  • show pytorch code of unsupported operators (2ee2a84)
  • support aten::Int (5bc977d)
  • trtorchc: Adding more dtype aliases (652fb13)
  • trtorchc: Adding new support for dtypes and formats in (c39bf81)
  • Support fallback options in trtorchc (ad966b7)
  • Using shared_ptrs to manage TRT resources in runtime (e336630)
  • trtorchc: Embedding engines in modules from the CLI (2b4b9e3)

BREAKING CHANGES

  • This commit cleans up the WIP CudaDevice class,
    simplifying implementation and formalizing the seralized format for CUDA
    devices.

It also implements ABI Versioning. The first entry in the serialized
format of a TRTEngine now records the ABI that the engine was compiled
with, defining expected compatibility with the TRTorch runtime. If the
ABI version does not match, the runtime will error out asking to
recompile the program.

ABI version is a monotonically increasing integer and should be
incremented everytime the serialization format changes in some way.

This commit cleans up the CudaDevice class, implementing a number of
constructors to replace the various utility functions that populate the
struct. Descriptive utility functions remain but solely call the
relevant constructor.

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This commit introduces the next iteration of the Python
    TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes"
    and "op_precision" compile spec keys will be removed. Users should port
    forward to using the "inputs" key which expects a list of trtorch.Input
    objects and the "enabled_precisions" key which expects a set of data
    type specifying enums.

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This change deprecates InputRange, and the CompileSpec
    fields "input_shapes", "op_precision" and associated contructors and
    functions. These are replaced wtih Input, "inputs" and
    "enabled_precisions" respectively. Deprecated components will be removed
    in TRTorch v0.5.0

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • Updating PyTorch version to 1.9.0 which includes
    breaking changes to the to_backend api

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]

  • This bumps the TRTorch ABI version to 3 due to
    a new field for engine name included in the serialized form of
    TRTEngine. This lets deserialized engines have the same name they
    serialized with

Signed-off-by: Naren Dasan [email protected]
Signed-off-by: Naren Dasan [email protected]>

Supported Operators in TRTorch v0.4.0

Operators Currently Supported Through Converters

  • aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
  • aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
  • aten::abs(Tensor self) -> (Tensor)
  • aten::acos(Tensor self) -> (Tensor)
  • aten::acosh(Tensor self) -> (Tensor)
  • aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)
  • aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
  • aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
  • aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::asin(Tensor self) -> (Tensor)
  • aten::asinh(Tensor self) -> (Tensor)
  • aten::atan(Tensor self) -> (Tensor)
  • aten::atanh(Tensor self) -> (Tensor)
  • aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
  • aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::bmm(Tensor self, Tensor mat2) -> (Tensor)
  • aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::ceil(Tensor self) -> (Tensor)
  • aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
  • aten::clamp_max(Tensor self, Scalar max) -> (Tensor)
  • aten::clamp_min(Tensor self, Scalar min) -> (Tensor)
  • aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
  • aten::cos(Tensor self) -> (Tensor)
  • aten::cosh(Tensor self) -> (Tensor)
  • aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)
  • aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
  • aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
  • aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
  • aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::erf(Tensor self) -> (Tensor)
  • aten::exp(Tensor self) -> (Tensor)
  • aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
  • aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
  • aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)
  • aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)
  • aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
  • aten::floor(Tensor self) -> (Tensor)
  • aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
  • aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::gelu(Tensor self) -> (Tensor)
  • aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor)
  • aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
  • aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
  • aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
  • aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
  • aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
  • aten::log(Tensor self) -> (Tensor)
  • aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
  • aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)
  • aten::matmul(Tensor self, Tensor other) -> (Tensor)
  • aten::max(Tensor self) -> (Tensor)
  • aten::max.other(Tensor self, Tensor other) -> (Tensor)
  • aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::min(Tensor self) -> (Tensor)
  • aten::min.other(Tensor self, Tensor other) -> (Tensor)
  • aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
  • aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
  • aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::neg(Tensor self) -> (Tensor)
  • aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)
  • aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
  • aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)
  • aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
  • aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
  • aten::prelu(Tensor self, Tensor weight) -> (Tensor)
  • aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::reciprocal(Tensor self) -> (Tensor)
  • aten::relu(Tensor input) -> (Tensor)
  • aten::relu_(Tensor(a!) self) -> (Tensor(a!))
  • aten::repeat(Tensor self, int[] repeats) -> (Tensor)
  • aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
  • aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
  • aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
  • aten::reshape(Tensor self, int[] shape) -> (Tensor)
  • aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))
  • aten::sigmoid(Tensor input) -> (Tensor)
  • aten::sigmoid_(Tensor(a!) self) -> (Tensor(a!))
  • aten::sin(Tensor self) -> (Tensor)
  • aten::sinh(Tensor self) -> (Tensor)
  • aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
  • aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)
  • aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])
  • aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::sqrt(Tensor self) -> (Tensor)
  • aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::t(Tensor self) -> (Tensor)
  • aten::tan(Tensor self) -> (Tensor)
  • aten::tanh(Tensor input) -> (Tensor)
  • aten::tanh_(Tensor(a!) self) -> (Tensor(a!))
  • aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
  • aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(a|b))
  • aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
  • aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))
  • aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)
  • aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)
  • aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::view(Tensor(a) self, int[] size) -> (Tensor(a))
  • trt::const(Tensor self) -> (Tensor)

Operators Currently Supported Through Evaluators

  • aten::Bool.float(float b) -> (bool)
  • aten::Bool.int(int a) -> (bool)
  • aten::Float.Scalar(Scalar a) -> float
  • aten::Float.bool(bool a) -> float
  • aten::Float.int(int a) -> float
  • aten::Int.Scalar(Scalar a) -> int
  • aten::Int.bool(bool a) -> int
  • aten::Int.float(float a) -> int
  • aten::Int.int(int a) -> int
  • aten::and(int a, int b) -> (bool)
  • aten::getitem.t(t list, int idx) -> (t(*))
  • aten::is(t1 self, t2 obj) -> bool
  • aten::isnot(t1 self, t2 obj) -> bool
  • aten::not(bool self) -> bool
  • aten::or(int a, int b) -> (bool)
  • aten::__round_to_zero_floordiv(int a, int b) -> (int)
  • aten::xor(int a, int b) -> (bool)
  • aten::add.float(float a, float b) -> (float)
  • aten::add.int(int a, int b) -> (int)
  • aten::add_.t(t self, t[] b) -> (t[])
  • aten::append.t(t self, t(c -> *) el) -> (t)
  • aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
    Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
    Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
  • aten::clone(Tensor self, *, int? memory_format=None) -> (Tensor)
  • aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!))
  • aten::dim(Tensor self) -> int
  • aten::div.float(float a, float b) -> (float)
  • aten::div.int(int a, int b) -> (float)
  • aten::eq.bool(bool a, bool b) -> (bool)
  • aten::eq.float(float a, float b) -> (bool)
  • aten::eq.float_int(float a, int b) -> (bool)
  • aten::eq.int(int a, int b) -> (bool)
  • aten::eq.int_float(int a, float b) -> (bool)
  • aten::floor.float(float a) -> (int)
  • aten::floor.int(int a) -> (int)
  • aten::floordiv.float(float a, float b) -> (int)
  • aten::floordiv.int(int a, int b) -> (int)
  • aten::ge.bool(bool a, bool b) -> (bool)
  • aten::ge.float(float a, float b) -> (bool)
  • aten::ge.float_int(float a, int b) -> (bool)
  • aten::ge.int(int a, int b) -> (bool)
  • aten::ge.int_float(int a, float b) -> (bool)
  • aten::gt.bool(bool a, bool b) -> (bool)
  • aten::gt.float(float a, float b) -> (bool)
  • aten::gt.float_int(float a, int b) -> (bool)
  • aten::gt.int(int a, int b) -> (bool)
  • aten::gt.int_float(int a, float b) -> (bool)
  • aten::is_floating_point(Tensor self) -> (bool)
  • aten::le.bool(bool a, bool b) -> (bool)
  • aten::le.float(float a, float b) -> (bool)
  • aten::le.float_int(float a, int b) -> (bool)
  • aten::le.int(int a, int b) -> (bool)
  • aten::le.int_float(int a, float b) -> (bool)
  • aten::len.t(t[] a) -> (int)
  • aten::lt.bool(bool a, bool b) -> (bool)
  • aten::lt.float(float a, float b) -> (bool)
  • aten::lt.float_int(float a, int b) -> (bool)
  • aten::lt.int(int a, int b) -> (bool)
  • aten::lt.int_float(int a, float b) -> (bool)
  • aten::mul.float(float a, float b) -> (float)
  • aten::mul.int(int a, int b) -> (int)
  • aten::ne.bool(bool a, bool b) -> (bool)
  • aten::ne.float(float a, float b) -> (bool)
  • aten::ne.float_int(float a, int b) -> (bool)
  • aten::ne.int(int a, int b) -> (bool)
  • aten::ne.int_float(int a, float b) -> (bool)
  • aten::neg.int(int a) -> (int)
  • aten::numel(Tensor self) -> int
  • aten::size(Tensor self) -> (int[])
  • aten::size.int(Tensor self, int dim) -> (int)
  • aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])
  • aten::sqrt.float(float a) -> (float)
  • aten::sqrt.int(int a) -> (float)
  • aten::sub.float(float a, float b) -> (float)
  • aten::sub.int(int a, int b) -> (int)
  • aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)
  • prim::dtype(Tensor a) -> (int)
  • prim::max.bool(bool a, bool b) -> (bool)
  • prim::max.float(float a, float b) -> (bool)
  • prim::max.float_int(float a, int b) -> (bool)
  • prim::max.int(int a, int b) -> (bool)
  • prim::max.int_float(int a, float b) -> (bool)
  • prim::max.self_int(int[] self) -> (int)
  • prim::min.bool(bool a, bool b) -> (bool)
  • prim::min.float(float a, float b) -> (bool)
  • prim::min.float_int(float a, int b) -> (bool)
  • prim::min.int(int a, int b) -> (bool)
  • prim::min.int_float(int a, float b) -> (bool)
  • prim::min.self_int(int[] self) -> (int)
  • prim::shape(Tensor a) -> (int[])