Skip to content

PyTorch to ONNX (DataParallel)

Katsuya Hyodo edited this page Aug 22, 2021 · 1 revision
import torch
import torch.nn as nn
from src.models.modnet import MODNet
from torch.autograd import Variable

modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load('pretrained/modnet_webcam_portrait_matting.ckpt'))
modnet.eval()
torch.save(modnet.module.state_dict(), 'modnet_512x672_float32.pth')

modnet.load_state_dict(torch.load('modnet_512x672_float32.pth'))
modnet.eval()
dummy_input = Variable(torch.randn(1, 3, 512, 512))
torch.onnx.export(modnet, dummy_input, 'modnet_512x672_float32.onnx', export_params=True)
net = GCANet(in_c=4, out_c=3, only_residual=True).to(device)
# net = FFANet(3, 19)
# net = MSBDNNet()

net = nn.DataParallel(net, device_ids=device_ids)

net.load_state_dict(torch.load('PSD-GCANET'))
# net.load_state_dict(torch.load('PSD-FFANET'))
# net.load_state_dict(torch.load('PSB-MSBDN'))
net.eval()


MODEL='psd_gcanet'
# MODEL='psd_ffanet'
# MODEL='psb_msbdn'
H=512
W=512

torch.save(net.module.state_dict(), f"{MODEL}_{H}x{W}.pth")

net = GCANet(in_c=4, out_c=3, only_residual=True).to(device)
# net = FFANet(3, 19)
# net = MSBDNNet()

net.load_state_dict(torch.load(f"{MODEL}_{H}x{W}.pth"))
x = torch.randn(1, 4, H, W).cuda() # GCANet
# x = torch.randn(1, 3, H, W) # FFANet, MSBDNNet
torch.onnx.export(net, x, f"{MODEL}_{H}x{W}.onnx", opset_version=11)
import sys
sys.exit(0)
Clone this wiki locally