This project is mainly used to generate tensorrt weights for makaveli10/cpptensorrtz
- Tested with python==3.7.9
- Install torch, torchvision
$ pip install torch==1.6.0
$ pip install torchvision==0.7.0
- Install torchsummary
$ pip install torchsummary
All the models are from torchvision. model.py will download and save the torch weights. Then gen_wts.py will write the pytorch model in a "vgg16".wts file as required by TensorRT.
Example VGG:
$ cd vgg
$ python models.py
$ python gen_trtwts.py
$ pip install pytest
$ pytest
- Check whether the weights are compatible without cuda.
- Re-structure the code and include only one main file with cmd arguments.
- Add multiple models based on their no of layers.
- torchsummary doesn't work with densenet.
- Add function to load custom weights for each network.
- Fix pylint linting score for generate_weights.py