Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to properly use Flops Profiler with pipelined parallelism? #385

Open
flyingdown opened this issue May 9, 2023 · 0 comments
Open

How to properly use Flops Profiler with pipelined parallelism? #385

flyingdown opened this issue May 9, 2023 · 0 comments

Comments

@flyingdown
Copy link

flyingdown commented May 9, 2023

  1. I train a gpt2 model with pipeline parallerlism, Flops Profiler in ds config is useless, it output nothing
  2. So add some code like this:
prof = None
if args.deepspeed:
  prof = FlopsProfiler(model[0])
else:
  prof = FlopsProfiler(model)
...
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
....
if iteration == profile_step and mpu.get_data_parallel_rank() == 0 and mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
  prof.print_model_profile(profile_step=profile_step)
  prof.end_profile()

but output only has fwd info, like:

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 8:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per gpu: 463.21 M
params of model = params per GPU * mp_size: 463.21 M
fwd MACs per GPU: 242304.87 GMACs
fwd flops per GPU: 484678.47 G
fwd flops of model = fwd flops per GPU * mp_size: 484678.47 G
fwd latency: 29.87 s
fwd FLOPS per GPU = fwd flops per GPU / fwd latency: 16.22 TFLOPS
  1. How to interpret the following data,depth 0 and depth 1 spend less time than depth 2 ? Aren't they inclusive relationships?
----------------------------- Aggregated Profile per GPU -----------------------------
Top 1 modules in terms of params, MACs or fwd latency at different model depths:
depth 0:
params - {'PipelineEngine': '463.21 M'}
MACs - {'PipelineEngine': '242304.87 GMACs'}
fwd latency - {'PipelineEngine': '29.87 s'}
depth 1:
params - {'GPTModelPipe': '463.21 M'}
MACs - {'GPTModelPipe': '242304.87 GMACs'}
fwd latency - {'GPTModelPipe': '29.87 s'}
depth 2:
params - {'ParallelTransformerLayerPipe': '402.91 M'}
MACs - {'ParallelTransformerLayerPipe': '228698.42 GMACs'}
fwd latency - {'ParallelTransformerLayerPipe': '55.83 s'}
depth 3:
params - {'ParallelMLP': '268.5 M'}
MACs - {'ParallelMLP': '140737.49 GMACs'}
fwd latency - {'ParallelMLP': '27.79 s'}
  1. The same question with following output.
PipelineEngine(
  463.21 M, 100.00% Params, 233508.78 GMACs, 100.00% MACs, 31.21 s, 100.00% latency, 14.97 TFLOPS,
  (module): GPTModelPipe(
    463.21 M, 100.00% Params, 233508.78 GMACs, 100.00% MACs, 31.21 s, 100.00% latency, 14.97 TFLOPS,
    (tied_modules): ModuleDict(
      60.29 M, 13.02% Params, 0 MACs, 0.00% MACs, 0, 0.00% latency, 0.0 FLOPS,
      (embed): EmbeddingPipe(
        60.29 M, 13.02% Params, 0 MACs, 0.00% MACs, 0, 0.00% latency, 0.0 FLOPS,
        (word_embeddings): VocabParallelEmbedding(51.9 M, 11.21% Params, 0 MACs, 0.00% MACs, 0, 0.00% latency, 0.0 FLOPS, )
        (position_embeddings): Embedding(8.39 M, 1.81% Params, 0 MACs, 0.00% MACs, 0, 0.00% latency, 0.0 FLOPS, 2048, 4096)
        (embedding_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 0, 0.00% latency, 0.0 FLOPS, p=0.1, inplace=False)
      )
    )
    (27): ParallelTransformerLayerPipe(
      50.36 M, 10.87% Params, 27487.79 GMACs, 11.77% MACs, 7.29 s, 23.37% latency, 7.54 TFLOPS,
      (input_layernorm): MixedFusedLayerNorm(8.19 k, 0.00% Params, 0 MACs, 0.00% MACs, 119.17 ms, 0.38% latency, 90.1 GFLOPS, )
      (self_attention): ParallelAttention(
        16.78 M, 3.62% Params, 9895.6 GMACs, 4.24% MACs, 3.0 s, 9.63% latency, 6.59 TFLOPS,
        (query_key_value): ColumnParallelLinear(12.59 M, 2.72% Params, 6597.07 GMACs, 2.83% MACs, 1.11 s, 3.57% latency, 11.86 TFLOPS, )
        (scale_mask_softmax): FusedScaleMaskSoftmax(0, 0.00% Params, 0 MACs, 0.00% MACs, 286.5 ms, 0.92% latency, 0.0 FLOPS, )
        (attention_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 106.24 ms, 0.34% latency, 0.0 FLOPS, p=0.1, inplace=False)
        (dense): RowParallelLinear(4.2 M, 0.91% Params, 2199.02 GMACs, 0.94% MACs, 883.0 ms, 2.83% latency, 4.98 TFLOPS, )
      )
      (post_attention_layernorm): MixedFusedLayerNorm(8.19 k, 0.00% Params, 0 MACs, 0.00% MACs, 115.44 ms, 0.37% latency, 93.01 GFLOPS, )
      (mlp): ParallelMLP(
        33.56 M, 7.25% Params, 17592.19 GMACs, 7.53% MACs, 3.63 s, 11.63% latency, 9.69 TFLOPS,
        (dense_h_to_4h): ColumnParallelLinear(16.78 M, 3.62% Params, 8796.09 GMACs, 3.77% MACs, 1.49 s, 4.77% latency, 11.83 TFLOPS, )
        (dense_4h_to_h): RowParallelLinear(16.78 M, 3.62% Params, 8796.09 GMACs, 3.77% MACs, 2.08 s, 6.65% latency, 8.47 TFLOPS, )
      )
    )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant