From 84b8202764c04f55ca955b2ac9351c8cb2a394b0 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Tue, 10 Sep 2024 11:53:15 +0200 Subject: [PATCH] Make random number generation thread-safe --- brian2/devices/cpp_standalone/codeobject.py | 2 +- brian2/devices/cpp_standalone/device.py | 19 ++++++++++++++----- .../cpp_standalone/templates/objects.cpp | 14 ++++++++++---- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/brian2/devices/cpp_standalone/codeobject.py b/brian2/devices/cpp_standalone/codeobject.py index f0d581335..8fecdabf0 100644 --- a/brian2/devices/cpp_standalone/codeobject.py +++ b/brian2/devices/cpp_standalone/codeobject.py @@ -146,7 +146,7 @@ def generate_rand_code(rand_func, owner): raise AssertionError(rand_func) code = """ double _%RAND_FUNC%(const int _vectorisation_idx) { - return brian::%RK_CALL%(brian::_mersenne_twister_generators[%THREAD_NUMBER%]); + return brian::%RK_CALL%[%THREAD_NUMBER%](brian::_mersenne_twister_generators[%THREAD_NUMBER%]); } """ code = replace( diff --git a/brian2/devices/cpp_standalone/device.py b/brian2/devices/cpp_standalone/device.py index 72eef38f4..f83df4e08 100644 --- a/brian2/devices/cpp_standalone/device.py +++ b/brian2/devices/cpp_standalone/device.py @@ -870,15 +870,24 @@ def generate_main_source(self, writer): nb_threads = prefs.devices.cpp_standalone.openmp_threads if nb_threads == 0: # no OpenMP nb_threads = 1 - main_lines.append(f"for (int _i=0; _i<{nb_threads}; _i++)") + main_lines.append(f"for (int _i=0; _i<{nb_threads}; _i++) {{") if seed is None: # random - main_lines.append( - " brian::_mersenne_twister_generators[_i] = std::mt19937(_rd());" + main_lines.extend( + [ + " brian::_mersenne_twister_generators[_i] = std::mt19937(_rd());", + " brian::_uniform_random[_i].reset();", + " brian::_normal_random[_i].reset();", + ] ) else: - main_lines.append( - f"brian::_mersenne_twister_generators[_i].seed({seed!r}L + _i);" + main_lines.extend( + [ + f" brian::_mersenne_twister_generators[_i].seed({seed!r}L + _i);" + " brian::_uniform_random[_i].reset();", + " brian::_normal_random[_i].reset();", + ] ) + main_lines.append("}") else: raise NotImplementedError(f"Unknown main queue function type {func}") diff --git a/brian2/devices/cpp_standalone/templates/objects.cpp b/brian2/devices/cpp_standalone/templates/objects.cpp index 0008db62f..9e1fac3a8 100644 --- a/brian2/devices/cpp_standalone/templates/objects.cpp +++ b/brian2/devices/cpp_standalone/templates/objects.cpp @@ -32,9 +32,13 @@ set_variable_from_value(name, {{array_name}}, var_size, (char)atoi(s_value.c_str namespace brian { std::string results_dir = "results/"; // can be overwritten by --results_dir command line arg + +// For multhreading, we need one generator for each thread. We also create a distribution for +// each thread, even though this is not strictly necessary for the uniform distribution, as +// the distribution is stateless. std::vector< std::mt19937 > _mersenne_twister_generators; -std::uniform_real_distribution _uniform_random; -std::normal_distribution _normal_random; +std::vector> _uniform_random; +std::vector> _normal_random; //////////////// networks ///////////////// {% for net in networks | sort(attribute='name') %} @@ -228,6 +232,8 @@ void _init_arrays() std::random_device rd; for (int i=0; i<{{openmp_pragma('get_num_threads')}}; i++) _mersenne_twister_generators.push_back(std::mt19937(rd())); + _uniform_random.push_back(std::uniform_real_distribution()); + _normal_random.push_back(std::normal_distribution()); } void _load_arrays() @@ -381,8 +387,8 @@ namespace brian { extern std::string results_dir; // In OpenMP we need one state per thread extern std::vector< std::mt19937 > _mersenne_twister_generators; -extern std::uniform_real_distribution _uniform_random; -extern std::normal_distribution _normal_random; +extern std::vector> _uniform_random; +extern std::vector> _normal_random; //////////////// clocks /////////////////// {% for clock in clocks | sort(attribute='name') %}