Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for optimizers which don't have "fused" parameter such as grokadamw and 8bit bnb #1744

Merged
merged 12 commits into from
Sep 26, 2024
30 changes: 27 additions & 3 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,13 +557,37 @@ def instantiate_bnb_optimizer(optimizer, model_parameters):


def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
# Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script:
# bnb.optim.AdamW8bit
# grokadamw.GrokAdamW
rasbt marked this conversation as resolved.
Show resolved Hide resolved
# torch.optim.RMSprop

if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer

module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)

valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}
optimizer = optimizer_cls(model_parameters, **kwargs)
else:
optimizer = dict(optimizer) # copy
elif isinstance(optimizer, dict):
optimizer = dict(optimizer)
class_module, class_name = optimizer["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)

valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}

optimizer["init_args"].update(kwargs)
optimizer = instantiate_class(model_parameters, optimizer)
else:
raise ValueError(f'Unrecognized "optimizer" value: {optimizer}')

return optimizer


Expand Down
21 changes: 21 additions & 0 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@
from tests.conftest import RunIf


@RunIf(min_cuda_gpus=1, standalone=True)
@mock.patch("litgpt.pretrain.save_hyperparameters")
def test_optimizer_args(_, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8)

dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))

for i in ("AdamW", "SGD", "RMSprop"):
pretrain.setup(
"pythia-14m",
devices=1,
optimizer="RMSprop",
model_config=model_config,
out_dir=tmp_path,
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1, final_validation=False),
)


@RunIf(min_cuda_gpus=2, standalone=True)
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
Expand Down
Loading