From 97fe4bcb2599e8eea8e87de2b966dccaca3dccbd Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 29 May 2023 17:12:33 +0200 Subject: [PATCH] remove 'encoder' --- torchgeo/trainers/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index e1b6678ed73..7a4cb00cb63 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -46,6 +46,14 @@ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: state_dict = OrderedDict( {k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()} ) + elif checkpoint["model"] in ["deeplabv3+", "unet"]: + state_dict = OrderedDict( + { + k.replace("encoder.", ""): v + for k, v in state_dict.items() + if "encoder" in k + } + ) else: raise ValueError( "Unknown checkpoint task. Only backbone or model extraction is supported"