Skip to content

Commit

Permalink
Add i-JEPA masking class (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 authored Oct 7, 2024
1 parent 0602c6d commit db0d597
Showing 1 changed file with 126 additions and 1 deletion.
127 changes: 126 additions & 1 deletion mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import math
import random
from typing import Any, List, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from hydra_zen import store
Expand Down Expand Up @@ -264,3 +265,127 @@ def apply_masks(

# Concatenate along the batch dimension
return torch.cat(all_x, dim=0)


@dataclass
class IJEPAMaskGenerator:
"""Generates encoder and predictor masks for preprocessing.
This class generates masks dynamically for individual examples and can be passed to
a data loader as a preprocessing step.
Parameters
----------
input_size : tuple[int, int], default=(224, 224)
Input image size.
patch_size : int, default=16
Size of each patch.
min_keep : int, default=4
Minimum number of patches to keep.
allow_overlap : bool, default=False
Whether to allow overlap between encoder and predictor masks.
enc_mask_scale : tuple[float, float], default=(0.2, 0.8)
Scale range for encoder mask.
pred_mask_scale : tuple[float, float], default=(0.2, 0.8)
Scale range for predictor mask.
aspect_ratio : tuple[float, float], default=(0.3, 3.0)
Aspect ratio range for mask blocks.
nenc : int, default=1
Number of encoder masks to generate.
npred : int, default=2
Number of predictor masks to generate.
"""

input_size: Tuple[int, int] = (224, 224)
patch_size: int = 16
min_keep: int = 4
allow_overlap: bool = False
enc_mask_scale: Tuple[float, float] = (0.2, 0.8)
pred_mask_scale: Tuple[float, float] = (0.2, 0.8)
aspect_ratio: Tuple[float, float] = (0.3, 3.0)
nenc: int = 1
npred: int = 2

def __post_init__(self) -> None:
"""Initialize the mask generator."""
self.height = self.input_size[0] // self.patch_size
self.width = self.input_size[1] // self.patch_size

def _sample_block_size(
self,
generator: torch.Generator,
scale: Tuple[float, float],
aspect_ratio: Tuple[float, float],
) -> Tuple[int, int]:
"""Sample the size of the mask block based on scale and aspect ratio."""
_rand = torch.rand(1, generator=generator).item()
min_s, max_s = scale
mask_scale = min_s + _rand * (max_s - min_s)
max_keep = int(self.height * self.width * mask_scale)

min_ar, max_ar = aspect_ratio
aspect_ratio_val = min_ar + _rand * (max_ar - min_ar)

h = int(round(math.sqrt(max_keep * aspect_ratio_val)))
w = int(round(math.sqrt(max_keep / aspect_ratio_val)))

h = min(h, self.height - 1)
w = min(w, self.width - 1)

return h, w

def _sample_block_mask(
self,
b_size: Tuple[int, int],
acceptable_regions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample a mask block."""
h, w = b_size
top = torch.randint(0, self.height - h, (1,)).item()
left = torch.randint(0, self.width - w, (1,)).item()
mask = torch.zeros((self.height, self.width), dtype=torch.int32)
mask[top : top + h, left : left + w] = 1

mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
mask_complement[top : top + h, left : left + w] = 0

return mask.flatten(), mask_complement.flatten()

def __call__(
self,
) -> Dict[str, Any]:
"""Generate encoder and predictor masks for a single example.
Returns
-------
Dict[str, Any]
A dictionary of encoder masks and predictor masks.
"""
seed = torch.randint(
0, 2**32, (1,)
).item() # Sample random seed for reproducibility
g = torch.Generator().manual_seed(seed)

# Sample block sizes
p_size = self._sample_block_size(
generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
)
e_size = self._sample_block_size(
generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
)

# Generate predictor masks
masks_pred, masks_enc = [], []
for _ in range(self.npred):
mask_p, _ = self._sample_block_mask(p_size)
masks_pred.append(mask_p)

# Generate encoder masks
for _ in range(self.nenc):
mask_e, _ = self._sample_block_mask(e_size)
masks_enc.append(mask_e)

return {
"encoder_masks": torch.stack(masks_enc),
"predictor_masks": torch.stack(masks_pred),
}

0 comments on commit db0d597

Please sign in to comment.