Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for optimizers which don't have "fused" parameter such as grokadamw and 8bit bnb #1744

Merged
merged 12 commits into from
Sep 26, 2024

Conversation

mtasic85
Copy link
Contributor

This fixes #1743

  1. This changes bring support for optimizers which do not have "fused" parameter. Examples:
optimizer: grokadamw.GrokAdamW
optimizer: bitsandbytes.optim.AdamW8bit
optimizer: bitsandbytes.optim.PagedAdamW8bit
  1. Additionally, pytorch optimizers can be written in following format:
optimizer: torch.optim.AdamW
optimizer: AdamW
  1. In case of yaml config file, if value for optimizer: isn't str nor dict, error is thrown.

litgpt/utils.py Outdated
Comment on lines 566 to 576
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer

module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)

if "fused" in kwargs and "fused" not in inspect.signature(optimizer_cls).parameters:
kwargs = dict(kwargs) # copy
del kwargs["fused"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I am thinking maybe we can make this even more general so that it would also remove other unsupported arguments if present, something like

Suggested change
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)
if "fused" in kwargs and "fused" not in inspect.signature(optimizer_cls).parameters:
kwargs = dict(kwargs) # copy
del kwargs["fused"]
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)
valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in kwargs.items() if key in valid_params}
optimizer = optimizer_cls(model_parameters, **kwargs)

Copy link
Collaborator

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, this looks really good to me. What do you think about the suggestions above?

litgpt/utils.py Outdated Show resolved Hide resolved
@mtasic85
Copy link
Contributor Author

mtasic85 commented Sep 26, 2024

@rasbt Can you please check CI/CD logs?

AttributeError: 'GPTNeoXRotaryEmbedding' object has no attribute 'cos_cached'

This should not be issue with this pull request but appears for some reason.

It looks it is related to #1745

@rasbt
Copy link
Collaborator

rasbt commented Sep 26, 2024

@mtasic85 I noticed that elsewhere too. It's related to the recent transformers release (4.45) yesterday that introduced a backward incompatible change. You are right, it has nothing to do with your PR. And no worries, I will fix this now.

@rasbt
Copy link
Collaborator

rasbt commented Sep 26, 2024

Alright! So there were two things that broke CI yesterday: a new transformers release and a new jsonargparse release. Should be all addressed now.

@rasbt
Copy link
Collaborator

rasbt commented Sep 26, 2024

Looks good now, it should be ready to merge. Thanks again for submitting this PR!

litgpt/utils.py Show resolved Hide resolved
litgpt/utils.py Outdated Show resolved Hide resolved
@rasbt rasbt merged commit b4b8dfc into Lightning-AI:main Sep 26, 2024
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

grokadamw and 8bit bnb optimizers
2 participants