From 3b1b4f0e4650eb32ea48b0247e028404872641fa Mon Sep 17 00:00:00 2001 From: alisterburt Date: Fri, 12 May 2023 21:01:18 +0100 Subject: [PATCH] load model onto cpu first to avoid mps error on linux (#26) --- src/fidder/model/_tests/test_download_checkpoint.py | 2 +- src/fidder/predict/predict.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fidder/model/_tests/test_download_checkpoint.py b/src/fidder/model/_tests/test_download_checkpoint.py index 1f8c593..8fe6497 100644 --- a/src/fidder/model/_tests/test_download_checkpoint.py +++ b/src/fidder/model/_tests/test_download_checkpoint.py @@ -4,5 +4,5 @@ def test_download_and_load_latest_checkpoint(): checkpoint_file = get_latest_checkpoint() model = Fidder() - model.load_from_checkpoint(checkpoint_file) + model.load_from_checkpoint(checkpoint_file, map_location="cpu") assert isinstance(model, Fidder) diff --git a/src/fidder/predict/predict.py b/src/fidder/predict/predict.py index 3869616..4ef3dc2 100644 --- a/src/fidder/predict/predict.py +++ b/src/fidder/predict/predict.py @@ -47,7 +47,7 @@ def predict_fiducial_mask( # prepare model if model_checkpoint_file is None: model_checkpoint_file = get_latest_checkpoint() - model = Fidder.load_from_checkpoint(model_checkpoint_file) + model = Fidder.load_from_checkpoint(model_checkpoint_file, map_location="cpu") model.eval() # predict