Skip to content

Commit

Permalink
add lvis
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid committed Nov 21, 2023
1 parent 266d038 commit 2cb3143
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_base_ = '../grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py'

model = dict(test_cfg=dict(
max_per_img=300,
chunked_size=40,
))


dataset_type = 'LVISV1Dataset'
data_root = 'data/coco/'

val_dataloader = dict(
dataset=dict(
data_root=data_root,
type=dataset_type,
ann_file='annotations/lvis_v1_minival_inserted_image_name.json',
data_prefix=dict(img='')))
test_dataloader = val_dataloader

# numpy < 1.24.0
val_evaluator = dict(
_delete_=True,
type='LVISFixedAPMetric',
ann_file=data_root +
'annotations/lvis_v1_minival_inserted_image_name.json')
test_evaluator = val_evaluator
5 changes: 5 additions & 0 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
--texts '$: lvis' --pred-score-thr 0.7 \
--palette random --chunked-size 80
python demo/image_demo.py demo/demo.jpg \
grounding_dino_swin-t_pretrain_obj365_goldg_cap4m \
--texts '$: lvis' --pred-score-thr 0.4 \
--palette random --chunked-size 80
Visualize prediction results::
Expand Down
142 changes: 123 additions & 19 deletions mmdet/models/detectors/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .dino import DINO
from .glip import (create_positive_map, create_positive_map_label_to_token,
run_ner)
import copy


def clean_label_name(name: str) -> str:
Expand All @@ -25,6 +26,20 @@ def clean_label_name(name: str) -> str:
return name


def chunks(lst: list, n: int) -> list:
"""Yield successive n-sized chunks from lst."""
all_ = []
for i in range(0, len(lst), n):
data_index = lst[i:i + n]
all_.append(data_index)
counter = 0
for i in all_:
counter += len(i)
assert (counter == len(lst))

return all_


@MODELS.register_module()
class GroundingDINO(DINO):
"""Implementation of `Grounding DINO: Marrying DINO with Grounded Pre-
Expand Down Expand Up @@ -190,14 +205,69 @@ def get_tokens_positive_and_prompts(
id, which is numbered from 1, to its positive token id.
The str represents the prompts.
"""
tokenized, caption_string, tokens_positive, entities = \
self.get_tokens_and_prompts(
original_caption, custom_entities, enhanced_text_prompt)
positive_map_label_to_token, positive_map = self.get_positive_map(
tokenized, tokens_positive)
chunked_size = self.test_cfg.get('chunked_size', -1)
if not self.training and chunked_size > 0:
assert isinstance(original_caption,
(list, tuple)) or custom_entities is True
all_output = self.get_tokens_positive_and_prompts_chunked(
original_caption, enhanced_text_prompt)
positive_map_label_to_token, \
caption_string, \
positive_map, \
entities = all_output
else:
tokenized, caption_string, tokens_positive, entities = \
self.get_tokens_and_prompts(
original_caption, custom_entities, enhanced_text_prompt)
positive_map_label_to_token, positive_map = self.get_positive_map(
tokenized, tokens_positive)
return positive_map_label_to_token, caption_string, \
positive_map, entities

def get_tokens_positive_and_prompts_chunked(
self,
original_caption: Union[list, tuple],
enhanced_text_prompts: Optional[ConfigType] = None):
chunked_size = self.test_cfg.get('chunked_size', -1)
original_caption = [clean_label_name(i) for i in original_caption]

original_caption_chunked = chunks(original_caption, chunked_size)
ids_chunked = chunks(
list(range(1,
len(original_caption) + 1)), chunked_size)

positive_map_label_to_token_chunked = []
caption_string_chunked = []
positive_map_chunked = []
entities_chunked = []

for i in range(len(ids_chunked)):
if enhanced_text_prompts is not None:
caption_string, tokens_positive = self.to_enhance_text_prompts(
original_caption_chunked[i], enhanced_text_prompts)
else:
caption_string, tokens_positive = self.to_plain_text_prompts(
original_caption_chunked[i])
tokenized = self.language_model.tokenizer([caption_string],
return_tensors='pt')
if tokenized.input_ids.shape[1] > self.language_model.max_tokens:
warnings.warn('Inputting a text that is too long will result '
'in poor prediction performance. '
'Please reduce the --chunked-size.')
positive_map_label_to_token, positive_map = self.get_positive_map(
tokenized, tokens_positive)

caption_string_chunked.append(caption_string)
positive_map_label_to_token_chunked.append(
positive_map_label_to_token)
positive_map_chunked.append(positive_map)
entities_chunked.append(original_caption_chunked[i])

return positive_map_label_to_token_chunked, \
caption_string_chunked, \
positive_map_chunked, \
entities_chunked

def forward_transformer(
self,
img_feats: Tuple[Tensor],
Expand Down Expand Up @@ -429,24 +499,58 @@ def predict(self, batch_inputs, batch_data_samples, rescale: bool = True):
]
token_positive_maps, text_prompts, _, entities = zip(
*_positive_maps_and_prompts)
# extract text feats
text_dict = self.language_model(list(text_prompts))
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])

for i, data_samples in enumerate(batch_data_samples):
data_samples.token_positive_map = token_positive_maps[i]

# image feature extraction
visual_feats = self.extract_feat(batch_inputs)

head_inputs_dict = self.forward_transformer(visual_feats, text_dict,
batch_data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict,
rescale=rescale,
batch_data_samples=batch_data_samples)
if isinstance(text_prompts[0], list):
# chunked text prompts, only bs=1 is supported
assert len(batch_inputs) == 1
count = 0
results_list = []

entities = [[item for lst in entities[0] for item in lst]]

for b in range(len(text_prompts[0])):
text_prompts_once = [text_prompts[0][b]]
token_positive_maps_once = token_positive_maps[0][b]
text_dict = self.language_model(text_prompts_once)
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])

batch_data_samples[
0].token_positive_map = token_positive_maps_once

head_inputs_dict = self.forward_transformer(copy.deepcopy(visual_feats), text_dict,
batch_data_samples)
pred_instances = self.bbox_head.predict(
**head_inputs_dict,
rescale=rescale,
batch_data_samples=batch_data_samples)[0]

if len(pred_instances) > 0:
pred_instances.labels += count
count += len(token_positive_maps_once)
results_list.append(pred_instances)
results_list = [results_list[0].cat(results_list)]
else:
# extract text feats
text_dict = self.language_model(list(text_prompts))
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])

for i, data_samples in enumerate(batch_data_samples):
data_samples.token_positive_map = token_positive_maps[i]

head_inputs_dict = self.forward_transformer(visual_feats, text_dict,
batch_data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict,
rescale=rescale,
batch_data_samples=batch_data_samples)

for data_sample, pred_instances, entity in zip(batch_data_samples,
results_list, entities):
if len(pred_instances) > 0:
Expand Down

0 comments on commit 2cb3143

Please sign in to comment.