Skip to content

Commit

Permalink
FIX: Issues with saving/loading with accelerate (#1008)
Browse files Browse the repository at this point in the history
* FIX: Issues with saving/loading with accelerate

Description

There were a few issue with saving and loading parameters for an
accelerated net (some only in a multi-GPU setting though):

- load_params set the device if device=None to CPU, but we need None
- not waiting for all processes to finish before saving parameters
- all processes saving the parameters, when only main should
- an issue with parameter names depending on the wrapping state

Regarding the last point, the issue was that if the module(s) are
wrapped with accelerate, the parameters have an additional prefix,
"module.". So e.g. "dense0.weight" would become "module.dense0.weight".
This could result in a key mismatch and error when loading a model.

The solution to this is to always unwrap the net before saving and
before loading. That way, the extra prefix is never present and there is
no mismatch.

A test was added to check this behavior, but since the GitHub CI does
not offer multi-GPU support, it does not test for all failure cases.
Therefore, I added a script,
examples/accelerate-multigpu/run-save-load.py, that can be run on a
multi-GPU setup to test the issue.

This unit test checks the correct behavior on CPU, iterating through all
4 combinations of wrapping/not wrapping the initial/loaded model.

Implementation

The changes needed were often just a few lines that sync the processes
or place a guard to only run on the main process. A few of these were
quite unintuitive to me, so I added a comment for them.

The one big change is that the preparation of the components by the
accelerator is now moved to a separate method, _initialize_accelerator.
This way, it is now possible to unwrap and re-wrap the model with a
single method call each. Without that change, re-wrapping was only
possible by calling net.initialize(), which would re-initialize
everything, which is not desired.

This change can be backwards incompatible: If a user saved the
parameters of an accelerated net while it's still wrapped (not the
default), and tries to load it into a wrapped net, it will no longer
work. I think this case is rare enough that we can accept it. In the
worst case, the user can still apply the state dict manually on the
wrapped net.

I did consider an alternative solution that would inspect the names of
the keys in the state dict and try to determine from those if the
loaded/current weights are from a wrapped model or not, and consequently
rename the keys of the state dict. However, this method is unreliable
and also not easy to implement with the current code, so I opted for the
solution described above.

Also note that this PR does _not_ fix potential issues that might occur
during checkpointing of the model while it's training. For this, we need
to use accelerator.{save_state,load_state}, see here:

https://huggingface.co/docs/accelerate/usage_guides/checkpoint

Probably this use case is best served with a separate checkpoint
callback.

* Changelog entry, references to the GH PR

* Temporarily set device in load_params if necessary

Reviewer feedback: To prevent a confusing warning message, temporarily
set the device to 'cpu' when it was None during load_params. After
loading has finished, set it back to the original value.
  • Loading branch information
BenjaminBossan authored Aug 18, 2023
1 parent e6023a1 commit 07fc260
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
### Fixed

