diff --git a/mmlearn/datasets/processors/masking.py b/mmlearn/datasets/processors/masking.py index 697f68e..6a0f986 100644 --- a/mmlearn/datasets/processors/masking.py +++ b/mmlearn/datasets/processors/masking.py @@ -2,7 +2,8 @@ import math import random -from typing import Any, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch from hydra_zen import store @@ -264,3 +265,127 @@ def apply_masks( # Concatenate along the batch dimension return torch.cat(all_x, dim=0) + + +@dataclass +class IJEPAMaskGenerator: + """Generates encoder and predictor masks for preprocessing. + + This class generates masks dynamically for individual examples and can be passed to + a data loader as a preprocessing step. + + Parameters + ---------- + input_size : tuple[int, int], default=(224, 224) + Input image size. + patch_size : int, default=16 + Size of each patch. + min_keep : int, default=4 + Minimum number of patches to keep. + allow_overlap : bool, default=False + Whether to allow overlap between encoder and predictor masks. + enc_mask_scale : tuple[float, float], default=(0.2, 0.8) + Scale range for encoder mask. + pred_mask_scale : tuple[float, float], default=(0.2, 0.8) + Scale range for predictor mask. + aspect_ratio : tuple[float, float], default=(0.3, 3.0) + Aspect ratio range for mask blocks. + nenc : int, default=1 + Number of encoder masks to generate. + npred : int, default=2 + Number of predictor masks to generate. + """ + + input_size: Tuple[int, int] = (224, 224) + patch_size: int = 16 + min_keep: int = 4 + allow_overlap: bool = False + enc_mask_scale: Tuple[float, float] = (0.2, 0.8) + pred_mask_scale: Tuple[float, float] = (0.2, 0.8) + aspect_ratio: Tuple[float, float] = (0.3, 3.0) + nenc: int = 1 + npred: int = 2 + + def __post_init__(self) -> None: + """Initialize the mask generator.""" + self.height = self.input_size[0] // self.patch_size + self.width = self.input_size[1] // self.patch_size + + def _sample_block_size( + self, + generator: torch.Generator, + scale: Tuple[float, float], + aspect_ratio: Tuple[float, float], + ) -> Tuple[int, int]: + """Sample the size of the mask block based on scale and aspect ratio.""" + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = scale + mask_scale = min_s + _rand * (max_s - min_s) + max_keep = int(self.height * self.width * mask_scale) + + min_ar, max_ar = aspect_ratio + aspect_ratio_val = min_ar + _rand * (max_ar - min_ar) + + h = int(round(math.sqrt(max_keep * aspect_ratio_val))) + w = int(round(math.sqrt(max_keep / aspect_ratio_val))) + + h = min(h, self.height - 1) + w = min(w, self.width - 1) + + return h, w + + def _sample_block_mask( + self, + b_size: Tuple[int, int], + acceptable_regions: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample a mask block.""" + h, w = b_size + top = torch.randint(0, self.height - h, (1,)).item() + left = torch.randint(0, self.width - w, (1,)).item() + mask = torch.zeros((self.height, self.width), dtype=torch.int32) + mask[top : top + h, left : left + w] = 1 + + mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) + mask_complement[top : top + h, left : left + w] = 0 + + return mask.flatten(), mask_complement.flatten() + + def __call__( + self, + ) -> Dict[str, Any]: + """Generate encoder and predictor masks for a single example. + + Returns + ------- + Dict[str, Any] + A dictionary of encoder masks and predictor masks. + """ + seed = torch.randint( + 0, 2**32, (1,) + ).item() # Sample random seed for reproducibility + g = torch.Generator().manual_seed(seed) + + # Sample block sizes + p_size = self._sample_block_size( + generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio + ) + e_size = self._sample_block_size( + generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0) + ) + + # Generate predictor masks + masks_pred, masks_enc = [], [] + for _ in range(self.npred): + mask_p, _ = self._sample_block_mask(p_size) + masks_pred.append(mask_p) + + # Generate encoder masks + for _ in range(self.nenc): + mask_e, _ = self._sample_block_mask(e_size) + masks_enc.append(mask_e) + + return { + "encoder_masks": torch.stack(masks_enc), + "predictor_masks": torch.stack(masks_pred), + }