Skip to content

Commit

Permalink
Merge pull request #134 from denisalevi/support_profiling_summary
Browse files Browse the repository at this point in the history
Support reporting kernel timings via `profiling_summary()`
  • Loading branch information
mstimberg authored Jul 28, 2023
2 parents 97a8a74 + c30a2df commit 49fe075
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 56 deletions.
129 changes: 81 additions & 48 deletions brian2genn/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def __init__(self):
'before_end': [],
'after_end': []}

#: Use GeNN's kernel timings?
self.kernel_timings = False

def insert_code(self, slot, code):
'''
Insert custom C++ code directly into ``main.cpp``. The available slots
Expand Down Expand Up @@ -407,14 +410,15 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
# the run_regularly operation (will be directly called from
# engine.cpp)
codeobj = super().code_object(owner, name,
abstract_code,
variables,
'stateupdate',
variable_indices,
codeobj_class=CPPStandaloneCodeObject,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)
abstract_code,
variables,
'stateupdate',
variable_indices,
codeobj_class=CPPStandaloneCodeObject,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)

# FIXME: The following is redundant with what is done during
# the code object creation above. At the moment, the code
# object does not allow us to access the information we
Expand Down Expand Up @@ -464,13 +468,13 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
elif template_name in ['reset', 'synapses', 'stateupdate', 'threshold']:
codeobj_class = GeNNCodeObject
codeobj = super().code_object(owner, name,
abstract_code,
variables,
template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
abstract_code,
variables,
template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)
self.simple_code_objects[codeobj.name] = codeobj
else:
Expand All @@ -491,13 +495,13 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
else:
mrl_template_name='max_row_length_array'
codeobj = super().code_object(owner, mrl_name,
abstract_code,
variables,
mrl_template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
abstract_code,
variables,
mrl_template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)
#self.code_objects['%s_max_row_length' % owner.name] = codeobj
self.code_objects.pop(mrl_name, None) # remove this from the normal list of code objects
Expand All @@ -507,14 +511,14 @@ def code_object(self, owner, name, abstract_code, variables, template_name,
self.max_row_length_run_calls.append('_run_%s();' % mrl_name)

codeobj = super().code_object(owner, name,
abstract_code,
variables,
template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)
abstract_code,
variables,
template_name,
variable_indices,
codeobj_class=codeobj_class,
template_kwds=template_kwds,
override_conditional_write=override_conditional_write,
)
# FIXME: is this actually necessary or is it already added by the super?
self.code_objects[codeobj.name] = codeobj
return codeobj
Expand Down Expand Up @@ -754,14 +758,7 @@ def build(self, directory='GeNNworkspace', compile=True, run=True,
logger.debug(
"Writing GeNN project to directory " + os.path.normpath(directory))

# FIXME: This is only needed to keep Brian2GeNN compatible with Brian2 2.0.1 and earlier
if isinstance(self.arange_arrays, dict):
arange_arrays = sorted([(var, start)
for var, start in
self.arange_arrays.items()],
key=lambda var_start: var_start[0].name)
else:
arange_arrays = self.arange_arrays
arange_arrays = self.arange_arrays

# write the static arrays
for code_object in self.code_objects.values():
Expand Down Expand Up @@ -1726,6 +1723,7 @@ def generate_model_source(self, writer, main_lines, use_GPU):
max_row_length_synapses=self.max_row_length_synapses,
codeobj_inc=codeobj_inc,
dtDef=self.dtDef,
profiled=self.kernel_timings,
prefs=prefs,
precision=precision,
header_files=prefs['codegen.cpp.headers']
Expand All @@ -1741,7 +1739,7 @@ def generate_main_source(self, writer, main_lines):
main_lines=main_lines,
header_files=header_files,
source_files=sorted(self.source_files),
prefs=prefs,
profiled=self.kernel_timings,
)
writer.write('main.*', runner_tmp)

Expand Down Expand Up @@ -1875,11 +1873,17 @@ def copy_source_files(self, writer, directory):
self.header_files.add('b2glib/' + file)

def network_run(self, net, duration, report=None, report_period=10 * second,
namespace=None, profile=False, level=0, **kwds):
if profile is True:
raise NotImplementedError('Brian2GeNN does not yet support '
'detailed profiling.')

namespace=None, profile=None, level=0, **kwds):
self.kernel_timings = profile
# Allow setting `profile` in the `set_device` call (used e.g. in brian2cuda
# SpeedTest configurations)
if profile is None:
self.kernel_timings = self.build_options.pop("profile", None)
# If not set, check the deprecated preference
if profile is None and prefs.devices.genn.kernel_timing:
logger.warn("The preference 'devices.genn.kernel_timing' is "
"deprecated, please set profile=True instead")
self.kernel_timings = True
if kwds:
logger.warn(('Unsupported keyword argument(s) provided for run: '
+ '%s') % ', '.join(kwds.keys()))
Expand All @@ -1906,13 +1910,42 @@ def network_run(self, net, duration, report=None, report_period=10 * second,
# Network.objects to avoid memory leaks
self.net_objects = _get_all_objects(self.net.objects)
super().network_run(net=net, duration=duration,
report=report,
report_period=report_period,
namespace=namespace,
level=level + 1)
report=report,
report_period=report_period,
namespace=namespace,
level=level + 1,
profile=False)

self.run_statement_used = True


def network_get_profiling_info(self, net):
fname = os.path.join(self.project_dir, 'test_output', 'test.time')
if not self.kernel_timings:
raise ValueError("No profiling info collected (need to set "
"profile = True ?)")
net._profiling_info = []
keys = ['neuronUpdateTime',
'presynapticUpdateTime',
'postsynapticUpdateTime',
'synapseDynamicsTime',
'initTime',
'initSparseTime']
with open(fname) as f:
# times are appended as new line in each run
last_line = f.read().splitlines()[-1]
times = last_line.split()
n_time = len(times)
n_key = len(keys)
assert n_time == n_key, (
f'{n_time} != {n_key} \ntimes: {times}\nkeys: {keys}'
)
for key, time in zip(keys, times):
net._profiling_info.append((key, float(time)*second))
return sorted(net._profiling_info, key=lambda item: item[1],
reverse=True)


# ------------------------------------------------------------------------------
# End of GeNNDevice
# ------------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion brian2genn/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __call__(self, value):
validator=lambda value: value is None or os.path.isdir(value)
),
kernel_timing=BrianPreference(
docs='''This preference determines whether GeNN should record kernel runtimes; note that this can affect performance.''',
docs='''This preference determines whether GeNN should record kernel runtimes; note that this can affect performance.
This preference is deprecated, use profile=True in the set_device or run call instead.''',
default=False,
)
)
Expand Down
6 changes: 4 additions & 2 deletions brian2genn/templates/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ int main(int argc, char *argv[])
string cmd= std::string("mkdir ") +OutDir;
system(cmd.c_str());
string name;
{% if profiled %}
name= OutDir+ "/"+ argv[1] + ".time";
FILE *timef= fopen(name.c_str(),"a");

{% endif %}
fprintf(stderr, "# DT %g \n", DT);
fprintf(stderr, "# totalTime %f \n", totalTime);

Expand Down Expand Up @@ -133,10 +134,11 @@ int main(int argc, char *argv[])
eng.run(totalTime); // run for the full duration
{{'\n'.join(code_lines['after_network_run'])|autoindent}}
cerr << t << " done ..." << endl;
{% if prefs['devices.genn.kernel_timing'] %}
{% if profiled %}
{% for kt in ('neuronUpdateTime', 'presynapticUpdateTime', 'postsynapticUpdateTime', 'synapseDynamicsTime', 'initTime', 'initSparseTime') %}
fprintf(timef,"%f ", {{kt}});
{% endfor %}
fprintf(timef,"\n");
{% endif %}

// get the final results from the GPU
Expand Down
2 changes: 1 addition & 1 deletion brian2genn/templates/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ void modelDefinition(NNmodel &model)
{% if precision == 'GENN_FLOAT' %}
model.setTimePrecision(TimePrecision::DOUBLE);
{% endif %}
{% if prefs['devices.genn.kernel_timing'] %}
{% if profiled %}
model.setTiming(true);
{% endif %}
{% for neuron_model in neuron_models %}
Expand Down
6 changes: 5 additions & 1 deletion scripts/run_brian_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import brian2genn
import brian2

import test_utils
skip_args = test_utils.get_skip_args()

if __name__ == '__main__':
success = brian2.test([], test_codegen_independent=False,
test_standalone='genn',
fail_for_not_implemented=False)
fail_for_not_implemented=False,
additional_args=skip_args)
if not success:
sys.exit(1)
6 changes: 5 additions & 1 deletion scripts/run_brian_tests_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

import numpy as np

import test_utils
skip_args = test_utils.get_skip_args()

if __name__ == '__main__':
success = brian2.test([], test_codegen_independent=False,
test_standalone='genn',
fail_for_not_implemented=False,
float_dtype=np.float32)
float_dtype=np.float32,
additional_args=skip_args)
if not success:
sys.exit(1)
6 changes: 5 additions & 1 deletion scripts/run_brian_tests_CPU.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import brian2genn
import brian2

import test_utils
skip_args = test_utils.get_skip_args()

if __name__ == '__main__':
success = brian2.test([], test_codegen_independent=False,
test_standalone='genn',
build_options={'use_GPU': False},
fail_for_not_implemented=False,
reset_preferences=False)
reset_preferences=False,
additional_args=skip_args + ['--collect-only'])
if not success:
sys.exit(1)
6 changes: 5 additions & 1 deletion scripts/run_brian_tests_CPU_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import brian2genn
import brian2

import test_utils
skip_args = test_utils.get_skip_args()

if __name__ == '__main__':
success = brian2.test([], test_codegen_independent=False,
test_standalone='genn',
build_options={'use_GPU': False},
fail_for_not_implemented=False,
float_dtype=np.float32,
reset_preferences=False)
reset_preferences=False,
additional_args=skip_args)
if not success:
sys.exit(1)
5 changes: 5 additions & 0 deletions scripts/skip_tests.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# A list of test name prefixes to skip (usually because they are marked as "standalone-compatible", but make assumptions
# that are only valid for C++ standalone).
# Test names have to use pytest's syntax, i.e. "test_name.py::test_function". Note that these are *prefixes*, so all
# tests with names starting with the given prefix will be skipped.
test_network.py::test_profile
11 changes: 11 additions & 0 deletions scripts/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
Utility functions for testing.
"""
from pathlib import Path
def get_skip_args():
fname = Path(__file__).parent / 'skip_tests.txt'
if not fname.exists():
return []
with open(fname) as f:
lines = f.readlines()
return ["--deselect="+line.strip() for line in lines if line.strip() and not line.startswith('#')]

0 comments on commit 49fe075

Please sign in to comment.