- Fixed a couple of issues when saving and loading parameters while using accelerate (via `AccelerateMixin`) in a multi-GPU setting (#1008)

## [0.14.0] - 2023-06-24

### Added
Expand Down
16 changes: 16 additions & 0 deletions examples/accelerate-multigpu/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Testing skorch with accelerate in multi GPU setting

This directory contains a couple of script used to test skorch with accelerate in a multi-GPU setting. The scripts cannot run as unit tests because they require a specific hardware setup not provided by the GitHub Action runners.

## `run-with-skorch.py`

The full history of this can be found here: https://github.com/skorch-dev/skorch/issues/944

There was an issue with using skorch in a multi-GPU setting with accelerate. After some searching, it turns out there were two problems:
Expand Down Expand Up @@ -36,3 +40,15 @@ tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
## `run-save-load.py`

The context of this script is that there were issues with saving and loading when using `AccelerateMixin`. The provided script is to ensure that everything works as expected. Same as the first one, for a proper test, this script needs to run in a multi-GPU setting. For more information, check PR #1008.

Run the scripts like this:

```sh
accelerate launch run-save-load.py
```

The accelerate config is the same.
104 changes: 104 additions & 0 deletions examples/accelerate-multigpu/run-save-load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Check that saving and loading works with accelerate.
Especially, pay attention that both the initial model, as well as the loaded
model, could be either wrapped with accelerate or not, i.e. there are 4 possible
combinations.
"""

import numpy as np
import torch
from accelerate import Accelerator
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from torch import nn
from torch.distributed import TCPStore

from skorch import NeuralNetClassifier
from skorch.hf import AccelerateMixin
from skorch.history import DistributedHistory


PORT = 8080


class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense0 = nn.Linear(100, 2)
self.nonlin = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.dense0(X)
X = self.nonlin(X)
return X


# make use of accelerate by creating a class with the AccelerateMixin
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
pass


def get_accelerate_model(accelerator):
global PORT
PORT += 1

is_master = accelerator.is_main_process
world_size = accelerator.num_processes
rank = accelerator.local_process_index
store = TCPStore(
"127.0.0.1", port=PORT, world_size=world_size, is_master=is_master)
dist_history = DistributedHistory(
store=store, rank=rank, world_size=world_size)

return AcceleratedNeuralNetClassifier(
MyModule,
criterion=nn.CrossEntropyLoss,
accelerator=accelerator,
max_epochs=3,
lr=0.001,
history=dist_history,
)


def get_vanilla_model():
return NeuralNetClassifier(
MyModule,
criterion=nn.CrossEntropyLoss,
max_epochs=3,
lr=0.001,
)


def main(wrap_initial_model=True, wrap_loaded_model=True):
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0)
X = X.astype(np.float32)

accelerator = Accelerator()
model = get_accelerate_model(accelerator)
model.unwrap_after_train = True if wrap_initial_model else False
model.fit(X, y)

model.save_params(f_params="model_params.pt")
y_pred = model.predict(X)
accuracy_before = accuracy_score(y, y_pred)
print(f"Accuracy before loading: {accuracy_before}")

if wrap_loaded_model:
model_loaded = get_accelerate_model(accelerator).initialize()
else:
model_loaded = get_vanilla_model().initialize()

model_loaded.load_params(f_params="model_params.pt")
y_pred = model_loaded.predict(X)
accuracy_after = accuracy_score(y, y_pred)
print(f"Accuracy after loading: {accuracy_after}")

assert accuracy_before == accuracy_after


if __name__ == '__main__':
main(True, True)
main(True, False)
main(False, True)
main(False, False)
130 changes: 101 additions & 29 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ def __init__(
)
self.accelerator = accelerator
self.unwrap_after_train = unwrap_after_train
self._wrapped_with_accelerator = False

def _validate_params(self):
super()._validate_params()
Expand All @@ -934,53 +935,60 @@ def _validate_params(self):
"When device placement is performed by the accelerator, set device=None"
)

def _initialize_callbacks(self):
if self.callbacks__print_log__sink == 'auto':
print_func = getattr(self.accelerator, 'print', print)
self.callbacks__print_log__sink = print_func
super()._initialize_callbacks()
return self

def _initialize_criterion(self, *args, **kwargs):
super()._initialize_criterion(*args, **kwargs)
def _initialize_accelerator(self):
"""Prepare everything for use with accelerate"""
if self._wrapped_with_accelerator:
return self

with self._current_init_context('criterion'):
for name in self._criteria:
criterion = getattr(self, name + '_')
if isinstance(criterion, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(criterion))

return self

def _initialize_module(self, *args, **kwargs):
super()._initialize_module(*args, **kwargs)

with self._current_init_context('module'):
for name in self._modules:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(module))

return self

def _initialize_optimizer(self, *args, **kwargs):
super()._initialize_optimizer(*args, **kwargs)

with self._current_init_context('optimizer'):
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if isinstance(optimizer, torch.optim.Optimizer):
setattr(self, name + '_', self.accelerator.prepare(optimizer))

return self

def initialize_callbacks(self, *args, **kwargs):
super().initialize_callbacks(*args, **kwargs)

for _, callback in self.callbacks_:
if isinstance(callback, LRScheduler):
callback.policy_ = self.accelerator.prepare(callback.policy_)

self._wrapped_with_accelerator = True
return self

def initialize(self):
"""Initializes all of its components and returns self."""
# this should be the same as the parent class, except for the one marked
# line
self.check_training_readiness()

self._initialize_virtual_params()
self._initialize_callbacks()
self._initialize_module()
self._initialize_criterion()
self._initialize_optimizer()
self._initialize_history()
self._initialize_accelerator() # <= added

self._validate_params()

self.initialized_ = True
return self

def _initialize_callbacks(self):
if self.callbacks__print_log__sink == 'auto':
print_func = getattr(self.accelerator, 'print', print)
self.callbacks__print_log__sink = print_func
super()._initialize_callbacks()
return self

def train_step(self, batch, **fit_params):
Expand Down Expand Up @@ -1021,17 +1029,23 @@ def _step_optimizer(self, step_fn):
optimizer = getattr(self, name + '_')
optimizer.step()

# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
if not self.unwrap_after_train:
return self
def _unwrap_accelerator(self):
if not self._wrapped_with_accelerator:
return

for name in self._modules + self._criteria:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
orig = self.accelerator.unwrap_model(module, keep_fp32_wrapper=False)
setattr(self, name + '_', orig)
self._wrapped_with_accelerator = False

# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
self.accelerator.wait_for_everyone()
super().on_train_end(net, X=X, y=y, **kwargs)
if self.unwrap_after_train:
self._unwrap_accelerator()
return self

def evaluation_step(self, batch, training=False):
Expand All @@ -1042,6 +1056,63 @@ def evaluation_step(self, batch, training=False):
y_pred = self.accelerator.gather_for_metrics(output)
return y_pred

# pylint: disable=missing-function-docstring
def save_params(self, *args, **kwargs):
# has to be called even if not main process, or else there is a dead lock
self.accelerator.wait_for_everyone()

if not self._wrapped_with_accelerator:
if self.accelerator.is_main_process:
super().save_params(*args, **kwargs)
else:
# A potential issue with using accelerate is that a model that has
# been prepared with accelerate is wrapped, so that the keys of the
# state dict have an additional prefix, "module.". Therefore, when
# the model is unwrapped when saving and wrapped when loading, or
# vice versa, there will be a mismatch in the state dict keys. To
# prevent this, always unwrap before saving. During loading, in case
# the model is wrapped, this would result in an error, but we take
# care of unwrapping the model in that case during loading.
self._unwrap_accelerator()
try:
# note: although saving is only done on the main process,
# unwrapping+wrapping has to be done on all processes, or else
# there is an error, not sure why
if self.accelerator.is_main_process:
super().save_params(*args, **kwargs)
finally:
self._initialize_accelerator()

# pylint: disable=missing-function-docstring
def load_params(self, *args, **kwargs):
self.accelerator.wait_for_everyone()
prev_device = self.device
if self.device is None:
self.device = 'cpu'

try:
if not self._wrapped_with_accelerator:
super().load_params(*args, **kwargs)
else:
# A potential issue with using accelerate is that a model that
# has been prepared with accelerate is wrapped, so that the keys
# of the state dict have an additional prefix, "module.".
# Therefore, when the model is unwrapped when saving and wrapped
# when loading, or vice versa, there will be a mismatch in the
# state dict keys. Here, we always unwrap the model first before
# loading (1st case). This would still result in an error in the
# 2nd case, but we take care of unwrapping the model in that
# case during saving.
self._unwrap_accelerator()
try:
super().load_params(*args, **kwargs)
finally:
self._initialize_accelerator()
finally:
# ensure that the device remains unchanged in case it was None
# before calling load_params
self.device = prev_device


class HfHubStorage:
"""Helper class that allows writing data to the Hugging Face Hub.
Expand Down Expand Up @@ -1213,6 +1284,7 @@ def flush(self):
if self.verbose:
self.sink(f"Uploaded file to {return_url}")

# pylint: disable=unused-argument
def close(self, *args):
self.flush()

Expand Down
Loading

0 comments on commit 07fc260

Please sign in to comment.