Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch distributed: add support for user-specified parameter synchronization #1612

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

NeoLegends
Copy link
Collaborator

To allow extending the existing param_avg and gradient_sync strategies w/ custom user-defined ones for easier experimentation.

@NeoLegends NeoLegends self-assigned this Sep 4, 2024
@NeoLegends NeoLegends requested review from albertz and a team as code owners September 4, 2024 10:14
@albertz
Copy link
Member

albertz commented Sep 4, 2024

What specific synchronization schemes do you have in mind?

While this API looks flexible, I think it's actually not very flexible and has lots of implicit assumptions:

  • It assumes there is such a concept of rank/size. (I could imagine many situations where this would not really make sense, e.g. where you dynamically add/remove compute nodes.)
  • It additionally assumes there is a distinction between local and global rank/size. (I could imagine situations where there are even more hierarchies.)
  • The relevant parameter synchronization is done right after the local parameter update. That's true for our current parameter averaging logic. But other synchronization schemes might need hooks elsewhere. For example, for grad sync instead of param sync, you need to do it right after calculating the gradients (or maybe accumulated gradients). So again, this is very limited. I can imagine many potential places where you might want to hook.
  • I think the make_distributed_model does not make sense. This is only for when you want to use DistributedDataParallel, but in that case, I don't think you want to use some other custom logic in the post-param-update step.
  • Also see your TODO on the torch.no_grad().

For most of the extensions I have in mind on distributed training, this API would not work. I'm also not sure there is a good way to design such API because there are basically infinite many ways how distributed training could be done, and we don't really know what we want to do yet, or what other people might want.

So I think we should really only implement a new API for some of the parts where we exactly know that we need this right now. That brings me to my initial question: What specific synchronization schemes do you have in mind? What API would you need for exactly that synchronization scheme, that it would be possible to implement it in user space?

Or: instead of trying to implement such flexible API, implement directly the specific method you have in mind? The flexible API might actually only really work well now for this specific method you have in mind, but not really work for anything else. In that case, we just have added complexity without real value.

@NeoLegends
Copy link
Collaborator Author

NeoLegends commented Sep 5, 2024

Thank you for your comment, that makes sense to me. I was thinking of a scheme where you do parameter averaging after different steps depending on whether you are averaging within one node or across nodes. But I think the comment, especially wrt. GPUs leaving and joining the computation makes a lot of sense why such an abstract API does not make that much sense, at least not until you can always update the local/global rank indices when you have GPUs joining and leaving the cluster.

@albertz
Copy link
Member

albertz commented Sep 5, 2024

I was thinking of a scheme where you do parameter averaging after different steps depending on whether you are averaging within one node or across nodes.

Ah yea, that's a good idea. Ok, this is very specific, so let's think about what API we need to allow for flexible user-defined schemes for this.

I think a flexible way would be if the user just defines a custom step_after_param_update function. So in the config, it could look like this:

def custom_step_after_param_update(*, module: torch.nn.Module, epoch_step_idx: int, **_kwargs):
    ...

torch_distributed = {
    "reduce_type": "custom_step_after_param_update",
    "custom_step_after_param_update": custom_step_after_param_update,
}

That's all what is needed, right?

If you need to know the global/local rank/size inside that custom step func, or any other environment information, that would be up to the user. E.g. the user can always do:

from returnn.torch import distributed

rank = distributed.get_ctx().rank

@NeoLegends
Copy link
Collaborator Author

NeoLegends commented Sep 6, 2024

That's all what is needed, right?

Hmm, when would you set up the sub process groups needed for synchronization? E.g. on the first invocation of the function? In the class approach this is quite easy because you can initialize the class (and any sub process groups) right after the global process group is initialized, the moment is just very defined. I'm not sure it's feasible to initialize them when the RETURNN config is parsed (i.e. by initializing a callable class like e.g. LaplaceOrdering(...) in the PP dataset).

@albertz
Copy link
Member

albertz commented Sep 6, 2024

E.g. on the first invocation of the function?

Yes that should work just fine, right?

@albertz
Copy link
Member

albertz commented Oct 12, 2024

Some update: In the future, I want to implement sth similar as this: https://github.com/PrimeIntellect-ai/prime (or maybe just reuse the existing code there).

Specifically, this includes ElasticDeviceMesh and OpenDiLoCo/DiLoCo.

This is just an example for what to keep in mind when making this more flexible here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants