Skip to content

Commit

Permalink
tt
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 30, 2024
1 parent f703b93 commit 699c5c8
Showing 1 changed file with 147 additions and 66 deletions.
213 changes: 147 additions & 66 deletions examples/docs/trace_sam2_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,86 +10,139 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import os
import sys
from typing import Tuple
from typing import Any

import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from torch import nn


class Sam2Wrapper(nn.Module):
class SAM2ImageEncoder(nn.Module):

def __init__(
self,
sam_model: SAM2Base,
) -> None:
def __init__(self, sam_model: SAM2Base) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed

# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
backbone_out = self.image_encoder(x)
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1])

def extract_features(
self,
input_image: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
backbone_out = self.model.forward_image(input_image)
_, vision_feats, _, _ = self.model._prepare_backbone_features(
backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feature_maps = backbone_out["backbone_fpn"][-self.model.
num_feature_levels:]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.
num_feature_levels:]

feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

feats = [
feat.permute(1, 2,
0).view(1, -1, *feat_size) for feat, feat_size in zip(
vision_feats[::-1], self._bb_feat_sizes[::-1])
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
][::-1]

return feats[-1], feats[0], feats[1]
return feats[0], feats[1], feats[2]

def forward(
self,
input_image: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
image_embed, feature_1, feature_2 = self.extract_features(input_image)
return self.predict(point_coords, point_labels, image_embed, feature_1,
feature_2)

def predict(
class SAM2ImageDecoder(nn.Module):

def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.img_size = sam_model.image_size
self.multimask_output = multimask_output
self.sparse_embedding = None

@torch.no_grad()
def forward(
self,
image_embed: torch.Tensor,
high_res_feats_0: torch.Tensor,
high_res_feats_1: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
image_embed: torch.Tensor,
feats_1: torch.Tensor,
feats_2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
concat_points = (point_coords, point_labels)

sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=concat_points,
boxes=None,
masks=None,
)

low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=image_embed[0].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
self.sparse_embedding = sparse_embedding
dense_embedding = self._embed_masks(mask_input, has_mask_input)

high_res_feats = [high_res_feats_0, high_res_feats_1]
image_embed = image_embed

masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
image_embeddings=image_embed,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
repeat_image=False,
high_res_features=[feats_1, feats_2],
high_res_features=high_res_feats,
)
return low_res_masks, iou_predictions

if self.multimask_output:
masks = masks[:, 1:, :, :]
iou_predictions = iou_predictions[:, 1:]
else:
masks, iou_pred = (
self.mask_decoder._dynamic_multimask_via_stability(
masks, iou_predictions))

masks = torch.clamp(masks, -32.0, 32.0)

return masks, iou_predictions

def _embed_points(self, point_coords: torch.Tensor,
point_labels: torch.Tensor) -> torch.Tensor:

point_coords = point_coords + 0.5

padding_point = torch.zeros((point_coords.shape[0], 1, 2),
device=point_coords.device)
padding_label = -torch.ones(
(point_labels.shape[0], 1), device=point_labels.device)
point_coords = torch.cat([point_coords, padding_point], dim=1)
point_labels = torch.cat([point_labels, padding_label], dim=1)

point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

point_embedding = self.prompt_encoder.pe_layer._pe_encoding(
point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

point_embedding = point_embedding * (point_labels != -1)
point_embedding = (point_embedding +
self.prompt_encoder.not_a_point_embed.weight *
(point_labels == -1))

for i in range(self.prompt_encoder.num_point_embeddings):
point_embedding = (point_embedding +
self.prompt_encoder.point_embeddings[i].weight *
(point_labels == i))

return point_embedding

def _embed_masks(self, input_mask: torch.Tensor,
has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(
input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding


def trace_model(model_id: str):
Expand All @@ -98,19 +151,47 @@ def trace_model(model_id: str):
else:
device = torch.device("cpu")

model_name = f"{model_id[9:]}"
os.makedirs(model_name)

predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
model = Sam2Wrapper(predictor.model)
encoder = SAM2ImageEncoder(predictor.model)
decoder = SAM2ImageDecoder(predictor.model, True)

input_image = torch.ones(1, 3, 1024, 1024).to(device)
input_point = torch.ones(1, 1, 2).to(device)
input_labels = torch.ones(1, 1, dtype=torch.int32, device=device)

converted = torch.jit.trace_module(
model, {
"extract_features": input_image,
"forward": (input_image, input_point, input_labels)
})
torch.jit.save(converted, f"{model_id[9:]}.pt")
high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image)

converted = torch.jit.trace(encoder, input_image)
torch.jit.save(converted, f"model_name/encoder.pt")

# trace decoder model
embed_size = (
predictor.model.image_size // predictor.model.backbone_stride,
predictor.model.image_size // predictor.model.backbone_stride,
)
mask_input_size = [4 * x for x in embed_size]

point_coords = torch.randint(low=0,
high=1024,
size=(1, 5, 2),
dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([1], dtype=torch.float)

converted = torch.jit.trace(
decoder, (image_embed, high_res_feats_0, high_res_feats_1,
point_coords, point_labels, mask_input, has_mask_input))
torch.jit.save(converted, f"model_name/model_name.pt")

# save serving.properties
serving_file = os.path.join(model_name, "serving.properties")
with open(serving_file, "w") as f:
f.write(
f"engine=PyTorch\n"
f"option.modelName={model_name}\n"
f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n"
f"encoder=encoder.pt")


if __name__ == '__main__':
Expand Down

0 comments on commit 699c5c8

Please sign in to comment.