Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention instead of TDID correlation #8

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions configs/configAVD2.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,38 @@ class Config():
"""

#Directories - MUST BE CHANGED for your environment
DATA_BASE_DIR = '/net/bvisionserver3/playpen/ammirato/sandbox/code/target_driven_instance_detection/Data/'
AVD_ROOT_DIR = '/net/bvisionserver3/playpen10/ammirato/Data/HalvedRohitData/'
# DATA_BASE_DIR = '/net/bvisionserver3/playpen/ammirato/sandbox/code/TDID/target_driven_instance_detection/Data/'
DATA_BASE_DIR = '/net/bvisionserver1/playmirror/mshvets/Phil_TDID/target_driven_instance_detection/Data/'
AVD_ROOT_DIR = '/net/bvisionserver3/playpen1/ammirato/Data/TDID_datasets/HalvedActiveVisionDataset/'
FULL_MODEL_LOAD_DIR= os.path.join(DATA_BASE_DIR, 'Models/')
SNAPSHOT_SAVE_DIR= os.path.join(DATA_BASE_DIR , 'Models/')
META_SAVE_DIR = os.path.join(DATA_BASE_DIR, 'ModelsMeta/')
TARGET_IMAGE_DIR= os.path.join(DATA_BASE_DIR, 'AVD_and_BigBIRD_targets_v1/')
TEST_OUTPUT_DIR = os.path.join(DATA_BASE_DIR, 'TestOutputs/')
TEST_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR, 'GT/AVD_split2_test.json')
VAL_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR ,'GT/AVD_part3_val.json')
VAL_GROUND_TRUTH_BOXES = os.path.join(DATA_BASE_DIR ,'GT/AVD_split2_test.json')

ATTENTION_ENABLED = True
# "sum", or "stack-reduce", or "attention-only"
ATTENTION_COMBINATION_MODE = "attention-only"
# "average" or "softmax"
ATTENTION_NORMALIZE_MODE = "average"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am adding these three flags. First is optional, as main code uses getattr and defaults to False. The other two are required in case first one is True.

Let's try to train a couple settings on the empty machines. Suggested:

  1. sum + softmax
  2. sum + average
  3. attention-only + average
  4. stack-reduce + softmax
  5. stack-reduce + average

I don't believe too much in softmax version, upon our discussion


#Model Loading and saving
FEATURE_NET_NAME= 'vgg16_bn'
PYTORCH_FEATURE_NET= True
USE_PRETRAINED_WEIGHTS = True
FULL_MODEL_LOAD_NAME= 'TDID_AVD2_03_15_26806_0.36337_0.35057.h5'
LOAD_FULL_MODEL= True
MODEL_BASE_SAVE_NAME = 'TDID_AVD2_04'
SAVE_FREQ = 15
FULL_MODEL_LOAD_NAME= ''
LOAD_FULL_MODEL= False
MODEL_BASE_SAVE_NAME = 'TDID_AVD2average_002'
SAVE_FREQ = 5
SAVE_BY_EPOCH = True


#Training
MAX_NUM_EPOCHS= 16
MAX_NUM_EPOCHS= 50
BATCH_SIZE = 5
LEARNING_RATE = .0001
LEARNING_RATE = .001
MOMENTUM = .9
WEIGHT_DECAY = .0005
DISPLAY_INTERVAL = 10
Expand Down
171 changes: 102 additions & 69 deletions model_defs/TDID.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .anchors.proposal_layer import proposal_layer as proposal_layer_py
from .anchors.anchor_target_layer import anchor_target_layer as anchor_target_layer_py
from .attention import CoAttention
from utils import *

class TDID(torch.nn.Module):
Expand All @@ -27,13 +28,22 @@ def __init__(self, cfg):

self.features,self._feat_stride,self.num_feature_channels = \
self.get_feature_net(cfg.FEATURE_NET_NAME)
self.embedding_conv = self.get_embedding_conv(cfg)
self.corr_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels,
self.num_feature_channels, 3,
relu=True, same_padding=True)
self.diff_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels,
self.num_feature_channels, 3,
relu=True, same_padding=True)

use_attention = getattr(cfg, "ATTENTION_ENABLED", False)
if use_attention:
print(">" * 100)
print("Construct the attention block")
self.coattention_layer = CoAttention(
in_channels=512,
combination_mode=cfg.ATTENTION_COMBINATION_MODE)
else:
self.embedding_conv = self.get_embedding_conv(cfg)
self.corr_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels,
self.num_feature_channels, 3,
relu=True, same_padding=True)
self.diff_conv = Conv2d(cfg.NUM_TARGETS*self.num_feature_channels,
self.num_feature_channels, 3,
relu=True, same_padding=True)
#for getting output size of score and bbbox convs
# 3 = number of anchor aspect ratios
# 2 = number of classes (background, target)
Expand All @@ -46,6 +56,8 @@ def __init__(self, cfg):
self.box_regression_loss = None
self.roi_cross_entropy_loss = None

self.use_attention = use_attention

@property
def loss(self):
'''
Expand Down Expand Up @@ -85,75 +97,96 @@ def forward(self, target_data, img_data, img_info, gt_boxes=None,
img_features = img_data
target_features = target_data
else:
# B x C x H x W (33x60)
img_features = self.features(img_data)
# B x C x h x w (h << H) (5x5)
target_features = self.features(target_data)


all_corrs = []
all_diffs = []
for batch_ind in range(img_features.size()[0]):
img_ind = np_to_variable(np.asarray([batch_ind]),
is_cuda=True, dtype=torch.LongTensor)
cur_img_feats = torch.index_select(img_features,0,img_ind)

cur_diffs = []
cur_corrs = []
for target_type in range(self.cfg.NUM_TARGETS):
target_ind = np_to_variable(np.asarray([batch_ind*
self.cfg.NUM_TARGETS+target_type]),
is_cuda=True,dtype=torch.LongTensor)
cur_target_feats = torch.index_select(target_features,0,
target_ind[0])
cur_target_feats = cur_target_feats.view(-1,1,
cur_target_feats.size()[2],
cur_target_feats.size()[3])
pooled_target_feats = F.max_pool2d(cur_target_feats,
(cur_target_feats.size()[2],
cur_target_feats.size()[3]))

cur_diffs.append(cur_img_feats -
pooled_target_feats.permute(1,0,2,3).expand_as(cur_img_feats))
if self.cfg.CORR_WITH_POOLED:
cur_corrs.append(F.conv2d(cur_img_feats,
pooled_target_feats,
groups=self.num_feature_channels))
if self.use_attention:
embedding_feats = self.coattention_layer(
img_features,
target_features,
mode=self.cfg.ATTENTION_NORMALIZE_MODE)
else:
# batch-separated attention here instead of the ###

######################
all_corrs = []
all_diffs = []
for batch_ind in range(img_features.size()[0]):
img_ind = np_to_variable(np.asarray([batch_ind]),
is_cuda=True, dtype=torch.LongTensor)
# 1 x C x H x W
cur_img_feats = torch.index_select(img_features,0,img_ind)

cur_diffs = []
cur_corrs = []
for target_type in range(self.cfg.NUM_TARGETS):
target_ind = np_to_variable(np.asarray([batch_ind*
self.cfg.NUM_TARGETS+target_type]),
is_cuda=True,dtype=torch.LongTensor)
cur_target_feats = torch.index_select(target_features,0,
target_ind[0])
cur_target_feats = cur_target_feats.view(-1,1,
cur_target_feats.size()[2],
cur_target_feats.size()[3])
# 1 x 1 x 1 x C
pooled_target_feats = F.max_pool2d(cur_target_feats,
(cur_target_feats.size()[2],
cur_target_feats.size()[3]))

cur_diffs.append(cur_img_feats -
pooled_target_feats.permute(1,0,2,3).expand_as(cur_img_feats))
if self.cfg.CORR_WITH_POOLED:
# we are here
cur_corrs.append(F.conv2d(cur_img_feats,
pooled_target_feats,
groups=self.num_feature_channels))
else:
target_conv_padding = (max(0,int(
target_features.size()[2]/2)),
max(0,int(
target_features.size()[3]/2)))
cur_corrs.append(F.conv2d(cur_img_feats,cur_target_feats,
padding=target_conv_padding,
groups=self.num_feature_channels))


cur_corrs = torch.cat(cur_corrs,1)
cur_corrs = self.select_to_match_dimensions(cur_corrs,cur_img_feats)
all_corrs.append(cur_corrs)
all_diffs.append(torch.cat(cur_diffs,1))

######################

corr = self.corr_conv(torch.cat(all_corrs,0))
diff = self.diff_conv(torch.cat(all_diffs,0))

if self.cfg.USE_IMG_FEATS and self.cfg.USE_DIFF_FEATS:
if self.cfg.USE_CC_FEATS:
concat_feats = torch.cat([corr,img_features, diff],1)
else:
target_conv_padding = (max(0,int(
target_features.size()[2]/2)),
max(0,int(
target_features.size()[3]/2)))
cur_corrs.append(F.conv2d(cur_img_feats,cur_target_feats,
padding=target_conv_padding,
groups=self.num_feature_channels))


cur_corrs = torch.cat(cur_corrs,1)
cur_corrs = self.select_to_match_dimensions(cur_corrs,cur_img_feats)
all_corrs.append(cur_corrs)
all_diffs.append(torch.cat(cur_diffs,1))

corr = self.corr_conv(torch.cat(all_corrs,0))
diff = self.diff_conv(torch.cat(all_diffs,0))

if self.cfg.USE_IMG_FEATS and self.cfg.USE_DIFF_FEATS:
if self.cfg.USE_CC_FEATS:
concat_feats = torch.cat([corr,img_features, diff],1)
else:
concat_feats = torch.cat([img_features, diff],1)
elif self.cfg.USE_IMG_FEATS:
if self.cfg.USE_CC_FEATS:
concat_feats = torch.cat([corr,img_features],1)
else:
concat_feats = torch.cat([img_features],1)
elif self.cfg.USE_DIFF_FEATS:
if self.cfg.USE_CC_FEATS:
concat_feats = torch.cat([corr,diff],1)
concat_feats = torch.cat([img_features, diff],1)
elif self.cfg.USE_IMG_FEATS:
if self.cfg.USE_CC_FEATS:
concat_feats = torch.cat([corr,img_features],1)
else:
concat_feats = torch.cat([img_features],1)
elif self.cfg.USE_DIFF_FEATS:
if self.cfg.USE_CC_FEATS:
# we are here
# B x 2D x H x W
concat_feats = torch.cat([corr,diff],1)
else:
concat_feats = torch.cat([diff],1)
else:
concat_feats = torch.cat([diff],1)
else:
concat_feats = corr
concat_feats = corr

embedding_feats = self.embedding_conv(concat_feats)


embedding_feats = self.embedding_conv(concat_feats)
#####################
class_score = self.score_conv(embedding_feats)
class_score_reshape = self.reshape_layer(class_score, 2)
class_prob = F.softmax(class_score_reshape)
Expand Down
Loading