diff --git a/configs/configAVD2.py b/configs/configAVD2.py old mode 100644 new mode 100755 index 49f32b5..41c7b36 --- a/configs/configAVD2.py +++ b/configs/configAVD2.py @@ -7,32 +7,38 @@ class Config(): """ #Directories - MUST BE CHANGED for your environment - DATA_BASE_DIR = '/net/bvisionserver3/playpen/ammirato/sandbox/code/target_driven_instance_detection/Data/' - AVD_ROOT_DIR = '/net/bvisionserver3/playpen10/ammirato/Data/HalvedRohitData/' + # DATA_BASE_DIR = '/net/bvisionserver3/playpen/ammirato/sandbox/code/TDID/target_driven_instance_detection/Data/' + DATA_BASE_DIR = '/net/bvisionserver1/playmirror/mshvets/Phil_TDID/target_driven_instance_detection/Data/' + AVD_ROOT_DIR = '/net/bvisionserver3/playpen1/ammirato/Data/TDID_datasets/HalvedActiveVisionDataset/' FULL_MODEL_LOAD_DIR= os.path.join(DATA_BASE_DIR, 'Models/') SNAPSHOT_SAVE_DIR= os.path.join(DATA_BASE_DIR , 'Models/') META_SAVE_DIR = os.path.join(DATA_BASE_DIR, 'ModelsMeta/') TARGET_IMAGE_DIR= os.path.join(DATA_BASE_DIR, 'AVD_and_BigBIRD_targets_v1/') TEST_OUTPUT_DIR = os.path.join(DATA_BASE_DIR, 'TestOutputs/') TEST_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR, 'GT/AVD_split2_test.json') - VAL_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR ,'GT/AVD_part3_val.json') + VAL_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR ,'GT/AVD_split2_test.json') + ATTENTION_ENABLED = True + # "sum", or "stack-reduce", or "attention-only" + ATTENTION_COMBINATION_MODE = "attention-only" + # "average" or "softmax" + ATTENTION_NORMALIZE_MODE = "average" #Model Loading and saving FEATURE_NET_NAME= 'vgg16_bn' PYTORCH_FEATURE_NET= True USE_PRETRAINED_WEIGHTS = True - FULL_MODEL_LOAD_NAME= 'TDID_AVD2_03_15_26806_0.36337_0.35057.h5' - LOAD_FULL_MODEL= True - MODEL_BASE_SAVE_NAME = 'TDID_AVD2_04' - SAVE_FREQ = 15 + FULL_MODEL_LOAD_NAME= '' + LOAD_FULL_MODEL= False + MODEL_BASE_SAVE_NAME = 'TDID_AVD2average_002' + SAVE_FREQ = 5 SAVE_BY_EPOCH = True #Training - MAX_NUM_EPOCHS= 16 + MAX_NUM_EPOCHS= 50 BATCH_SIZE = 5 - LEARNING_RATE = .0001 + LEARNING_RATE = .001 MOMENTUM = .9 WEIGHT_DECAY = .0005 DISPLAY_INTERVAL = 10 diff --git a/model_defs/TDID.py b/model_defs/TDID.py old mode 100644 new mode 100755 index 24f25bb..f910f29 --- a/model_defs/TDID.py +++ b/model_defs/TDID.py @@ -8,6 +8,7 @@ from .anchors.proposal_layer import proposal_layer as proposal_layer_py from .anchors.anchor_target_layer import anchor_target_layer as anchor_target_layer_py +from .attention import CoAttention from utils import * class TDID(torch.nn.Module): @@ -27,13 +28,22 @@ def __init__(self, cfg): self.features,self._feat_stride,self.num_feature_channels = \ self.get_feature_net(cfg.FEATURE_NET_NAME) - self.embedding_conv = self.get_embedding_conv(cfg) - self.corr_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels, - self.num_feature_channels, 3, - relu=True, same_padding=True) - self.diff_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels, - self.num_feature_channels, 3, - relu=True, same_padding=True) + + use_attention = getattr(cfg, "ATTENTION_ENABLED", False) + if use_attention: + print(">" * 100) + print("Construct the attention block") + self.coattention_layer = CoAttention( + in_channels=512, + combination_mode=cfg.ATTENTION_COMBINATION_MODE) + else: + self.embedding_conv = self.get_embedding_conv(cfg) + self.corr_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels, + self.num_feature_channels, 3, + relu=True, same_padding=True) + self.diff_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels, + self.num_feature_channels, 3, + relu=True, same_padding=True) #for getting output size of score and bbbox convs # 3 = number of anchor aspect ratios # 2 = number of classes (background, target) @@ -46,6 +56,8 @@ def __init__(self, cfg): self.box_regression_loss = None self.roi_cross_entropy_loss = None + self.use_attention = use_attention + @property def loss(self): ''' @@ -85,75 +97,96 @@ def forward(self, target_data, img_data, img_info, gt_boxes=None, img_features = img_data target_features = target_data else: + # B x C x H x W (33x60) img_features = self.features(img_data) + # B x C x h x w (h << H) (5x5) target_features = self.features(target_data) - all_corrs = [] - all_diffs = [] - for batch_ind in range(img_features.size()[0]): - img_ind = np_to_variable(np.asarray([batch_ind]), - is_cuda=True, dtype=torch.LongTensor) - cur_img_feats = torch.index_select(img_features,0,img_ind) - - cur_diffs = [] - cur_corrs = [] - for target_type in range(self.cfg.NUM_TARGETS): - target_ind = np_to_variable(np.asarray([batch_ind* - self.cfg.NUM_TARGETS+target_type]), - is_cuda=True,dtype=torch.LongTensor) - cur_target_feats = torch.index_select(target_features,0, - target_ind[0]) - cur_target_feats = cur_target_feats.view(-1,1, - cur_target_feats.size()[2], - cur_target_feats.size()[3]) - pooled_target_feats = F.max_pool2d(cur_target_feats, - (cur_target_feats.size()[2], - cur_target_feats.size()[3])) - - cur_diffs.append(cur_img_feats - - pooled_target_feats.permute(1,0,2,3).expand_as(cur_img_feats)) - if self.cfg.CORR_WITH_POOLED: - cur_corrs.append(F.conv2d(cur_img_feats, - pooled_target_feats, - groups=self.num_feature_channels)) + if self.use_attention: + embedding_feats = self.coattention_layer( + img_features, + target_features, + mode=self.cfg.ATTENTION_NORMALIZE_MODE) + else: + # batch-separated attention here instead of the ### + + ###################### + all_corrs = [] + all_diffs = [] + for batch_ind in range(img_features.size()[0]): + img_ind = np_to_variable(np.asarray([batch_ind]), + is_cuda=True, dtype=torch.LongTensor) + # 1 x C x H x W + cur_img_feats = torch.index_select(img_features,0,img_ind) + + cur_diffs = [] + cur_corrs = [] + for target_type in range(self.cfg.NUM_TARGETS): + target_ind = np_to_variable(np.asarray([batch_ind* + self.cfg.NUM_TARGETS+target_type]), + is_cuda=True,dtype=torch.LongTensor) + cur_target_feats = torch.index_select(target_features,0, + target_ind[0]) + cur_target_feats = cur_target_feats.view(-1,1, + cur_target_feats.size()[2], + cur_target_feats.size()[3]) + # 1 x 1 x 1 x C + pooled_target_feats = F.max_pool2d(cur_target_feats, + (cur_target_feats.size()[2], + cur_target_feats.size()[3])) + + cur_diffs.append(cur_img_feats - + pooled_target_feats.permute(1,0,2,3).expand_as(cur_img_feats)) + if self.cfg.CORR_WITH_POOLED: + # we are here + cur_corrs.append(F.conv2d(cur_img_feats, + pooled_target_feats, + groups=self.num_feature_channels)) + else: + target_conv_padding = (max(0,int( + target_features.size()[2]/2)), + max(0,int( + target_features.size()[3]/2))) + cur_corrs.append(F.conv2d(cur_img_feats,cur_target_feats, + padding=target_conv_padding, + groups=self.num_feature_channels)) + + + cur_corrs = torch.cat(cur_corrs,1) + cur_corrs = self.select_to_match_dimensions(cur_corrs,cur_img_feats) + all_corrs.append(cur_corrs) + all_diffs.append(torch.cat(cur_diffs,1)) + + ###################### + + corr = self.corr_conv(torch.cat(all_corrs,0)) + diff = self.diff_conv(torch.cat(all_diffs,0)) + + if self.cfg.USE_IMG_FEATS and self.cfg.USE_DIFF_FEATS: + if self.cfg.USE_CC_FEATS: + concat_feats = torch.cat([corr,img_features, diff],1) else: - target_conv_padding = (max(0,int( - target_features.size()[2]/2)), - max(0,int( - target_features.size()[3]/2))) - cur_corrs.append(F.conv2d(cur_img_feats,cur_target_feats, - padding=target_conv_padding, - groups=self.num_feature_channels)) - - - cur_corrs = torch.cat(cur_corrs,1) - cur_corrs = self.select_to_match_dimensions(cur_corrs,cur_img_feats) - all_corrs.append(cur_corrs) - all_diffs.append(torch.cat(cur_diffs,1)) - - corr = self.corr_conv(torch.cat(all_corrs,0)) - diff = self.diff_conv(torch.cat(all_diffs,0)) - - if self.cfg.USE_IMG_FEATS and self.cfg.USE_DIFF_FEATS: - if self.cfg.USE_CC_FEATS: - concat_feats = torch.cat([corr,img_features, diff],1) - else: - concat_feats = torch.cat([img_features, diff],1) - elif self.cfg.USE_IMG_FEATS: - if self.cfg.USE_CC_FEATS: - concat_feats = torch.cat([corr,img_features],1) - else: - concat_feats = torch.cat([img_features],1) - elif self.cfg.USE_DIFF_FEATS: - if self.cfg.USE_CC_FEATS: - concat_feats = torch.cat([corr,diff],1) + concat_feats = torch.cat([img_features, diff],1) + elif self.cfg.USE_IMG_FEATS: + if self.cfg.USE_CC_FEATS: + concat_feats = torch.cat([corr,img_features],1) + else: + concat_feats = torch.cat([img_features],1) + elif self.cfg.USE_DIFF_FEATS: + if self.cfg.USE_CC_FEATS: + # we are here + # B x 2D x H x W + concat_feats = torch.cat([corr,diff],1) + else: + concat_feats = torch.cat([diff],1) else: - concat_feats = torch.cat([diff],1) - else: - concat_feats = corr + concat_feats = corr + + embedding_feats = self.embedding_conv(concat_feats) + - embedding_feats = self.embedding_conv(concat_feats) + ##################### class_score = self.score_conv(embedding_feats) class_score_reshape = self.reshape_layer(class_score, 2) class_prob = F.softmax(class_score_reshape) diff --git a/model_defs/attention.py b/model_defs/attention.py new file mode 100755 index 0000000..6849a64 --- /dev/null +++ b/model_defs/attention.py @@ -0,0 +1,257 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class CoAttention(nn.Module): + def __init__(self, in_channels, inter_channels=None, bn_layer=False, + combination_mode="sum"): + assert combination_mode in ["sum", "stack-reduce", "attention-only"] + + super(CoAttention, self).__init__() + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + nn.init.kaiming_uniform(self.g.weight, a=1) + nn.init.constant(self.g.bias, 0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant(self.W[1].weight, 0) + nn.init.constant(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + # nn.init.constant(self.W.weight, 0) + # nn.init.constant(self.W.bias, 0) + nn.init.kaiming_uniform(self.W.weight, a=1) + nn.init.constant(self.W.bias, 0) + + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + nn.init.kaiming_uniform(self.theta.weight, a=1) + nn.init.constant(self.theta.bias, 0) + + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + nn.init.kaiming_uniform(self.phi.weight, a=1) + nn.init.constant(self.phi.bias, 0) + + self.combination_mode = combination_mode + if combination_mode == "stack-reduce": + self.output_reduce = conv_nd(in_channels=2*self.in_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, target, mode="average"): + ''' + :param x: B x C x H x W + :param y: B x C x h x w + :return: + ''' + + assert mode in ["softmax", "average"] + + batch_size = x.size(0) + + # embed the object features + g_target = self.g(target).view(batch_size, self.inter_channels, -1) + # B x hw x C + g_target = g_target.permute(0, 2, 1) + + # B x C x HW + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + # B x HW x C + theta_x = theta_x.permute(0, 2, 1) + # B x C x hw + phi_target = self.phi(target).view(batch_size, self.inter_channels, -1) + + # B x HW x hw + f = torch.matmul(theta_x, phi_target) + if mode == "softmax": + f_div_C = F.softmax(f, dim=-1) + else: + f_div_C = f / f.size(-1) + + # B x HW x C + y = torch.matmul(f_div_C, g_target) + + get_norm = lambda z, dim: torch.norm(z, p=2, dim=dim) + print("="*10) + print("Input x norm: {}".format( + get_norm(x, 1).mean().data.cpu().numpy())) + print("Target norm: {}".format( + get_norm(target, 1).mean().data.cpu().numpy())) + print("Embedded target norm: {}/{}/{}".format( + get_norm(g_target, 2).min().data.cpu().numpy(), + get_norm(g_target, 2).max().data.cpu().numpy(), + get_norm(g_target, 2).mean().data.cpu().numpy())) + print("Min/Max/Mean abs correlation: {}/{}/{}".format( + f.abs().min().data.cpu().numpy(), + f.abs().max().data.cpu().numpy(), + f.abs().mean().data.cpu().numpy())) + print("Aggregated norm: {}/{}/{}".format( + get_norm(y, 2).min().data.cpu().numpy(), + get_norm(y, 2).max().data.cpu().numpy(), + get_norm(y, 2).mean().data.cpu().numpy())) + + # B x C x HW + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + if self.combination_mode == "stack-reduce": + z = torch.cat([W_y, x], dim=1) + z = self.output_reduce(z) + elif self.combination_mode == "attention-only": + z = W_y + else: + z = W_y + x + + return z + + +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + + self.dimension = dimension + self.sub_sample = sub_sample + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool_layer = nn.MaxPool1d(kernel_size=(2)) + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant(self.W[1].weight, 0) + nn.init.constant(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if sub_sample: + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock1D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock1D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=1, sub_sample=sub_sample, + bn_layer=bn_layer) + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=2, sub_sample=sub_sample, + bn_layer=bn_layer) + + +class NONLocalBlock3D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock3D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=3, sub_sample=sub_sample, + bn_layer=bn_layer) + + +if __name__ == '__main__': + from torch.autograd import Variable + import torch + + sub_sample = True + bn_layer = True + + img = Variable(torch.zeros(2, 3, 20)) + net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) + + img = Variable(torch.zeros(2, 3, 20, 20)) + net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) + + img = Variable(torch.randn(2, 3, 10, 20, 20)) + net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) +