Skip to content

Commit

Permalink
Make random number generation thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Sep 10, 2024
1 parent 78322cc commit 84b8202
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion brian2/devices/cpp_standalone/codeobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def generate_rand_code(rand_func, owner):
raise AssertionError(rand_func)
code = """
double _%RAND_FUNC%(const int _vectorisation_idx) {
return brian::%RK_CALL%(brian::_mersenne_twister_generators[%THREAD_NUMBER%]);
return brian::%RK_CALL%[%THREAD_NUMBER%](brian::_mersenne_twister_generators[%THREAD_NUMBER%]);
}
"""
code = replace(
Expand Down
19 changes: 14 additions & 5 deletions brian2/devices/cpp_standalone/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,15 +870,24 @@ def generate_main_source(self, writer):
nb_threads = prefs.devices.cpp_standalone.openmp_threads
if nb_threads == 0: # no OpenMP
nb_threads = 1
main_lines.append(f"for (int _i=0; _i<{nb_threads}; _i++)")
main_lines.append(f"for (int _i=0; _i<{nb_threads}; _i++) {{")
if seed is None: # random
main_lines.append(
" brian::_mersenne_twister_generators[_i] = std::mt19937(_rd());"
main_lines.extend(
[
" brian::_mersenne_twister_generators[_i] = std::mt19937(_rd());",
" brian::_uniform_random[_i].reset();",
" brian::_normal_random[_i].reset();",
]
)
else:
main_lines.append(
f"brian::_mersenne_twister_generators[_i].seed({seed!r}L + _i);"
main_lines.extend(
[
f" brian::_mersenne_twister_generators[_i].seed({seed!r}L + _i);"
" brian::_uniform_random[_i].reset();",
" brian::_normal_random[_i].reset();",
]
)
main_lines.append("}")
else:
raise NotImplementedError(f"Unknown main queue function type {func}")

Expand Down
14 changes: 10 additions & 4 deletions brian2/devices/cpp_standalone/templates/objects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ set_variable_from_value(name, {{array_name}}, var_size, (char)atoi(s_value.c_str
namespace brian {

std::string results_dir = "results/"; // can be overwritten by --results_dir command line arg

// For multhreading, we need one generator for each thread. We also create a distribution for
// each thread, even though this is not strictly necessary for the uniform distribution, as
// the distribution is stateless.
std::vector< std::mt19937 > _mersenne_twister_generators;
std::uniform_real_distribution<double> _uniform_random;
std::normal_distribution<double> _normal_random;
std::vector<std::uniform_real_distribution<double>> _uniform_random;
std::vector<std::normal_distribution<double>> _normal_random;

//////////////// networks /////////////////
{% for net in networks | sort(attribute='name') %}
Expand Down Expand Up @@ -228,6 +232,8 @@ void _init_arrays()
std::random_device rd;
for (int i=0; i<{{openmp_pragma('get_num_threads')}}; i++)
_mersenne_twister_generators.push_back(std::mt19937(rd()));
_uniform_random.push_back(std::uniform_real_distribution<double>());
_normal_random.push_back(std::normal_distribution<double>());
}

void _load_arrays()
Expand Down Expand Up @@ -381,8 +387,8 @@ namespace brian {
extern std::string results_dir;
// In OpenMP we need one state per thread
extern std::vector< std::mt19937 > _mersenne_twister_generators;
extern std::uniform_real_distribution<double> _uniform_random;
extern std::normal_distribution<double> _normal_random;
extern std::vector<std::uniform_real_distribution<double>> _uniform_random;
extern std::vector<std::normal_distribution<double>> _normal_random;

//////////////// clocks ///////////////////
{% for clock in clocks | sort(attribute='name') %}
Expand Down

0 comments on commit 84b8202

Please sign in to comment.