Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP API for Yolact, python package #323

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Empty file added __init__.py
Empty file.
2 changes: 1 addition & 1 deletion data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def print(self):


# ----------------------- MASK BRANCH TYPES ----------------------- #

mask_type = Config({
# Direct produces masks directly as the output of each pred module.
# This is denoted as fc-mask in the paper.
Expand Down Expand Up @@ -819,6 +818,7 @@ def set_cfg(config_name:str):

if cfg.name is None:
cfg.name = config_name.split('_config')[0]
return cfg

def set_dataset(dataset_name:str):
""" Sets the dataset of the current config. """
Expand Down
4 changes: 3 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,11 @@ def play_video():
cleanup_and_exit()

def evaluate(net:Yolact, dataset, train_mode=False):
net.eval()

net.detect.use_fast_nms = args.fast_nms
net.detect.use_cross_class_nms = args.cross_class_nms
cfg=net.cfg
cfg.mask_proto_debug = args.mask_proto_debug

# TODO Currently we do not support Fast Mask Re-scroing in evalimage, evalimages, and evalvideo
Expand Down Expand Up @@ -1096,7 +1099,6 @@ def print_maps(all_maps):
print('Loading model...', end='')
net = Yolact()
net.load_weights(args.trained_model)
net.eval()
print(' Done.')

if args.cuda:
Expand Down
13 changes: 13 additions & 0 deletions layers/modules/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch, torchvision
import torch.nn as nn

class Concat(nn.Module):
def __init__(self, nets, extra_params):
super().__init__()

self.nets = nn.ModuleList(nets)
self.extra_params = extra_params

def forward(self, x):
# Concat each along the channel dimension
return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params)
24 changes: 24 additions & 0 deletions layers/modules/fast_mask_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torch import nn
import torch.nn.functional as F

#locals
from data.config import Config
from utils.functions import make_net
from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper

class FastMaskIoUNet(ScriptModuleWrapper):

def __init__(self, config:Config):
super().__init__()

cfg = config
input_channels = 1
last_layer = [(cfg.num_classes-1, 1, {})]
self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True)

def forward(self, x):
x = self.maskiou_net(x)
maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1)
return maskiou_p

112 changes: 112 additions & 0 deletions layers/modules/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
from torch import nn
import torch.nn.functional as F


from typing import List

#local imports
from data.config import Config
from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper



class FPN(ScriptModuleWrapper):
"""
Implements a general version of the FPN introduced in
https://arxiv.org/pdf/1612.03144.pdf

Parameters (in cfg.fpn):
- num_features (int): The number of output features in the fpn layers.
- interpolation_mode (str): The mode to pass to F.interpolate.
- num_downsample (int): The number of downsampled layers to add onto the selected layers.
These extra layers are downsampled from the last selected layer.

Args:
- in_channels (list): For each conv layer you supply in the forward pass,
how many features will it have?
"""
__constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 'relu_pred_layers',
'lat_layers', 'pred_layers', 'downsample_layers', 'relu_downsample_layers']

def __init__(self, in_channels, config:Config):
super().__init__()

cfg = config

self.lat_layers = nn.ModuleList([
nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1)
for x in reversed(in_channels)
])

# This is here for backwards compatability
padding = 1 if cfg.fpn.pad else 0
self.pred_layers = nn.ModuleList([
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding)
for _ in in_channels
])

if cfg.fpn.use_conv_downsample:
self.downsample_layers = nn.ModuleList([
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2)
for _ in range(cfg.fpn.num_downsample)
])

self.interpolation_mode = cfg.fpn.interpolation_mode
self.num_downsample = cfg.fpn.num_downsample
self.use_conv_downsample = cfg.fpn.use_conv_downsample
self.relu_downsample_layers = cfg.fpn.relu_downsample_layers
self.relu_pred_layers = cfg.fpn.relu_pred_layers

@script_method_wrapper
def forward(self, convouts:List[torch.Tensor]):
"""
Args:
- convouts (list): A list of convouts for the corresponding layers in in_channels.
Returns:
- A list of FPN convouts in the same order as x with extra downsample layers if requested.
"""

out = []
x = torch.zeros(1, device=convouts[0].device)
for i in range(len(convouts)):
out.append(x)

# For backward compatability, the conv layers are stored in reverse but the input and output is
# given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers.
j = len(convouts)
for lat_layer in self.lat_layers:
j -= 1

if j < len(convouts) - 1:
_, _, h, w = convouts[j].size()
x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False)

x = x + lat_layer(convouts[j])
out[j] = x

# This janky second loop is here because TorchScript.
j = len(convouts)
for pred_layer in self.pred_layers:
j -= 1
out[j] = pred_layer(out[j])

if self.relu_pred_layers:
F.relu(out[j], inplace=True)

cur_idx = len(out)

# In the original paper, this takes care of P6
if self.use_conv_downsample:
for downsample_layer in self.downsample_layers:
out.append(downsample_layer(out[-1]))
else:
for idx in range(self.num_downsample):
# Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript.
out.append(nn.functional.max_pool2d(out[-1], 1, stride=2))

if self.relu_downsample_layers:
for idx in range(len(out) - cur_idx):
out[idx] = F.relu(out[idx + cur_idx], inplace=False)

return out
Loading