diff --git a/include/kamping/measurements/timer.hpp b/include/kamping/measurements/timer.hpp index ec2f1d83a..e51ef193d 100644 --- a/include/kamping/measurements/timer.hpp +++ b/include/kamping/measurements/timer.hpp @@ -128,7 +128,7 @@ class Timer { /// @brief Constructs a timer using a given communicator. /// /// @param comm Communicator in which the time measurements are executed. - Timer(CommunicatorType const& comm) : _timer_tree{}, _comm{comm} {} + Timer(CommunicatorType const& comm) : _timer_tree{}, _comm{comm}, _is_timer_enabled{true} {} /// @brief Synchronizes all ranks in the underlying communicator via a barrier and then start the measurement with /// the given key. @@ -199,6 +199,15 @@ class Timer { _timer_tree.reset(); } + /// @brief (Re-)Enable start/stop operations. + void enable() { + _is_timer_enabled = true; + } + /// @brief Disable start/stop operations, i.e., start()/stop() operations do not have any effect. + void disable() { + _is_timer_enabled = false; + } + /// @brief Aggregates and outputs the executed measurements. The output is done via the print() /// method of a given Printer object. /// @@ -220,11 +229,15 @@ class Timer { private: internal::TimerTree - _timer_tree; ///< Timer tree used to represent the hierarchical time measurements. - CommunicatorType const& _comm; ///< Communicator in which the time measurements take place. + _timer_tree; ///< Timer tree used to represent the hierarchical time measurements. + CommunicatorType const& _comm; ///< Communicator in which the time measurements take place. + bool _is_timer_enabled; ///< Flag indicating whether start/stop operations are enabled. /// @brief Starts a time measurement. void start_impl(std::string const& key, bool use_barrier) { + if (!_is_timer_enabled) { + return; + } auto& node = _timer_tree.current_node->find_or_insert(key); node.is_active(true); if (use_barrier) { @@ -243,6 +256,9 @@ class Timer { void stop_impl( LocalAggregationMode local_aggregation_mode, std::vector const& global_aggregation_modes ) { + if (!_is_timer_enabled) { + return; + } auto endpoint = Environment<>::wtime(); KASSERT( _timer_tree.current_node->is_active(), diff --git a/tests/measurements/mpi_timer_test.cpp b/tests/measurements/mpi_timer_test.cpp index 3ee7726b7..288c8ef95 100644 --- a/tests/measurements/mpi_timer_test.cpp +++ b/tests/measurements/mpi_timer_test.cpp @@ -67,7 +67,7 @@ struct VisitorReturningSizeAndCategory { return std::make_pair(vec.size(), false); } }; -// Traverses the evaluation tree and returns a smmary of the aggregated data that can be used to verify to some degree +// Traverses the evaluation tree and returns a summary of the aggregated data that can be used to verify to some degree // the executed timings struct ValidationPrinter { void print(measurements::AggregatedTreeNode const& node) { @@ -373,3 +373,41 @@ TEST(TimerTest, singleton) { EXPECT_EQ(printer.output, expected_output); } } + +TEST(TimerTest, enable_disable) { + Communicator<> comm; + Timer<> timer; + timer.disable(); + timer.start("measurement1"); + timer.enable(); + { + timer.start("measurement11"); + timer.stop({measurements::GlobalAggregationMode::gather, measurements::GlobalAggregationMode::max}); + timer.start("measurement12"); + { + timer.synchronize_and_start("measurement121"); + timer.stop(); + } + timer.stop(); + timer.start("measurement11"); + timer.stop(); + } + timer.disable(); + timer.stop_and_append(); + timer.enable(); + ValidationPrinter printer; + timer.aggregate_and_print(printer); + if (comm.is_root()) { + std::unordered_map expected_output{ + {"root.measurement11:gather", + AggregatedDataSummary{}.set_is_scalar(false).set_num_entries(1).set_num_values(comm.size())}, + {"root.measurement12:max", + AggregatedDataSummary{}.set_is_scalar(true).set_num_entries(1).set_num_values(1)}, + {"root.measurement12.measurement121:max", + AggregatedDataSummary{}.set_is_scalar(true).set_num_entries(1).set_num_values(1)}, + {"root.measurement11:max", + AggregatedDataSummary{}.set_is_scalar(true).set_num_entries(1).set_num_values(1)}, + }; + EXPECT_EQ(printer.output, expected_output); + }; +}