-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to setup TP Overlap configs #1344
Comments
Hi @TJ-Solergibert -- the chapter in NeMo documentation is, I believe, the best description of what TP overlap does conceptually but you're correct that there is not much guidance anywhere for the overlap configs. It's on our roadmap to improve the documentation for this on the TE side, and I'm hoping to include it in some other TP overlap-related items on our agenda in January. In the meantime, I'll share some guidance on the config options here in TE's own API. In pure-TE/PyTorch code (without NeMo), we initialize and configure TP overlap with transformer_engine.pytorch.initialize_ub(...). The docstring for this shows the structure of the config dictionary, which NeMo constructs from the YAML config files you've seen in their repo. I'll list these options below with some descriptions of what they are and guidance on how to use them:
I would recommend starting out with one of the existing configs in NeMo and looking at execution profiles to see if there is any communication left that's not hidden by a GEMM execution. For pipeline overlaps, that typically means you need to increase your SM margin to slow down the GEMM kernel and speed up the comms until you're able to cover all your communication behind the GEMM compute. Conversely, you would need to decrease your SM margin if the GEMM is too slow and communication streams stuck waiting for chunks to finish compute. Ring-exchange overlaps with Copy Engine send/recv can't tune with SM margins but these are also typically performant out-of-the-box. Since you're using GH200s, I should point out that TP overlap currently does not support Multi-Node NVLink so it restricts your TP size to the 2 GPUs on a single GH200 node. Adding this support is one of the other TP overlap-related items I mentioned on our roadmap, but until then, it may not make sense to use it |
Hi!
After checking across multiple NVIDIA projects (NeMo, NeMo Framework Launcher & Megatron) I'm still wondering how do I have to setup all the configs related with TP Overlap. The little explanation I've found is this chapter in the NeMo docs.
So far I've seen this folder in NeMo Framework Launcher which contains similar configs to the ones in NeMo 2.0. From the names of the configs I can more or less derive to which hidden size, tp, cp, mbs & sequence length are optimised for but I would like to understand a bit more what is happening & also some guidelines to tune this parameters for my underlaying hardware (GH200).
Thanks!
The text was updated successfully, but these errors were encountered: