[BUG] Inconsistent loss between overlap_comm=true
and overlap_comm=false
#1004
Labels
bug
Something isn't working
overlap_comm=true
and overlap_comm=false
#1004
Describe the bug
DeepSpeed provides a ZeRO configuration property
overlap_comm
which according to the documentation Attempts to overlap the reduction of the gradients with backward computation (Ref: https://www.deepspeed.ai/docs/config-json/). I'm noticing that the lm loss is different whenoverlap_comm=true
compared to whenoverlap_comm=false
To Reproduce
I'm running training with
enwik8_text_document
corpus. Here is my configuration.Here are the observed losses
One may argue that the losses are "close". However, the expectation is that the computation should be exact. Given that only
overlap_comm
setting has changed makes me wonder ifoverlap_comm
implicitly introduces some sort of a data copy race condition and the losses diverge more so overtime.Expected behavior
The expectation is that the losses should be exact. DeepSpeed doesn't have the contract that with
overlap_comm
the computation is only loosely correct.Proposed solution
I've spent some time looking at the nsys profiles for this training, but nothing immediately stands out. Using Stage2, there is some data movement which happens across different buffers.
In stage_1_and_2.py:905 the stream is set to
reduction_stream
which waits for default stream after which a sequence of reduce operations are scheduled.After
average_tensor
completes there is a call to copy_grads_in_partition which copies the reduced values from the ipg_buffer to the newly allocated buffer to hold the gradients.With
overlap_comm
is enabled, thecudaMemCopyAsync
op is not synchronized with the completion of the reduce operation && hence the data that is copied over (thinking that it is the reduced result) may or may not have been reduced. This is happening because the collective wait semantic is to only synchronize the completion of the Reduce Op on the collective stream with the default stream. Whenoverlap_comm
is enabled, thereduction_stream
is used && thewait()
operation will not synchronize with this. This can be confirmed from the implementation in ProcessGroupNCCL as well as PyTorch documentation on the use of thewait
semantic with asynchronous collectives.However, if what I'm saying here was the case, the whole
overlap_comm
implementation is incorrect. I'll create a similar issue with DeepSpeed as well. But wanted to bring this up here incase anyone else has noticed different loss dynamic when overlap_comm=false
is toggled.Screenshots
N/A
Environment (please complete the following information):
Additional context
The text was updated successfully, but these errors were encountered: