diff --git a/uisrnn/uisrnn.py b/uisrnn/uisrnn.py index e951346..7beff5e 100644 --- a/uisrnn/uisrnn.py +++ b/uisrnn/uisrnn.py @@ -151,7 +151,7 @@ def load(self, filepath): Args: filepath: the path of the file. """ - var_dict = torch.load(filepath) + var_dict = torch.load(filepath, map_location=self.device) self.rnn_model.load_state_dict(var_dict['rnn_state_dict']) self.rnn_init_hidden = nn.Parameter( torch.from_numpy(var_dict['rnn_init_hidden']).to(self.device))