TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" https://arxiv.org/abs/1811.08383
TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2.
For the PyTorch implementation, you can refer to open-mmlab/mmaction2 or mit-han-lab/temporal-shift-module.
More details about the shift module(which is the core of TSM) could to test_shift.py.
-
An example could refer to demo.sh
- Requirements: Successfully installed
torch>=1.3.0, torchvision
- Requirements: Successfully installed
-
Step 1: Train/Download TSM-R50 checkpoints from offical Github repo or MMAction2
- Supported settings:
num_segments
,shift_div
,num_classes
. - Fixed settings:
backbone
(ResNet50),shift_place
(blockres),temporal_pool
(False).
- Supported settings:
-
Step 2: Convert PyTorch checkpoints to TensorRT weights.
python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts
- Step 3: Test Python API.
- Modify configs in
tsm_r50.py
. - Inference with
tsm_r50.py
.
- Modify configs in
# Supported settings
BATCH_SIZE = 1
NUM_SEGMENTS = 8
INPUT_H = 224
INPUT_W = 224
OUTPUT_SIZE = 400
SHIFT_DIV = 8
usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT] [--test-cpp] [--cpp-result-path CPP_RESULT_PATH]
optional arguments:
-h, --help show this help message and exit
--tensorrt-weights TENSORRT_WEIGHTS
Path to TensorRT weights, which is generated by gen_weights.py
--input-video INPUT_VIDEO
Path to local video file
--save-engine-path SAVE_ENGINE_PATH
Save engine to local file
--load-engine-path LOAD_ENGINE_PATH
Saved engine file path
--test-mmaction2 Compare TensorRT results with MMAction2 Results
--mmaction2-config MMACTION2_CONFIG
Path to MMAction2 config file
--mmaction2-checkpoint MMACTION2_CHECKPOINT
Path to MMAction2 checkpoint url or file path
--test-cpp Compare Python API results with C++ API results
--cpp-result-path CPP_RESULT_PATH
Path to C++ API results
- Step 4: Test C++ API.
- Mocify Configs in
tsm_r50.cpp
. - Build from source code:
mkdir build && cd build && cmake .. && make
- Generate Engine file:
./tsm_r50 -s
- Inference with genrated engine file and write predictions to local:
./tsm_r50 -d
- Compare results with Python API:
python tsm_r50.py --tensorrt-weights /path/to/tensorrt.weights --test-cpp --cpp-result-file /path/to/cpp-result.txt
- Mocify Configs in
- Python Shift module.
- Generate wts of official tsm and mmaction2 tsm.
- Python API Definition
- Test with mmaction2 demo
- Tutorial
- C++ API Definition