Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to setup TP Overlap configs #1344

Open
TJ-Solergibert opened this issue Nov 21, 2024 · 1 comment
Open

How to setup TP Overlap configs #1344

TJ-Solergibert opened this issue Nov 21, 2024 · 1 comment
Assignees

Comments

@TJ-Solergibert
Copy link

Hi!

After checking across multiple NVIDIA projects (NeMo, NeMo Framework Launcher & Megatron) I'm still wondering how do I have to setup all the configs related with TP Overlap. The little explanation I've found is this chapter in the NeMo docs.

So far I've seen this folder in NeMo Framework Launcher which contains similar configs to the ones in NeMo 2.0. From the names of the configs I can more or less derive to which hidden size, tp, cp, mbs & sequence length are optimised for but I would like to understand a bit more what is happening & also some guidelines to tune this parameters for my underlaying hardware (GH200).

Thanks!

@denera denera self-assigned this Dec 11, 2024
@denera
Copy link
Collaborator

denera commented Dec 11, 2024

Hi @TJ-Solergibert -- the chapter in NeMo documentation is, I believe, the best description of what TP overlap does conceptually but you're correct that there is not much guidance anywhere for the overlap configs. It's on our roadmap to improve the documentation for this on the TE side, and I'm hoping to include it in some other TP overlap-related items on our agenda in January. In the meantime, I'll share some guidance on the config options here in TE's own API.

In pure-TE/PyTorch code (without NeMo), we initialize and configure TP overlap with transformer_engine.pytorch.initialize_ub(...). The docstring for this shows the structure of the config dictionary, which NeMo constructs from the YAML config files you've seen in their repo. I'll list these options below with some descriptions of what they are and guidance on how to use them:

  • "method" : str
    • "ring_exchange" -- Overlap GEMM with point-to-point send/recv kernels. Send/recv is executed by the Copy Engine and does not require any SMs if "use_ce" : True. Supports both all-gather and reduce-scatter overlaps.
    • "pipeline" -- Overlap GEMM with collective/pipelined reduce-scatter kernel. Does not support all-gather overlap, only reduce-scatter. Requires an SM margin reservation via "num_sm" : int and "set_sm_margin": True.
    • "bulk" -- Overlap GEMM with an independent bulk all-gather or reduce-scatter. Requires an SM margin reservation via reservation via "num_sm" : int and "set_sm_margin": True.
  • "num_splits": int -- Number of chunks to divide the GEMM operation into. This option is ignored when "method" : "ring_exchange" and forced to be equal to the # of GPUs in the tensor-parallel group.
  • "num_max_streams": int -- Maximum number of compute streams for GEMM chunks, capped by the number of chunks (i.e. min(num_splits, num_max_streams)). Default value is set to 3.
  • "use_ce" : bool -- Flag for using the Copy Engine to execute send/recv kernels when "method" : "ring_exchange". Ignored when "method" : "pipeline" or "method": "bulk". Default value set to True.
  • "atomic_gemm" : bool -- Instead of looping over GEMM chunks on the host and launching a separate GEMM kernel for each, launch a single atomic GEMM kernel that loops over chunks on device. Set to False by default, and typically sub-optimal with CUDA graphs.
  • "aggregate" : bool -- Aggregate 2X GEMM chunks when overlapping all-gather. Since "num_splits" is fixed to the tensor-parallel number of GPUs when "method" : "ring_exchange", aggregating 2X chunks per compute stream can help improve the overlap efficiency for smaller problem sizes where the GEMM kernel executes too quickly. Set to False by default.
  • "num_sm" : int -- Number of streaming multiprocessors (SM) to reserve for communication kernels. This is ignored when "set_sm_margin" : False.
  • "set_sm_margin" : bool -- When set to False, GEMM kernels flood all available SMs on device. Otherwise, "num_sm" count of SMs are excluded from the GEMM kernel. This should not be used with "method" : "ring_exchange" and "use_ce" : True` since Copy Engine can execute P2P send/recv without requiring any SMs.

I would recommend starting out with one of the existing configs in NeMo and looking at execution profiles to see if there is any communication left that's not hidden by a GEMM execution. For pipeline overlaps, that typically means you need to increase your SM margin to slow down the GEMM kernel and speed up the comms until you're able to cover all your communication behind the GEMM compute. Conversely, you would need to decrease your SM margin if the GEMM is too slow and communication streams stuck waiting for chunks to finish compute. Ring-exchange overlaps with Copy Engine send/recv can't tune with SM margins but these are also typically performant out-of-the-box.

Since you're using GH200s, I should point out that TP overlap currently does not support Multi-Node NVLink so it restricts your TP size to the 2 GPUs on a single GH200 node. Adding this support is one of the other TP overlap-related items I mentioned on our roadmap, but until then, it may not make sense to use it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants