From 699c5c8f328c33974496844779cca058b1926ab5 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 29 Sep 2024 20:06:14 -0700 Subject: [PATCH] tt --- examples/docs/trace_sam2_img.py | 213 ++++++++++++++++++++++---------- 1 file changed, 147 insertions(+), 66 deletions(-) diff --git a/examples/docs/trace_sam2_img.py b/examples/docs/trace_sam2_img.py index bb80795202b..ce5a94ed0ce 100644 --- a/examples/docs/trace_sam2_img.py +++ b/examples/docs/trace_sam2_img.py @@ -10,8 +10,9 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import os import sys -from typing import Tuple +from typing import Any import torch from sam2.modeling.sam2_base import SAM2Base @@ -19,77 +20,129 @@ from torch import nn -class Sam2Wrapper(nn.Module): +class SAM2ImageEncoder(nn.Module): - def __init__( - self, - sam_model: SAM2Base, - ) -> None: + def __init__(self, sam_model: SAM2Base) -> None: super().__init__() self.model = sam_model + self.image_encoder = sam_model.image_encoder + self.no_mem_embed = sam_model.no_mem_embed - # Spatial dim for backbone feature maps - self._bb_feat_sizes = [ - (256, 256), - (128, 128), - (64, 64), - ] + def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: + backbone_out = self.image_encoder(x) + backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1]) - def extract_features( - self, - input_image: torch.Tensor, - ) -> (torch.Tensor, torch.Tensor, torch.Tensor): - backbone_out = self.model.forward_image(input_image) - _, vision_feats, _, _ = self.model._prepare_backbone_features( - backbone_out) - # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos - if self.model.directly_add_no_mem_embed: - vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feature_maps = backbone_out["backbone_fpn"][-self.model. + num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model. + num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_feats[-1] = vision_feats[-1] + self.no_mem_embed feats = [ - feat.permute(1, 2, - 0).view(1, -1, *feat_size) for feat, feat_size in zip( - vision_feats[::-1], self._bb_feat_sizes[::-1]) + feat.permute(1, 2, 0).reshape(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) ][::-1] - return feats[-1], feats[0], feats[1] + return feats[0], feats[1], feats[2] - def forward( - self, - input_image: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - image_embed, feature_1, feature_2 = self.extract_features(input_image) - return self.predict(point_coords, point_labels, image_embed, feature_1, - feature_2) - def predict( +class SAM2ImageDecoder(nn.Module): + + def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: + super().__init__() + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.model = sam_model + self.img_size = sam_model.image_size + self.multimask_output = multimask_output + self.sparse_embedding = None + + @torch.no_grad() + def forward( self, + image_embed: torch.Tensor, + high_res_feats_0: torch.Tensor, + high_res_feats_1: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, - image_embed: torch.Tensor, - feats_1: torch.Tensor, - feats_2: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - concat_points = (point_coords, point_labels) - - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=concat_points, - boxes=None, - masks=None, - ) - - low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=image_embed[0].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=True, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + self.sparse_embedding = sparse_embedding + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + high_res_feats = [high_res_feats_0, high_res_feats_1] + image_embed = image_embed + + masks, iou_predictions, _, _ = self.mask_decoder.predict_masks( + image_embeddings=image_embed, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, repeat_image=False, - high_res_features=[feats_1, feats_2], + high_res_features=high_res_feats, ) - return low_res_masks, iou_predictions + + if self.multimask_output: + masks = masks[:, 1:, :, :] + iou_predictions = iou_predictions[:, 1:] + else: + masks, iou_pred = ( + self.mask_decoder._dynamic_multimask_via_stability( + masks, iou_predictions)) + + masks = torch.clamp(masks, -32.0, 32.0) + + return masks, iou_predictions + + def _embed_points(self, point_coords: torch.Tensor, + point_labels: torch.Tensor) -> torch.Tensor: + + point_coords = point_coords + 0.5 + + padding_point = torch.zeros((point_coords.shape[0], 1, 2), + device=point_coords.device) + padding_label = -torch.ones( + (point_labels.shape[0], 1), device=point_labels.device) + point_coords = torch.cat([point_coords, padding_point], dim=1) + point_labels = torch.cat([point_labels, padding_label], dim=1) + + point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size + point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size + + point_embedding = self.prompt_encoder.pe_layer._pe_encoding( + point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = (point_embedding + + self.prompt_encoder.not_a_point_embed.weight * + (point_labels == -1)) + + for i in range(self.prompt_encoder.num_point_embeddings): + point_embedding = (point_embedding + + self.prompt_encoder.point_embeddings[i].weight * + (point_labels == i)) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, + has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling( + input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding def trace_model(model_id: str): @@ -98,19 +151,47 @@ def trace_model(model_id: str): else: device = torch.device("cpu") + model_name = f"{model_id[9:]}" + os.makedirs(model_name) + predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device) - model = Sam2Wrapper(predictor.model) + encoder = SAM2ImageEncoder(predictor.model) + decoder = SAM2ImageDecoder(predictor.model, True) input_image = torch.ones(1, 3, 1024, 1024).to(device) - input_point = torch.ones(1, 1, 2).to(device) - input_labels = torch.ones(1, 1, dtype=torch.int32, device=device) - - converted = torch.jit.trace_module( - model, { - "extract_features": input_image, - "forward": (input_image, input_point, input_labels) - }) - torch.jit.save(converted, f"{model_id[9:]}.pt") + high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image) + + converted = torch.jit.trace(encoder, input_image) + torch.jit.save(converted, f"model_name/encoder.pt") + + # trace decoder model + embed_size = ( + predictor.model.image_size // predictor.model.backbone_stride, + predictor.model.image_size // predictor.model.backbone_stride, + ) + mask_input_size = [4 * x for x in embed_size] + + point_coords = torch.randint(low=0, + high=1024, + size=(1, 5, 2), + dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float) + mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) + has_mask_input = torch.tensor([1], dtype=torch.float) + + converted = torch.jit.trace( + decoder, (image_embed, high_res_feats_0, high_res_feats_1, + point_coords, point_labels, mask_input, has_mask_input)) + torch.jit.save(converted, f"model_name/model_name.pt") + + # save serving.properties + serving_file = os.path.join(model_name, "serving.properties") + with open(serving_file, "w") as f: + f.write( + f"engine=PyTorch\n" + f"option.modelName={model_name}\n" + f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n" + f"encoder=encoder.pt") if __name__ == '__main__':