diff --git a/examples/transformer/models/GPT/docs/hybrid_profiler.md b/examples/transformer/models/GPT/docs/hybrid_profiler.md index 21e258ba2..ec9c51e0f 100644 --- a/examples/transformer/models/GPT/docs/hybrid_profiler.md +++ b/examples/transformer/models/GPT/docs/hybrid_profiler.md @@ -64,7 +64,7 @@ cd PaddleFleetX/examples/transformer/models/GPT # 如果已在此目录下,则 ``` python -m paddle.distributed.launch \ ./pretrain/run.py -c \ - ./pretrain/configs/pretrain_gpt_1.3B_dp8.yaml -o Profiler.enable=True + ./pretrain/configs/pretrain_gpt_1.3B_dp8.yaml -o Profiler.enable=True -o Global.max_steps=6 ``` diff --git a/examples/transformer/utils/components.py b/examples/transformer/utils/components.py index c04432521..8467cd987 100644 --- a/examples/transformer/utils/components.py +++ b/examples/transformer/utils/components.py @@ -151,6 +151,7 @@ def build_profiler(profiler_config): profiler_log = profiler_config.get('profiler_log', './profiler_log') record_shapes = profiler_config.get('record_shapes', True) profile_memory = profiler_config.get('profile_memory', True) + with_flops = profiler_config.get('with_flops', True) profiler = paddle.profiler.Profiler( targets=[ paddle.profiler.ProfilerTarget.CPU, @@ -159,7 +160,8 @@ def build_profiler(profiler_config): scheduler=scheduler, on_trace_ready=paddle.profiler.export_chrome_tracing(profiler_log), record_shapes=record_shapes, - profile_memory=profile_memory) + profile_memory=profile_memory, + with_flops=with_flops) profiler.start() logger.warning("Profiler is enabled, do not enable it in production.")