From 56ecec340a09d5e299a432519f6422a7483d2464 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 5 Dec 2024 11:58:37 +0100 Subject: [PATCH] override checkpointing hooks to exclude backbone while saving checkpoints --- src/eva/core/models/modules/head.py | 20 +++++++++++++++++- src/eva/core/models/modules/utils/__init__.py | 3 ++- .../core/models/modules/utils/checkpoint.py | 21 +++++++++++++++++++ .../models/modules/semantic_segmentation.py | 20 +++++++++++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 src/eva/core/models/modules/utils/checkpoint.py diff --git a/src/eva/core/models/modules/head.py b/src/eva/core/models/modules/head.py index 9f2a8d300..67ee4f4db 100644 --- a/src/eva/core/models/modules/head.py +++ b/src/eva/core/models/modules/head.py @@ -12,7 +12,7 @@ from eva.core.metrics import structs as metrics_lib from eva.core.models.modules import module from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE -from eva.core.models.modules.utils import batch_postprocess, grad +from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict from eva.core.utils import parser @@ -32,6 +32,7 @@ def __init__( lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR, metrics: metrics_lib.MetricsSchema | None = None, postprocess: batch_postprocess.BatchPostProcess | None = None, + save_head_only: bool = True, ) -> None: """Initializes the neural net head module. @@ -48,6 +49,8 @@ def __init__( postprocess: A list of helper functions to apply after the loss and before the metrics calculation to the model predictions and targets. + save_head_only: Whether to save the head only during checkpointing. If False, + will also save the backbone (not recommended when backbone is frozen). """ super().__init__(metrics=metrics, postprocess=postprocess) @@ -56,6 +59,7 @@ def __init__( self.backbone = backbone self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.save_head_only = save_head_only @override def configure_model(self) -> Any: @@ -72,6 +76,20 @@ def configure_optimizers(self) -> Any: lr_scheduler = self.lr_scheduler(optimizer) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + @override + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self.save_head_only: + checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "head") + super().on_save_checkpoint(checkpoint) + + @override + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self.save_head_only and self.backbone is not None: + checkpoint["state_dict"].update( + {f"backbone.{k}": v for k, v in self.backbone.state_dict().items()} + ) + super().on_load_checkpoint(checkpoint) + @override def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: features = tensor if self.backbone is None else self.backbone(tensor) diff --git a/src/eva/core/models/modules/utils/__init__.py b/src/eva/core/models/modules/utils/__init__.py index f190e07b5..0623c298c 100644 --- a/src/eva/core/models/modules/utils/__init__.py +++ b/src/eva/core/models/modules/utils/__init__.py @@ -2,5 +2,6 @@ from eva.core.models.modules.utils import grad from eva.core.models.modules.utils.batch_postprocess import BatchPostProcess +from eva.core.models.modules.utils.checkpoint import submodule_state_dict -__all__ = ["grad", "BatchPostProcess"] +__all__ = ["grad", "BatchPostProcess", "submodule_state_dict"] diff --git a/src/eva/core/models/modules/utils/checkpoint.py b/src/eva/core/models/modules/utils/checkpoint.py new file mode 100644 index 000000000..4c83d768e --- /dev/null +++ b/src/eva/core/models/modules/utils/checkpoint.py @@ -0,0 +1,21 @@ +"""Checkpointing related utilities and helper functions.""" + +from typing import Any, Dict + + +def submodule_state_dict(state_dict: Dict[str, Any], submodule_key: str) -> Dict[str, Any]: + """Get the state dict of a specific submodule. + + Args: + state_dict: The state dict to extract the submodule from. + submodule_key: The key of the submodule to extract. + + Returns: + The subset of the state dict corresponding to the specified submodule. + """ + submodule_key = submodule_key if submodule_key.endswith(".") else submodule_key + "." + return { + module: weights + for module, weights in state_dict.items() + if module.startswith(submodule_key) + } diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py index 83eb337df..2d6001db5 100644 --- a/src/eva/vision/models/modules/semantic_segmentation.py +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -12,7 +12,7 @@ from eva.core.metrics import structs as metrics_lib from eva.core.models.modules import module from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH -from eva.core.models.modules.utils import batch_postprocess, grad +from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict from eva.core.utils import parser from eva.vision.models.networks import decoders from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs @@ -31,6 +31,7 @@ def __init__( lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR, metrics: metrics_lib.MetricsSchema | None = None, postprocess: batch_postprocess.BatchPostProcess | None = None, + save_decoder_only: bool = True, ) -> None: """Initializes the neural net head module. @@ -49,6 +50,8 @@ def __init__( postprocess: A list of helper functions to apply after the loss and before the metrics calculation to the model predictions and targets. + save_decoder_only: Whether to save the head only during checkpointing. If False, + will also save the backbone (not recommended when backbone is frozen). """ super().__init__(metrics=metrics, postprocess=postprocess) @@ -58,6 +61,7 @@ def __init__( self.lr_multiplier_encoder = lr_multiplier_encoder self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.save_decoder_only = save_decoder_only @override def configure_model(self) -> None: @@ -83,6 +87,20 @@ def configure_optimizers(self) -> Any: lr_scheduler = self.lr_scheduler(optimizer) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + @override + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self.save_decoder_only: + checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "decoder") + super().on_save_checkpoint(checkpoint) + + @override + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + if self.save_decoder_only and self.encoder is not None: + checkpoint["state_dict"].update( + {f"encoder.{k}": v for k, v in self.encoder.state_dict().items()} # type: ignore + ) + super().on_load_checkpoint(checkpoint) + @override def forward( self,