Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

references/segmentation/coco_utils might require merging rles? #8661

Open
crazyboy9103 opened this issue Sep 26, 2024 · 1 comment
Open

references/segmentation/coco_utils might require merging rles? #8661

crazyboy9103 opened this issue Sep 26, 2024 · 1 comment

Comments

@crazyboy9103
Copy link

crazyboy9103 commented Sep 26, 2024

def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
Above seems to assume that objects are not occluded, not merging rles from frPyObjects. In such case, i think it must be changed to

rles = coco_mask.frPyObjects(polygons, height, width) 
rle = coco_mask.merge(rles)
mask = coco_mask.decode(rle)

Is there any specific reason for this, or am I wrong?

@NicolasHug
Copy link
Member

Hi @crazyboy9103 , thanks for the report.
I'm not so familiar with that part of the code-base so I could be way off, but I suspect the logic you're looking for is implemented later in

masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants