diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/data/config.py b/data/config.py index 91b4c82ea..a29fd14bb 100644 --- a/data/config.py +++ b/data/config.py @@ -303,7 +303,6 @@ def print(self): # ----------------------- MASK BRANCH TYPES ----------------------- # - mask_type = Config({ # Direct produces masks directly as the output of each pred module. # This is denoted as fc-mask in the paper. @@ -819,6 +818,7 @@ def set_cfg(config_name:str): if cfg.name is None: cfg.name = config_name.split('_config')[0] + return cfg def set_dataset(dataset_name:str): """ Sets the dataset of the current config. """ diff --git a/eval.py b/eval.py index 547bc0aae..42349c5c4 100644 --- a/eval.py +++ b/eval.py @@ -868,8 +868,11 @@ def play_video(): cleanup_and_exit() def evaluate(net:Yolact, dataset, train_mode=False): + net.eval() + net.detect.use_fast_nms = args.fast_nms net.detect.use_cross_class_nms = args.cross_class_nms + cfg=net.cfg cfg.mask_proto_debug = args.mask_proto_debug # TODO Currently we do not support Fast Mask Re-scroing in evalimage, evalimages, and evalvideo @@ -1096,7 +1099,6 @@ def print_maps(all_maps): print('Loading model...', end='') net = Yolact() net.load_weights(args.trained_model) - net.eval() print(' Done.') if args.cuda: diff --git a/layers/modules/concat.py b/layers/modules/concat.py new file mode 100644 index 000000000..45ac83575 --- /dev/null +++ b/layers/modules/concat.py @@ -0,0 +1,13 @@ +import torch, torchvision +import torch.nn as nn + +class Concat(nn.Module): + def __init__(self, nets, extra_params): + super().__init__() + + self.nets = nn.ModuleList(nets) + self.extra_params = extra_params + + def forward(self, x): + # Concat each along the channel dimension + return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params) \ No newline at end of file diff --git a/layers/modules/fast_mask_iou.py b/layers/modules/fast_mask_iou.py new file mode 100644 index 000000000..f082e2a2d --- /dev/null +++ b/layers/modules/fast_mask_iou.py @@ -0,0 +1,24 @@ +import torch +from torch import nn +import torch.nn.functional as F + +#locals +from data.config import Config +from utils.functions import make_net +from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper + +class FastMaskIoUNet(ScriptModuleWrapper): + + def __init__(self, config:Config): + super().__init__() + + cfg = config + input_channels = 1 + last_layer = [(cfg.num_classes-1, 1, {})] + self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True) + + def forward(self, x): + x = self.maskiou_net(x) + maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) + return maskiou_p + diff --git a/layers/modules/fpn.py b/layers/modules/fpn.py new file mode 100644 index 000000000..13cda2597 --- /dev/null +++ b/layers/modules/fpn.py @@ -0,0 +1,112 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +from typing import List + +#local imports +from data.config import Config +from utils.script_module_wrapper import ScriptModuleWrapper, script_method_wrapper + + + +class FPN(ScriptModuleWrapper): + """ + Implements a general version of the FPN introduced in + https://arxiv.org/pdf/1612.03144.pdf + + Parameters (in cfg.fpn): + - num_features (int): The number of output features in the fpn layers. + - interpolation_mode (str): The mode to pass to F.interpolate. + - num_downsample (int): The number of downsampled layers to add onto the selected layers. + These extra layers are downsampled from the last selected layer. + + Args: + - in_channels (list): For each conv layer you supply in the forward pass, + how many features will it have? + """ + __constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 'relu_pred_layers', + 'lat_layers', 'pred_layers', 'downsample_layers', 'relu_downsample_layers'] + + def __init__(self, in_channels, config:Config): + super().__init__() + + cfg = config + + self.lat_layers = nn.ModuleList([ + nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1) + for x in reversed(in_channels) + ]) + + # This is here for backwards compatability + padding = 1 if cfg.fpn.pad else 0 + self.pred_layers = nn.ModuleList([ + nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding) + for _ in in_channels + ]) + + if cfg.fpn.use_conv_downsample: + self.downsample_layers = nn.ModuleList([ + nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2) + for _ in range(cfg.fpn.num_downsample) + ]) + + self.interpolation_mode = cfg.fpn.interpolation_mode + self.num_downsample = cfg.fpn.num_downsample + self.use_conv_downsample = cfg.fpn.use_conv_downsample + self.relu_downsample_layers = cfg.fpn.relu_downsample_layers + self.relu_pred_layers = cfg.fpn.relu_pred_layers + + @script_method_wrapper + def forward(self, convouts:List[torch.Tensor]): + """ + Args: + - convouts (list): A list of convouts for the corresponding layers in in_channels. + Returns: + - A list of FPN convouts in the same order as x with extra downsample layers if requested. + """ + + out = [] + x = torch.zeros(1, device=convouts[0].device) + for i in range(len(convouts)): + out.append(x) + + # For backward compatability, the conv layers are stored in reverse but the input and output is + # given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers. + j = len(convouts) + for lat_layer in self.lat_layers: + j -= 1 + + if j < len(convouts) - 1: + _, _, h, w = convouts[j].size() + x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False) + + x = x + lat_layer(convouts[j]) + out[j] = x + + # This janky second loop is here because TorchScript. + j = len(convouts) + for pred_layer in self.pred_layers: + j -= 1 + out[j] = pred_layer(out[j]) + + if self.relu_pred_layers: + F.relu(out[j], inplace=True) + + cur_idx = len(out) + + # In the original paper, this takes care of P6 + if self.use_conv_downsample: + for downsample_layer in self.downsample_layers: + out.append(downsample_layer(out[-1])) + else: + for idx in range(self.num_downsample): + # Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript. + out.append(nn.functional.max_pool2d(out[-1], 1, stride=2)) + + if self.relu_downsample_layers: + for idx in range(len(out) - cur_idx): + out[idx] = F.relu(out[idx + cur_idx], inplace=False) + + return out diff --git a/layers/modules/prediction.py b/layers/modules/prediction.py new file mode 100644 index 000000000..80313a8c7 --- /dev/null +++ b/layers/modules/prediction.py @@ -0,0 +1,238 @@ +import torch +import torchvision +from torch import nn +from torchvision.models.resnet import Bottleneck +import torch.nn.functional as F + +from itertools import product +from math import sqrt +from collections import defaultdict + +# local imports +from utils import timer +from utils.functions import make_net +from data.config import mask_type, Config + +class PredictionModule(nn.Module): + """ + The (c) prediction module adapted from DSSD: + https://arxiv.org/pdf/1701.06659.pdf + + Note that this is slightly different to the module in the paper + because the Bottleneck block actually has a 3x3 convolution in + the middle instead of a 1x1 convolution. Though, I really can't + be arsed to implement it myself, and, who knows, this might be + better. + + Args: + - in_channels: The input feature size. + - out_channels: The output feature size (must be a multiple of 4). + - aspect_ratios: A list of lists of priorbox aspect ratios (one list per scale). + - scales: A list of priorbox scales relative to this layer's convsize. + For instance: If this layer has convouts of size 30x30 for + an image of size 600x600, the 'default' (scale + of 1) for this layer would produce bounding + boxes with an area of 20x20px. If the scale is + .5 on the other hand, this layer would consider + bounding boxes with area 10x10px, etc. + - parent: If parent is a PredictionModule, this module will use all the layers + from parent instead of from this module. + """ + + def __init__(self, in_channels:int, cfg:Config, out_channels:int=1024, aspect_ratios=[[1]], scales=[1], parent=None, index=0): + """ + @param cfg - config, passed from Yolact class + """ + super().__init__() + + self.cfg = cfg + self.prior_cache = defaultdict(lambda: None) + self.num_classes = cfg.num_classes + self.mask_dim = cfg.mask_dim # Defined by Yolact + self.num_priors = sum(len(x)*len(scales) for x in aspect_ratios) + self.parent = [parent] # Don't include this in the state dict + self.index = index + self.num_heads = cfg.num_heads # Defined by Yolact + + if cfg.mask_proto_split_prototypes_by_head and cfg.mask_type == mask_type.lincomb: + self.mask_dim = self.mask_dim // self.num_heads + + if cfg.mask_proto_prototypes_as_features: + in_channels += self.mask_dim + + if parent is None: + if cfg.extra_head_net is None: + out_channels = in_channels + else: + self.upfeature, out_channels = make_net(in_channels, cfg.extra_head_net) + + if cfg.use_prediction_module: + self.block = Bottleneck(out_channels, out_channels // 4) + self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=True) + self.bn = nn.BatchNorm2d(out_channels) + + self.bbox_layer = nn.Conv2d(out_channels, self.num_priors * 4, **cfg.head_layer_params) + self.conf_layer = nn.Conv2d(out_channels, self.num_priors * self.num_classes, **cfg.head_layer_params) + self.mask_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, **cfg.head_layer_params) + + if cfg.use_mask_scoring: + self.score_layer = nn.Conv2d(out_channels, self.num_priors, **cfg.head_layer_params) + + if cfg.use_instance_coeff: + self.inst_layer = nn.Conv2d(out_channels, self.num_priors * cfg.num_instance_coeffs, **cfg.head_layer_params) + + # What is this ugly lambda doing in the middle of all this clean prediction module code? + def make_extra(num_layers): + if num_layers == 0: + return lambda x: x + else: + # Looks more complicated than it is. This just creates an array of num_layers alternating conv-relu + return nn.Sequential(*sum([[ + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True) + ] for _ in range(num_layers)], [])) + + self.bbox_extra, self.conf_extra, self.mask_extra = [make_extra(x) for x in cfg.extra_layers] + + if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_coeff_gate: + self.gate_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, kernel_size=3, padding=1) + + self.aspect_ratios = aspect_ratios + self.scales = scales + + self.priors = None + self.last_conv_size = None + self.last_img_size = None + + def forward(self, x): + """ + Args: + - x: The convOut from a layer in the backbone network + Size: [batch_size, in_channels, conv_h, conv_w]) + + Returns a tuple (bbox_coords, class_confs, mask_output, prior_boxes) with sizes + - bbox_coords: [batch_size, conv_h*conv_w*num_priors, 4] + - class_confs: [batch_size, conv_h*conv_w*num_priors, num_classes] + - mask_output: [batch_size, conv_h*conv_w*num_priors, mask_dim] + - prior_boxes: [conv_h*conv_w*num_priors, 4] + """ + # In case we want to use another module's layers + src = self if self.parent[0] is None else self.parent[0] + + conv_h = x.size(2) + conv_w = x.size(3) + + cfg = self.cfg + if cfg.extra_head_net is not None: + x = src.upfeature(x) + + if cfg.use_prediction_module: + # The two branches of PM design (c) + a = src.block(x) + + b = src.conv(x) + b = src.bn(b) + b = F.relu(b) + + # TODO: Possibly switch this out for a product + x = a + b + + bbox_x = src.bbox_extra(x) + conf_x = src.conf_extra(x) + mask_x = src.mask_extra(x) + + bbox = src.bbox_layer(bbox_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) + conf = src.conf_layer(conf_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes) + + if self.cfg.eval_mask_branch: + mask = src.mask_layer(mask_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.mask_dim) + else: + mask = torch.zeros(x.size(0), bbox.size(1), self.mask_dim, device=bbox.device) + + if self.cfg.use_mask_scoring: + score = src.score_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 1) + + if self.cfg.use_instance_coeff: + inst = src.inst_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, cfg.num_instance_coeffs) + + # See box_utils.decode for an explanation of this + if self.cfg.use_yolo_regressors: + bbox[:, :, :2] = torch.sigmoid(bbox[:, :, :2]) - 0.5 + bbox[:, :, 0] /= conv_w + bbox[:, :, 1] /= conv_h + + if self.cfg.eval_mask_branch: + if self.cfg.mask_type == mask_type.direct: + mask = torch.sigmoid(mask) + elif self.cfg.mask_type == mask_type.lincomb: + mask = self.cfg.mask_proto_coeff_activation(mask) + + if self.cfg.mask_proto_coeff_gate: + gate = src.gate_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.mask_dim) + mask = mask * torch.sigmoid(gate) + + if self.cfg.mask_proto_split_prototypes_by_head and self.cfg.mask_type == mask_type.lincomb: + mask = F.pad(mask, (self.index * self.mask_dim, (self.num_heads - self.index - 1) * self.mask_dim), mode='constant', value=0) + + priors = self.make_priors(conv_h, conv_w, x.device) + + preds = { 'loc': bbox, 'conf': conf, 'mask': mask, 'priors': priors } + + if self.cfg.use_mask_scoring: + preds['score'] = score + + if self.cfg.use_instance_coeff: + preds['inst'] = inst + + return preds + + def make_priors(self, conv_h, conv_w, device): + """ Note that priors are [x,y,width,height] where (x,y) is the center of the box. """ + size = (conv_h, conv_w) + cfg = self.cfg + + with timer.env('makepriors'): + if self.last_img_size != (cfg._tmp_img_w, cfg._tmp_img_h): + prior_data = [] + + # Iteration order is important (it has to sync up with the convout) + for j, i in product(range(conv_h), range(conv_w)): + # +0.5 because priors are in center-size notation + x = (i + 0.5) / conv_w + y = (j + 0.5) / conv_h + + for ars in self.aspect_ratios: + for scale in self.scales: + for ar in ars: + if not cfg.backbone.preapply_sqrt: + ar = sqrt(ar) + + if cfg.backbone.use_pixel_scales: + w = scale * ar / cfg.max_size + h = scale / ar / cfg.max_size + else: + w = scale * ar / conv_w + h = scale / ar / conv_h + + # This is for backward compatability with a bug where I made everything square by accident + if cfg.backbone.use_square_anchors: + h = w + + prior_data += [x, y, w, h] + + self.priors = torch.Tensor(prior_data, device=device).view(-1, 4).detach() + self.priors.requires_grad = False + self.last_img_size = (cfg._tmp_img_w, cfg._tmp_img_h) + self.last_conv_size = (conv_w, conv_h) + self.prior_cache[size] = None + elif self.priors.device != device: + # This whole weird situation is so that DataParalell doesn't copy the priors each iteration + if self.prior_cache[size] is None: + self.prior_cache[size] = {} + + if device not in prior_cache[size]: + self.prior_cache[size][device] = self.priors.to(device) + + self.priors = self.prior_cache[size][device] + + return self.priors diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..1b90c5c39 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="yolact", # Replace with your own username + version="1.2.0", + author="Daniel Bolya", + author_email="author@example.com", + description="YOLACT a real-time instance segmentation", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/dbolya/yolact", + packages=setuptools.find_packages(), + py_modules=['yolact','backbone','eval'], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) diff --git a/utils/functions.py b/utils/functions.py index 3b7a4e45a..0de8e25d1 100644 --- a/utils/functions.py +++ b/utils/functions.py @@ -6,6 +6,8 @@ from pathlib import Path from layers.interpolate import InterpolateModule +from layers.modules.concat import Concat + class MovingAverage(): """ Keeps an average window of the specified number of items. """ @@ -210,4 +212,4 @@ def make_layer(layer_cfg): if not include_last_relu: net = net[:-1] - return nn.Sequential(*(net)), in_channels \ No newline at end of file + return nn.Sequential(*(net)), in_channels diff --git a/utils/script_module_wrapper.py b/utils/script_module_wrapper.py new file mode 100644 index 000000000..f918b7c79 --- /dev/null +++ b/utils/script_module_wrapper.py @@ -0,0 +1,10 @@ +import torch +import torch.nn + +# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules +use_jit = torch.cuda.device_count() <= 1 +if not use_jit: + print('Multiple GPUs detected! Turning off JIT.') + +ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module #TODO remove once nn.Module supports JIT script modules +script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn diff --git a/yolact.py b/yolact.py index d83703bb7..948148e34 100644 --- a/yolact.py +++ b/yolact.py @@ -1,380 +1,28 @@ import torch, torchvision import torch.nn as nn import torch.nn.functional as F +import torch.backends.cudnn as cudnn from torchvision.models.resnet import Bottleneck import numpy as np -from itertools import product -from math import sqrt -from typing import List -from collections import defaultdict -from data.config import cfg, mask_type +from data.config import mask_type, set_cfg from layers import Detect from layers.interpolate import InterpolateModule from backbone import construct_backbone -import torch.backends.cudnn as cudnn from utils import timer from utils.functions import MovingAverage, make_net +from data.config import Config, mask_type + +#locally defined modules +from layers.modules.prediction import PredictionModule +from layers.modules.fast_mask_iou import FastMaskIoUNet +from layers.modules.fpn import FPN # This is required for Pytorch 1.0.1 on Windows to initialize Cuda on some driver versions. # See the bug report here: https://github.com/pytorch/pytorch/issues/17108 torch.cuda.current_device() -# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules -use_jit = torch.cuda.device_count() <= 1 -if not use_jit: - print('Multiple GPUs detected! Turning off JIT.') - -ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module -script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn - - - -class Concat(nn.Module): - def __init__(self, nets, extra_params): - super().__init__() - - self.nets = nn.ModuleList(nets) - self.extra_params = extra_params - - def forward(self, x): - # Concat each along the channel dimension - return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params) - -prior_cache = defaultdict(lambda: None) - -class PredictionModule(nn.Module): - """ - The (c) prediction module adapted from DSSD: - https://arxiv.org/pdf/1701.06659.pdf - - Note that this is slightly different to the module in the paper - because the Bottleneck block actually has a 3x3 convolution in - the middle instead of a 1x1 convolution. Though, I really can't - be arsed to implement it myself, and, who knows, this might be - better. - - Args: - - in_channels: The input feature size. - - out_channels: The output feature size (must be a multiple of 4). - - aspect_ratios: A list of lists of priorbox aspect ratios (one list per scale). - - scales: A list of priorbox scales relative to this layer's convsize. - For instance: If this layer has convouts of size 30x30 for - an image of size 600x600, the 'default' (scale - of 1) for this layer would produce bounding - boxes with an area of 20x20px. If the scale is - .5 on the other hand, this layer would consider - bounding boxes with area 10x10px, etc. - - parent: If parent is a PredictionModule, this module will use all the layers - from parent instead of from this module. - """ - - def __init__(self, in_channels, out_channels=1024, aspect_ratios=[[1]], scales=[1], parent=None, index=0): - super().__init__() - - self.num_classes = cfg.num_classes - self.mask_dim = cfg.mask_dim # Defined by Yolact - self.num_priors = sum(len(x)*len(scales) for x in aspect_ratios) - self.parent = [parent] # Don't include this in the state dict - self.index = index - self.num_heads = cfg.num_heads # Defined by Yolact - - if cfg.mask_proto_split_prototypes_by_head and cfg.mask_type == mask_type.lincomb: - self.mask_dim = self.mask_dim // self.num_heads - - if cfg.mask_proto_prototypes_as_features: - in_channels += self.mask_dim - - if parent is None: - if cfg.extra_head_net is None: - out_channels = in_channels - else: - self.upfeature, out_channels = make_net(in_channels, cfg.extra_head_net) - - if cfg.use_prediction_module: - self.block = Bottleneck(out_channels, out_channels // 4) - self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=True) - self.bn = nn.BatchNorm2d(out_channels) - - self.bbox_layer = nn.Conv2d(out_channels, self.num_priors * 4, **cfg.head_layer_params) - self.conf_layer = nn.Conv2d(out_channels, self.num_priors * self.num_classes, **cfg.head_layer_params) - self.mask_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, **cfg.head_layer_params) - - if cfg.use_mask_scoring: - self.score_layer = nn.Conv2d(out_channels, self.num_priors, **cfg.head_layer_params) - - if cfg.use_instance_coeff: - self.inst_layer = nn.Conv2d(out_channels, self.num_priors * cfg.num_instance_coeffs, **cfg.head_layer_params) - - # What is this ugly lambda doing in the middle of all this clean prediction module code? - def make_extra(num_layers): - if num_layers == 0: - return lambda x: x - else: - # Looks more complicated than it is. This just creates an array of num_layers alternating conv-relu - return nn.Sequential(*sum([[ - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.ReLU(inplace=True) - ] for _ in range(num_layers)], [])) - - self.bbox_extra, self.conf_extra, self.mask_extra = [make_extra(x) for x in cfg.extra_layers] - - if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_coeff_gate: - self.gate_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, kernel_size=3, padding=1) - - self.aspect_ratios = aspect_ratios - self.scales = scales - - self.priors = None - self.last_conv_size = None - self.last_img_size = None - - def forward(self, x): - """ - Args: - - x: The convOut from a layer in the backbone network - Size: [batch_size, in_channels, conv_h, conv_w]) - - Returns a tuple (bbox_coords, class_confs, mask_output, prior_boxes) with sizes - - bbox_coords: [batch_size, conv_h*conv_w*num_priors, 4] - - class_confs: [batch_size, conv_h*conv_w*num_priors, num_classes] - - mask_output: [batch_size, conv_h*conv_w*num_priors, mask_dim] - - prior_boxes: [conv_h*conv_w*num_priors, 4] - """ - # In case we want to use another module's layers - src = self if self.parent[0] is None else self.parent[0] - - conv_h = x.size(2) - conv_w = x.size(3) - - if cfg.extra_head_net is not None: - x = src.upfeature(x) - - if cfg.use_prediction_module: - # The two branches of PM design (c) - a = src.block(x) - - b = src.conv(x) - b = src.bn(b) - b = F.relu(b) - - # TODO: Possibly switch this out for a product - x = a + b - - bbox_x = src.bbox_extra(x) - conf_x = src.conf_extra(x) - mask_x = src.mask_extra(x) - - bbox = src.bbox_layer(bbox_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) - conf = src.conf_layer(conf_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes) - - if cfg.eval_mask_branch: - mask = src.mask_layer(mask_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.mask_dim) - else: - mask = torch.zeros(x.size(0), bbox.size(1), self.mask_dim, device=bbox.device) - - if cfg.use_mask_scoring: - score = src.score_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 1) - - if cfg.use_instance_coeff: - inst = src.inst_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, cfg.num_instance_coeffs) - - # See box_utils.decode for an explanation of this - if cfg.use_yolo_regressors: - bbox[:, :, :2] = torch.sigmoid(bbox[:, :, :2]) - 0.5 - bbox[:, :, 0] /= conv_w - bbox[:, :, 1] /= conv_h - - if cfg.eval_mask_branch: - if cfg.mask_type == mask_type.direct: - mask = torch.sigmoid(mask) - elif cfg.mask_type == mask_type.lincomb: - mask = cfg.mask_proto_coeff_activation(mask) - - if cfg.mask_proto_coeff_gate: - gate = src.gate_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.mask_dim) - mask = mask * torch.sigmoid(gate) - - if cfg.mask_proto_split_prototypes_by_head and cfg.mask_type == mask_type.lincomb: - mask = F.pad(mask, (self.index * self.mask_dim, (self.num_heads - self.index - 1) * self.mask_dim), mode='constant', value=0) - - priors = self.make_priors(conv_h, conv_w, x.device) - - preds = { 'loc': bbox, 'conf': conf, 'mask': mask, 'priors': priors } - - if cfg.use_mask_scoring: - preds['score'] = score - - if cfg.use_instance_coeff: - preds['inst'] = inst - - return preds - - def make_priors(self, conv_h, conv_w, device): - """ Note that priors are [x,y,width,height] where (x,y) is the center of the box. """ - global prior_cache - size = (conv_h, conv_w) - - with timer.env('makepriors'): - if self.last_img_size != (cfg._tmp_img_w, cfg._tmp_img_h): - prior_data = [] - - # Iteration order is important (it has to sync up with the convout) - for j, i in product(range(conv_h), range(conv_w)): - # +0.5 because priors are in center-size notation - x = (i + 0.5) / conv_w - y = (j + 0.5) / conv_h - - for ars in self.aspect_ratios: - for scale in self.scales: - for ar in ars: - if not cfg.backbone.preapply_sqrt: - ar = sqrt(ar) - - if cfg.backbone.use_pixel_scales: - w = scale * ar / cfg.max_size - h = scale / ar / cfg.max_size - else: - w = scale * ar / conv_w - h = scale / ar / conv_h - - # This is for backward compatability with a bug where I made everything square by accident - if cfg.backbone.use_square_anchors: - h = w - - prior_data += [x, y, w, h] - - self.priors = torch.Tensor(prior_data, device=device).view(-1, 4).detach() - self.priors.requires_grad = False - self.last_img_size = (cfg._tmp_img_w, cfg._tmp_img_h) - self.last_conv_size = (conv_w, conv_h) - prior_cache[size] = None - elif self.priors.device != device: - # This whole weird situation is so that DataParalell doesn't copy the priors each iteration - if prior_cache[size] is None: - prior_cache[size] = {} - - if device not in prior_cache[size]: - prior_cache[size][device] = self.priors.to(device) - - self.priors = prior_cache[size][device] - - return self.priors - -class FPN(ScriptModuleWrapper): - """ - Implements a general version of the FPN introduced in - https://arxiv.org/pdf/1612.03144.pdf - - Parameters (in cfg.fpn): - - num_features (int): The number of output features in the fpn layers. - - interpolation_mode (str): The mode to pass to F.interpolate. - - num_downsample (int): The number of downsampled layers to add onto the selected layers. - These extra layers are downsampled from the last selected layer. - - Args: - - in_channels (list): For each conv layer you supply in the forward pass, - how many features will it have? - """ - __constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 'relu_pred_layers', - 'lat_layers', 'pred_layers', 'downsample_layers', 'relu_downsample_layers'] - - def __init__(self, in_channels): - super().__init__() - - self.lat_layers = nn.ModuleList([ - nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1) - for x in reversed(in_channels) - ]) - - # This is here for backwards compatability - padding = 1 if cfg.fpn.pad else 0 - self.pred_layers = nn.ModuleList([ - nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding) - for _ in in_channels - ]) - - if cfg.fpn.use_conv_downsample: - self.downsample_layers = nn.ModuleList([ - nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2) - for _ in range(cfg.fpn.num_downsample) - ]) - - self.interpolation_mode = cfg.fpn.interpolation_mode - self.num_downsample = cfg.fpn.num_downsample - self.use_conv_downsample = cfg.fpn.use_conv_downsample - self.relu_downsample_layers = cfg.fpn.relu_downsample_layers - self.relu_pred_layers = cfg.fpn.relu_pred_layers - - @script_method_wrapper - def forward(self, convouts:List[torch.Tensor]): - """ - Args: - - convouts (list): A list of convouts for the corresponding layers in in_channels. - Returns: - - A list of FPN convouts in the same order as x with extra downsample layers if requested. - """ - - out = [] - x = torch.zeros(1, device=convouts[0].device) - for i in range(len(convouts)): - out.append(x) - - # For backward compatability, the conv layers are stored in reverse but the input and output is - # given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers. - j = len(convouts) - for lat_layer in self.lat_layers: - j -= 1 - - if j < len(convouts) - 1: - _, _, h, w = convouts[j].size() - x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False) - - x = x + lat_layer(convouts[j]) - out[j] = x - - # This janky second loop is here because TorchScript. - j = len(convouts) - for pred_layer in self.pred_layers: - j -= 1 - out[j] = pred_layer(out[j]) - - if self.relu_pred_layers: - F.relu(out[j], inplace=True) - - cur_idx = len(out) - - # In the original paper, this takes care of P6 - if self.use_conv_downsample: - for downsample_layer in self.downsample_layers: - out.append(downsample_layer(out[-1])) - else: - for idx in range(self.num_downsample): - # Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript. - out.append(nn.functional.max_pool2d(out[-1], 1, stride=2)) - - if self.relu_downsample_layers: - for idx in range(len(out) - cur_idx): - out[idx] = F.relu(out[idx + cur_idx], inplace=False) - - return out - -class FastMaskIoUNet(ScriptModuleWrapper): - - def __init__(self): - super().__init__() - input_channels = 1 - last_layer = [(cfg.num_classes-1, 1, {})] - self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True) - - def forward(self, x): - x = self.maskiou_net(x) - maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) - - return maskiou_p - - class Yolact(nn.Module): """ @@ -396,9 +44,31 @@ class Yolact(nn.Module): - pred_aspect_ratios: A list of lists of aspect ratios with len(selected_layers) (see PredictionModule) """ - def __init__(self): + def __init__(self, + config_name="yolact_base_config", + device_type="gpu" + ): + """ + @param config_name: string name of used config, choose from ./data/config.py, default "yolact_base" + @param device_type: string, type of devices used, choose from "gpu","cpu","tpu". Default "gpu". + """ super().__init__() + ## set (custom) config + cfg = set_cfg(str(config_name)) + self.cfg = cfg + + ## GPU + #TODO try half: net = net.half() + assert(device_type == "gpu" or device_type == "cpu" or device_type == "tpu") + assert(device_type != "tpu"), "TPU not yet supported!" + self.device_type = device_type + if self.device_type == "gpu": + self.cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + + + self.backbone = construct_backbone(cfg.backbone) if cfg.freeze_bn: @@ -432,11 +102,11 @@ def __init__(self): src_channels = self.backbone.channels if cfg.use_maskiou: - self.maskiou_net = FastMaskIoUNet() + self.maskiou_net = FastMaskIoUNet(self.cfg) if cfg.fpn is not None: # Some hacky rewiring to accomodate the FPN - self.fpn = FPN([src_channels[i] for i in self.selected_layers]) + self.fpn = FPN([src_channels[i] for i in self.selected_layers], config=self.cfg) self.selected_layers = list(range(len(self.selected_layers) + cfg.fpn.num_downsample)) src_channels = [cfg.fpn.num_features] * len(self.selected_layers) @@ -450,7 +120,7 @@ def __init__(self): if cfg.share_prediction_module and idx > 0: parent = self.prediction_layers[0] - pred = PredictionModule(src_channels[layer_idx], src_channels[layer_idx], + pred = PredictionModule(src_channels[layer_idx], self.cfg, src_channels[layer_idx], aspect_ratios = cfg.backbone.pred_aspect_ratios[idx], scales = cfg.backbone.pred_scales[idx], parent = parent, @@ -470,6 +140,11 @@ def __init__(self): self.detect = Detect(cfg.num_classes, bkg_label=0, top_k=cfg.nms_top_k, conf_thresh=cfg.nms_conf_thresh, nms_thresh=cfg.nms_thresh) + + # set default backbone weights + self.init_weights(backbone_path='weights/' + cfg.backbone.path) + + def save_weights(self, path): """ Saves the model's weights using compression because the file sizes were getting too big. """ torch.save(self.state_dict(), path) @@ -485,7 +160,7 @@ def load_weights(self, path): # Also for backward compatibility with v1.0 weights, do this check if key.startswith('fpn.downsample_layers.'): - if cfg.fpn is not None and int(key.split('.')[2]) >= cfg.fpn.num_downsample: + if self.cfg.fpn is not None and int(key.split('.')[2]) >= self.cfg.fpn.num_downsample: del state_dict[key] self.load_state_dict(state_dict) @@ -526,7 +201,7 @@ def all_in(x, y): nn.init.xavier_uniform_(module.weight.data) if module.bias is not None: - if cfg.use_focal_loss and 'conf_layer' in name: + if self.cfg.use_focal_loss and 'conf_layer' in name: if not cfg.use_sigmoid_focal_loss: # Initialize the last layer as in the focal loss paper. # Because we use softmax and not sigmoid, I had to derive an alternate expression @@ -549,7 +224,7 @@ def all_in(x, y): def train(self, mode=True): super().train(mode) - if cfg.freeze_bn: + if self.cfg.freeze_bn: self.freeze_bn() def freeze_bn(self, enable=False): @@ -564,20 +239,21 @@ def freeze_bn(self, enable=False): def forward(self, x): """ The input should be of size [batch_size, 3, img_h, img_w] """ _, _, img_h, img_w = x.size() + cfg = self.cfg cfg._tmp_img_h = img_h cfg._tmp_img_w = img_w with timer.env('backbone'): outs = self.backbone(x) - if cfg.fpn is not None: + if self.cfg.fpn is not None: with timer.env('fpn'): # Use backbone.selected_layers because we overwrote self.selected_layers outs = [outs[i] for i in cfg.backbone.selected_layers] outs = self.fpn(outs) proto_out = None - if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: + if self.cfg.mask_type == mask_type.lincomb and self.cfg.eval_mask_branch: with timer.env('proto'): proto_x = x if self.proto_src is None else outs[self.proto_src] @@ -586,24 +262,25 @@ def forward(self, x): proto_x = torch.cat([proto_x, grids], dim=1) proto_out = self.proto_net(proto_x) - proto_out = cfg.mask_proto_prototype_activation(proto_out) + proto_out = self.cfg.mask_proto_prototype_activation(proto_out) - if cfg.mask_proto_prototypes_as_features: + if self.cfg.mask_proto_prototypes_as_features: # Clone here because we don't want to permute this, though idk if contiguous makes this unnecessary proto_downsampled = proto_out.clone() - if cfg.mask_proto_prototypes_as_features_no_grad: + if self.cfg.mask_proto_prototypes_as_features_no_grad: proto_downsampled = proto_out.detach() # Move the features last so the multiplication is easy proto_out = proto_out.permute(0, 2, 3, 1).contiguous() - if cfg.mask_proto_bias: + if self.cfg.mask_proto_bias: bias_shape = [x for x in proto_out.size()] bias_shape[-1] = 1 proto_out = torch.cat([proto_out, torch.ones(*bias_shape)], -1) + cfg = self.cfg with timer.env('pred_heads'): pred_outs = { 'loc': [], 'conf': [], 'mask': [], 'priors': [] } @@ -683,19 +360,16 @@ def forward(self, x): from utils.functions import init_console init_console() + # initialize yolact + net = Yolact() + # Use the first argument to set the config if you want import sys if len(sys.argv) > 1: - from data.config import set_cfg - set_cfg(sys.argv[1]) + net = Yolact(config_name=sys.argv[1]) - net = Yolact() + cfg = net.cfg net.train() - net.init_weights(backbone_path='weights/' + cfg.backbone.path) - - # GPU - net = net.cuda() - torch.set_default_tensor_type('torch.cuda.FloatTensor') x = torch.zeros((1, 3, cfg.max_size, cfg.max_size)) y = net(x)