Skip to content

Commit

Permalink
feat: add bin_spike_ms to Spiketrain data (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniaelsen authored Oct 18, 2024
2 parents b6048f9 + d305709 commit 91b6c85
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 29 deletions.
2 changes: 1 addition & 1 deletion synapse/server/nodes/spike_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ async def run(self):
spike_counts.append(spike_count)

await self.emit_data(
SpiketrainData(t0=data.t0, spike_counts=spike_counts)
SpiketrainData(t0=data.t0, bin_size_ms=self.bin_size_ms, spike_counts=spike_counts)
)
5 changes: 4 additions & 1 deletion synapse/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from synapse.api.node_pb2 import NodeType
from synapse.server.entrypoint import ENTRY_DEFAULTS, main as server
from synapse.server.nodes.spectral_filter import SpectralFilter
from synapse.server.nodes.spike_detect import SpikeDetect
from synapse.server.nodes.stream_in import StreamIn
from synapse.server.nodes.stream_out import StreamOut
from synapse.simulator.nodes.electrical_broadband import ElectricalBroadband
from synapse.simulator.nodes.optical_stimulation import OpticalStimulation


SIMULATOR_NODE_OBJECT_MAP = {
NodeType.kStreamIn: StreamIn,
NodeType.kStreamOut: StreamOut,
NodeType.kSpectralFilter: SpectralFilter,
NodeType.kSpikeDetect: SpikeDetect,
NodeType.kElectricalBroadband: ElectricalBroadband,
NodeType.kOpticalStimulation: OpticalStimulation
}
Expand Down
6 changes: 4 additions & 2 deletions synapse/tests/test_ndtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,13 @@ def test_ndtp_payload_broadband_large():
def test_ndtp_payload_spiketrain():
samples = [0, 1, 2, 3, 2]

payload = NDTPPayloadSpiketrain(samples)
payload = NDTPPayloadSpiketrain(10, samples)
packed = payload.pack()
unpacked = NDTPPayloadSpiketrain.unpack(packed)

assert unpacked == payload
assert unpacked.bin_size_ms == 10
assert list(unpacked.spike_counts) == samples


def test_ndtp_header():
Expand Down Expand Up @@ -379,7 +381,7 @@ def test_ndtp_message_broadband_large():

def test_ndtp_message_spiketrain():
header = NDTPHeader(DataType.kSpiketrain, timestamp=1234567890, seq_number=42)
payload = NDTPPayloadSpiketrain(spike_counts=[1, 2, 3, 2, 1])
payload = NDTPPayloadSpiketrain(bin_size_ms=10, spike_counts=[1, 2, 3, 2, 1])
message = NDTPMessage(header, payload)

packed = message.pack()
Expand Down
2 changes: 2 additions & 0 deletions synapse/tests/test_stream_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ def test_packing_spiketrain_data():

sdata = SpiketrainData(
t0=1234567890,
bin_size_ms=10,
spike_counts=[0, 1, 2, 3, 2, 1, 0],
)

packed = node._pack(sdata)[0]
unpacked = NDTPMessage.unpack(packed)

assert unpacked.header.timestamp == sdata.t0
assert unpacked.payload.bin_size_ms == sdata.bin_size_ms
assert len(unpacked.payload.spike_counts) == len(sdata.spike_counts)

assert list(unpacked.payload.spike_counts) == list(sdata.spike_counts)
57 changes: 36 additions & 21 deletions synapse/utils/ndtp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def to_bytes(
elif byteorder == 'big':
byteorder_is_little = False
else:
raise ValueError(f"Invalid byteorder: {byteorder}")
raise ValueError("Invalid byteorder: " + byteorder)

for py_value in values:
value = py_value
if not (min_value <= value <= max_value):
raise ValueError(f"Value {value} cannot be represented in {bit_width} bits")
raise ValueError("Value " + str(value) + " cannot be represented in " + str(bit_width) + " bits")

# Handle negative values for signed integers
if is_signed and value < 0:
Expand Down Expand Up @@ -144,14 +144,14 @@ def to_ints(
if isinstance(data, (bytes, bytearray)):
data_view = data
else:
raise TypeError(f"Unsupported data type: {type(data)}")
raise TypeError("Unsupported data type: " + str(type(data)))

cdef Py_ssize_t data_len = len(data_view)

if count > 0 and data_len < (bit_width * count + 7) // 8:
raise ValueError(
f"insufficient data for {count} x {bit_width} bit values "
f"(expected {(bit_width * count + 7) // 8} bytes, given {data_len} bytes)"
"insufficient data for " + str(count) + " x " + str(bit_width) + " bit values " +
"(expected " + str((bit_width * count + 7) // 8) + " bytes, given " + str(data_len) + " bytes)"
)

cdef int current_value = 0
Expand All @@ -163,7 +163,7 @@ def to_ints(
cdef int value_index = 0
cdef int max_values = count if count > 0 else (data_len * 8) // bit_width
if max_values == 0:
raise ValueError(f"max_values must be > 0 (got {len(data)} data, {count} count, bit width {bit_width})")
raise ValueError("max_values must be > 0 (got " + str(len(data)) + " data, " + str(count) + " count, bit width " + str(bit_width) + ")")
cdef int[::1] values_array = cython.view.array(shape=(max_values,), itemsize=cython.sizeof(cython.int), format="i")
cdef int sign_bit = 1 << (bit_width - 1)
cdef uint8_t byte
Expand Down Expand Up @@ -218,7 +218,7 @@ def to_ints(
return [values_array[i] for i in range(value_index)], end_bit, data

else:
raise ValueError(f"Invalid byteorder: {byteorder}")
raise ValueError("Invalid byteorder: " + byteorder)

if bits_in_current_value > 0:
if bits_in_current_value == bit_width:
Expand All @@ -230,7 +230,7 @@ def to_ints(
value_index += 1
elif count == 0:
raise ValueError(
f"{bits_in_current_value} bits left over, not enough to form a complete value of bit width {bit_width}"
str(bits_in_current_value) + " bits left over, not enough to form a complete value of bit width " + str(bit_width)
)

if count > 0:
Expand Down Expand Up @@ -325,7 +325,7 @@ cdef class NDTPPayloadBroadband:
cdef int len_data = len(data)
if len_data < payload_h_size:
raise ValueError(
f"Invalid broadband data size {len_data}: expected at least {payload_h_size} bytes"
"Invalid broadband data size " + str(len_data) + ": expected at least " + str(payload_h_size) + " bytes"
)

cdef int bit_width = data[0] >> 1
Expand Down Expand Up @@ -370,10 +370,12 @@ cdef class NDTPPayloadBroadband:


cdef class NDTPPayloadSpiketrain:
cdef public int bin_size_ms
cdef public int[::1] spike_counts # Memoryview of integers

def __init__(self, spike_counts):
def __init__(self, bin_size_ms, spike_counts):
cdef int size, i
self.bin_size_ms = bin_size_ms
self.spike_counts = None

if isinstance(spike_counts, list):
Expand Down Expand Up @@ -403,6 +405,9 @@ cdef class NDTPPayloadSpiketrain:
# Pack the number of spikes (4 bytes)
payload += struct.pack(">I", spike_counts_len)

# Pack the bin_size (1 byte)
payload += struct.pack(">B", self.bin_size_ms)

# Pack clamped spike counts
spike_counts_bytes, _ = to_bytes(
clamped_counts, NDTPPayloadSpiketrain_BIT_WIDTH, is_signed=False
Expand All @@ -415,26 +420,36 @@ cdef class NDTPPayloadSpiketrain:
if isinstance(data, bytes):
data = bytearray(data)

cdef str msg;
cdef int len_data = len(data)
if len_data < 4:
raise ValueError(
f"Invalid spiketrain data size {len_data}: expected at least 4 bytes"
)
if len_data < 5:
msg = "Invalid spiketrain data size "
msg += str(len_data)
msg += " bytes: expected at least 5 bytes"
raise ValueError(msg)

cdef int num_spikes = struct.unpack(">I", data[:4])[0]
cdef bytearray payload = data[4:]
cdef int bin_size_ms = struct.unpack(">B", data[4:5])[0]
cdef bytearray payload = data[5:]
cdef int bits_needed = num_spikes * NDTPPayloadSpiketrain_BIT_WIDTH
cdef int bytes_needed = (bits_needed + 7) // 8

if len(payload) < bytes_needed:
raise ValueError("Insufficient data for spike_counts")
msg = "Insufficient data for spiketrain data (expected "
msg += str(bytes_needed)
msg += "bytes for "
msg += str(num_spikes)
msg += " spikes, got "
msg += str(len(payload))
msg += ")"
raise ValueError(msg)

# Unpack spike_counts
spike_counts, _, _ = to_ints(
payload[:bytes_needed], NDTPPayloadSpiketrain_BIT_WIDTH, num_spikes, is_signed=False
)

return NDTPPayloadSpiketrain(spike_counts)
return NDTPPayloadSpiketrain(bin_size_ms, spike_counts)

def __eq__(self, other):
if not isinstance(other, NDTPPayloadSpiketrain):
Expand Down Expand Up @@ -483,13 +498,13 @@ cdef class NDTPHeader:
cdef int expected_size = NDTPHeader.STRUCT.size
if len(data) < expected_size:
raise ValueError(
f"Invalid header size {len(data)}: expected {expected_size}"
"Invalid header size " + str(len(data)) + ": expected " + str(expected_size)
)

version, data_type, timestamp, seq_number = NDTPHeader.STRUCT.unpack(bytes(data[:expected_size]))
if version != NDTP_VERSION:
raise ValueError(
f"Incompatible version {version}: expected {hex(NDTP_VERSION)}, got {hex(version)}"
"Incompatible version " + str(version) + ": expected " + hex(NDTP_VERSION) + ", got " + hex(version)
)

return NDTPHeader(data_type, timestamp, seq_number)
Expand Down Expand Up @@ -566,10 +581,10 @@ cdef class NDTPMessage:
elif pdtype == DataType.kSpiketrain:
payload = NDTPPayloadSpiketrain.unpack(pbytes)
else:
raise ValueError(f"unknown data type {pdtype}")
raise ValueError("unknown data type " + str(pdtype))

if not NDTPMessage.crc16_verify(data[:-2], crc16_value):
raise ValueError(f"CRC16 verification failed (expected {crc16_value})")
raise ValueError("CRC16 verification failed (expected " + str(crc16_value) + ")")

msg = NDTPMessage(header, payload)
msg._crc16 = crc16_value
Expand Down
13 changes: 9 additions & 4 deletions synapse/utils/ndtp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ def to_list(self):


class SpiketrainData:
__slots__ = ["data_type", "t0", "spike_counts"]
__slots__ = ["data_type", "t0", "bin_size_ms", "spike_counts"]

def __init__(self, t0, spike_counts):
def __init__(self, t0, bin_size_ms, spike_counts):
self.data_type = DataType.kSpiketrain
self.t0 = t0
self.bin_size_ms = bin_size_ms
self.spike_counts = spike_counts

def pack(self, seq_number: int):
Expand All @@ -114,7 +115,10 @@ def pack(self, seq_number: int):
timestamp=self.t0,
seq_number=seq_number,
),
payload=NDTPPayloadSpiketrain(spike_counts=self.spike_counts),
payload=NDTPPayloadSpiketrain(
bin_size_ms=self.bin_size_ms,
spike_counts=self.spike_counts
),
)

return [message.pack()]
Expand All @@ -123,6 +127,7 @@ def pack(self, seq_number: int):
def from_ndtp_message(msg: NDTPMessage):
return SpiketrainData(
t0=msg.header.timestamp,
bin_size_ms=msg.payload.bin_size_ms,
spike_counts=msg.payload.spike_counts,
)

Expand All @@ -132,7 +137,7 @@ def unpack(data):
return SpiketrainData.from_ndtp_message(u)

def to_list(self):
return [self.t0, list(self.spike_counts)]
return [self.t0, self.bin_size_ms, list(self.spike_counts)]


SynapseData = Union[SpiketrainData, ElectricalBroadbandData]

0 comments on commit 91b6c85

Please sign in to comment.