-
Notifications
You must be signed in to change notification settings - Fork 0
/
nest_utils.py
363 lines (312 loc) · 12 KB
/
nest_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# -*- coding: utf-8 -*-
from __future__ import print_function
import __builtin__ # for Python 3: builtins as __builtin__
import os
import sys
import threading
import time
import random
import atexit
import contextlib
import nett_python as nett
import float_message_pb2 as fm
import string_message_pb2 as sm
nett = reload(nett) # In case nett has been changed by the testsuite.
nett.initialize('tcp://127.0.0.1:2001')
if os.name == 'posix' and sys.version_info[0] < 3:
# Import a backport of the subprocess module from Python 3 for Python 2
try:
import subprocess32 as sp
except ImportError:
print('Module subprocess32 not found, using old subprocess module.')
import subprocess as sp
else:
import subprocess as sp
def print(*args, **kwargs):
"""
A cosmetic change to the print function to show more clearly which script
is actually printing when we are running the NESTClient in the same
terminal.
"""
__builtin__.print('[\033[1m\033[93mserver\033[0m] ', end='')
return __builtin__.print(*args, **kwargs)
class observe_slot(threading.Thread):
"""
A listener for messages from the NESTClient. Each listener spawns its own
thread.
:param slot: The nett type slot to receive from
:param message_type: The nett type data-type to receive
:param callback: Optional function to call on receiving data
"""
def __init__(self, slot, message_type, callback=None):
super(observe_slot, self).__init__()
self.slot = slot
self.msg = message_type
self.last_message = None
self.last_message = None
self.callback = callback
self.daemon = True
self.ceased = False
def get_last_message(self):
"""
Gets the last message received.
:returns: The last message received
"""
return self.last_message
def run(self):
"""
Runs the thread.
"""
while not self.ceased:
self.msg.ParseFromString(self.slot.receive())
if self.ceased:
break
if self.msg.value is not None:
self.last_message = self.msg
if self.callback is not None:
self.callback(self.msg)
self.last_message = self.msg
class NESTInterface(object):
"""
For interacting with the NESTClient.
:param networkSpecs: Dictionary of network specifications, including
synapse specifications and projections between layers
:param device_projections: Optional list of projections between layers and
devices
"""
def __init__(self, networkSpecs,
user_id,
device_projections='[]',
silent=False,
socketio=None):
self.networkSpecs = networkSpecs
self.device_projections = device_projections
self.user_id = user_id
self.device_results = None
self.silent = silent
self.socketio = socketio
atexit.register(self.terminate_nest_client)
self.slot_out_data = nett.slot_out_string_message(
'data_{}'.format(self.user_id))
self.slot_in_complete = nett.slot_in_float_message()
self.slot_in_nconnections = nett.slot_in_float_message()
# self.slot_in_gids = nett.slot_in_string_message()
self.slot_in_device_results = nett.slot_in_string_message()
self.slot_in_status_message = nett.slot_in_string_message()
random.seed(self.user_id)
port_increment = random.randint(1, 1000)
client_address = 'tcp://127.0.0.1:{}'.format(8000 + port_increment)
self.slot_in_complete.connect(client_address,
'task_complete_{}'.format(self.user_id))
self.slot_in_nconnections.connect(
client_address, 'nconnections_{}'.format(self.user_id))
# self.slot_in_gids.connect(client_address, 'GIDs')
self.slot_in_device_results.connect(
client_address, 'device_results_{}'.format(self.user_id))
self.slot_in_status_message.connect(
client_address, 'status_message_{}'.format(self.user_id))
self.observe_slot_ready = observe_slot(self.slot_in_complete,
fm.float_message(),
self.handle_complete)
self.observe_slot_nconnections = observe_slot(
self.slot_in_nconnections,
fm.float_message())
self.observe_slot_device_results = observe_slot(
self.slot_in_device_results,
sm.string_message(),
self.handle_device_results)
self.observe_slot_status_message = observe_slot(
self.slot_in_status_message,
sm.string_message(),
self.handle_status_message)
self.observe_slot_ready.start()
self.observe_slot_nconnections.start()
self.observe_slot_device_results.start()
self.observe_slot_status_message.start()
self.event = threading.Event()
with self.wait_for_client(10):
self.start_nest_client()
with self.wait_for_client():
self.reset_kernel()
if self.device_projections != '[]':
self.send_device_projections()
with self.wait_for_client():
self.make_network()
def print(self, *args, **kwargs):
"""
Wrapper around the print function to handle silent mode.
"""
if not self.silent:
print(*args, **kwargs)
@contextlib.contextmanager
def wait_for_client(self, timeout=None):
"""
Context manager for waiting for the client.
"""
self.reset_complete_signal()
yield
self.wait_until_client_finishes(timeout)
def get_valid_msg_value(self, observer):
timeout = 10
n = 0
while observer.get_last_message() is None:
time.sleep(0.1)
n += 1
if n == timeout:
return -1
return observer.get_last_message().value
def start_nest_client(self):
"""
Starting the NEST client in a separate process using the subprocess
module.
"""
cmd = ['python', 'nest_client.py', str(self.user_id)]
if self.silent:
self.client = sp.Popen(cmd + ['-s'], stdout=sp.PIPE)
else:
self.client = sp.Popen(cmd)
self.print('NEST client started')
def terminate_nest_client(self):
"""
Terminates the NEST client subprocess.
"""
self.client.terminate()
stdout, stderr = self.client.communicate()
def cease_threads(self):
"""
Marks the current observing threads as obsolete, so they will stop at
the first available opportunity.
"""
threads_before_cease = len(threading.enumerate())
threads = [self.observe_slot_ready, self.observe_slot_nconnections,
self.observe_slot_device_results,
self.observe_slot_status_message]
for thread in threads:
thread.ceased = True
# The threads are blocking until they receive a message. Therefore we
# make the client ping all slots so that all threads are terminated.
self.send_to_client('ping')
for thread in threads:
thread.join()
# Should be only the main thread now.
threads_after_cease = len(threading.enumerate())
expected_num_threads = threads_before_cease - len(threads)
if threads_after_cease != expected_num_threads:
raise RuntimeError('Did not stop all threads! '
'(Threads: {}; Expected: {})'.format(
threads_after_cease, expected_num_threads))
def handle_complete(self, msg):
"""
Handles receiving complete signal from the client.
:param msg: The nett type message received.
"""
self.print('Received complete signal')
self.event.set()
def reset_complete_signal(self):
"""
Resets the complete signal.
"""
self.event.clear()
def wait_until_client_finishes(self, timeout=None):
"""
Blocks until complete signal from the client is received.
"""
self.print('Waiting for client...')
recv_flag = self.event.wait(timeout)
if not recv_flag:
self.print('WARNING: Event timed out')
def send_to_client(self, label, data=''):
"""
Sends a command or data to the NEST client.
:param label: Command or label for the data
:param data: Data to send
"""
# TODO: check that label and data are strings
msg = sm.string_message()
msg.value = label + ' ' * bool(data) + data
self.slot_out_data.send(msg.SerializeToString())
def reset_kernel(self):
"""
Resets the NEST kernel.
"""
self.send_to_client('reset')
self.print('Sent reset')
def send_device_projections(self):
"""
Sends projections to the NEST client.
"""
with self.wait_for_client():
self.send_to_client('projections', self.device_projections)
self.print('Sent projections')
def make_network(self):
"""
Sends the network specifications to the NEST client, which then creates
the layers and models of nodes.
"""
self.send_to_client('make_network', self.networkSpecs)
# msg = sm.string_message()
# msg.value = self.networkSpecs
# self.slot_out_network.send(msg.SerializeToString())
self.print('Sent make network')
def printGIDs(self, selection):
"""
Prints the selected GIDs to terminal.
:param selection: dictionary containing specifications of the
selected areas
:returns: a list of GIDs
"""
self.print('Sending get GIDs')
with self.wait_for_client():
self.send_to_client('get_gids', selection)
def connect_all(self):
"""
Connects both projections between layers and projections between layers
and devices.
"""
self.print('Sending connect')
with self.wait_for_client():
self.send_to_client('connect')
self.print("Connection complete")
def get_num_connections(self):
"""
Gets the number of connections.
:returns: number of connections
"""
self.print('Sending get Nconnections')
with self.wait_for_client():
self.send_to_client('get_nconnections')
nconnections = int(
self.get_valid_msg_value(self.observe_slot_nconnections))
self.print("Nconnections: {}".format(nconnections))
return nconnections
def simulate(self, t):
"""
Runs a simulation for a specified time.
:param t: time to simulate
"""
self.device_results = None # Clear device results
self.send_to_client('simulate', str(t))
def handle_device_results(self, msg):
"""
Handles receiving device results.
:param msg: Nett type message with the device results
"""
self.print('Received device results:\n' +
'{:>{width}}'.format(msg.value, width=len(msg.value) + 9))
self.device_results = msg.value
def get_device_results(self):
return self.device_results
def handle_status_message(self, msg):
"""
Handles receiving status messages from the NEST client.
:param msg: Nett type message with the status message
"""
user_id = msg.value.split()[0]
message = ' '.join(msg.value.split()[1:])
self.print('Received status message:\n' +
'{:>{width}}'.format(message, width=len(message) + 9))
self.socketio.emit('message',
{'message': message},
namespace='/message/{}'.format(user_id))
# TODO: Use namespace to send to different clients
print('Sent socket msg')