From fb56be4b0d441c25467b37166b2dbded73d3e40c Mon Sep 17 00:00:00 2001 From: Antonia Elsen Date: Tue, 15 Oct 2024 21:11:17 -0700 Subject: [PATCH] fix: pack channel data tightly (again) --- synapse/tests/test_ndtp.py | 48 ++++++++++++++++-- synapse/tests/test_stream_out.py | 7 ++- synapse/utils/ndtp.pyx | 87 ++++++++++++++++---------------- 3 files changed, 92 insertions(+), 50 deletions(-) diff --git a/synapse/tests/test_ndtp.py b/synapse/tests/test_ndtp.py index 113527c..a83ba3e 100644 --- a/synapse/tests/test_ndtp.py +++ b/synapse/tests/test_ndtp.py @@ -144,6 +144,47 @@ def test_ndtp_payload_broadband(): payload = NDTPPayloadBroadband(is_signed, bit_width, sample_rate, channels) p = payload.pack() + assert p[0] == (bit_width << 1) | (is_signed << 0) + + # number of channels + assert p[1] == 0 + assert p[2] == 0 + assert p[3] == 3 + + # sample rate + assert p[4] == 0 + assert p[5] == 3 + + # ch 0 channel_id, 0 (24 bits, 3 bytes) + assert p[6] == 0 + assert p[7] == 0 + assert p[8] == 0 + + # ch 0 num_samples, 3 (16 bits, 2 bytes) + assert p[9] == 0 + assert p[10] == 3 + + # ch 0 channel_data, 1, 2, 3 (12 bits, 1.5 bytes each) + # 0000 0000 0001 0000 0000 0010 0000 0000 0011 .... + assert p[11] == 0 + assert p[12] == 16 + assert p[13] == 2 + assert p[14] == 0 + assert p[15] >= 3 + + # ch 1 channel_id, 1 (24 bits, 3 bytes, starting from 4 bit offset) + # 0011 0000 0000 0000 0000 0000 0001 .... + assert p[15] == 48 + assert p[16] == 0 + assert p[17] == 0 + assert p[18] >= 16 + + # ch 1 num_samples, 3 (16 bits, 2 bytes, starting from 4 bit offset) + # 0001 0000 0000 0000 0011 .... + assert p[18] == 16 + assert p[19] == 0 + assert p[20] >= 48 + u = NDTPPayloadBroadband.unpack(p) assert u.bit_width == bit_width assert u.is_signed == is_signed @@ -200,7 +241,7 @@ def test_ndtp_header(): with pytest.raises(ValueError): NDTPHeader.unpack( struct.pack(">B", NDTP_VERSION) - + struct.pack(">I", DataType.kBroadband) + + struct.pack(">B", DataType.kBroadband) + struct.pack(">Q", 123) ) @@ -209,12 +250,12 @@ def test_ndtp_message(): header = NDTPHeader(DataType.kBroadband, timestamp=1234567890, seq_number=42) payload = NDTPPayloadBroadband( bit_width=12, - sample_rate=100, + sample_rate=3, is_signed=False, channels=[ NDTPPayloadBroadbandChannelData( channel_id=c, - channel_data=[c * 100 for _ in range(c + 1)], + channel_data=[c * 3 for _ in range(c + 1)], ) for c in range(3) ], @@ -222,6 +263,7 @@ def test_ndtp_message(): message = NDTPMessage(header, payload) packed = message.pack() + unpacked = NDTPMessage.unpack(packed) assert unpacked.header == message.header diff --git a/synapse/tests/test_stream_out.py b/synapse/tests/test_stream_out.py index 109b0f8..56ca3bc 100644 --- a/synapse/tests/test_stream_out.py +++ b/synapse/tests/test_stream_out.py @@ -43,9 +43,9 @@ def test_packing_broadband_data(): # Unsigned sample_data = [ - (1, np.array([1000, 2000, 3000], dtype=np.int16)), - (2, np.array([1234, 1234, 1234, 1234], dtype=np.int16)), - (3, np.array([1000, 2000, 3000, 4000, 3000], dtype=np.int16)), + (1, np.array([1000, 2000, 3000], dtype=np.uint16)), + (2, np.array([1234, 1234, 1234, 1234], dtype=np.uint16)), + (3, np.array([1000, 2000, 3000, 4000, 3000], dtype=np.uint16)), ] bdata = ElectricalBroadbandData( bit_width=12, @@ -61,7 +61,6 @@ def test_packing_broadband_data(): unpacked = NDTPMessage.unpack(p) assert unpacked.header.timestamp == bdata.t0 - assert unpacked.payload.bit_width == 12 assert unpacked.payload.channels[0].channel_id == bdata.samples[i][0] assert list(unpacked.payload.channels[0].channel_data) == list( diff --git a/synapse/utils/ndtp.pyx b/synapse/utils/ndtp.pyx index b0fca4b..d3e8bd3 100644 --- a/synapse/utils/ndtp.pyx +++ b/synapse/utils/ndtp.pyx @@ -14,8 +14,6 @@ from synapse.api.datatype_pb2 import DataType cdef int DATA_TYPE_K_BROADBAND = DataType.kBroadband cdef int DATA_TYPE_K_SPIKETRAIN = DataType.kSpiketrain -cdef object NDTPHeader_STRUCT = struct.Struct(">BIQH") - NDTP_VERSION = 0x01 cdef int NDTPPayloadSpiketrain_BIT_WIDTH = 2 @@ -44,9 +42,16 @@ def to_bytes( else: buffer = existing buffer_length = len(buffer) - bit_offset = (buffer_length - 1) * 8 + writing_bit_offset if buffer_length > 0 else 0 + if buffer_length <= 0: + bit_offset = 0 + else: + if writing_bit_offset > 0: + bit_offset = (buffer_length - 1) * 8 + writing_bit_offset + else: + bit_offset = (buffer_length) * 8 cdef int total_bits_needed = bit_offset + num_bits_to_write + cdef int total_bytes_needed = (total_bits_needed + 7) // 8 # Extend buffer if necessary @@ -162,12 +167,11 @@ def to_ints( cdef int value_index = 0 cdef int max_values = count if count > 0 else (data_len * 8) // bit_width cdef int[::1] values_array = cython.view.array(shape=(max_values,), itemsize=cython.sizeof(cython.int), format="i") - cdef int bit_width_minus1 = bit_width - 1 - cdef int sign_bit = 1 << bit_width_minus1 - cdef uint8_t byte # Declare byte here, outside the loop + cdef int sign_bit = 1 << (bit_width - 1) + cdef uint8_t byte for byte_index in range(data_len): - byte = data_view[byte_index] # Initialize byte inside the loop + byte = data_view[byte_index] if byteorder == 'little': start = start_bit if byte_index == 0 else 0 @@ -286,18 +290,32 @@ cdef class NDTPPayloadBroadband: payload += struct.pack(">H", self.sample_rate) cdef NDTPPayloadBroadbandChannelData c + bit_offset = 0 for c in self.channels: - # Pack channel_id (3 bytes, 24 bits) - payload += c.channel_id.to_bytes(3, byteorder='big', signed=False) + payload, bit_offset = to_bytes( + values=[c.channel_id], + bit_width=24, + is_signed=False, + existing=payload, + writing_bit_offset=bit_offset, + ) - # Pack number of samples (2 bytes, 16 bits) - payload += struct.pack(">H", len(c.channel_data)) - # Pack channel_data - channel_data_bytes, _ = to_bytes( - c.channel_data, self.bit_width, is_signed=self.is_signed + payload, bit_offset = to_bytes( + values=[len(c.channel_data)], + bit_width=16, + is_signed=False, + existing=payload, + writing_bit_offset=bit_offset, + ) + + payload, bit_offset = to_bytes( + values=c.channel_data, + bit_width=self.bit_width, + is_signed=self.is_signed, + existing=payload, + writing_bit_offset=bit_offset, ) - payload += channel_data_bytes return payload @@ -324,33 +342,16 @@ cdef class NDTPPayloadBroadband: cdef list channel_data cdef NDTPPayloadBroadbandChannelData channel + data = data[6:] + bit_offset = 0 for _ in range(num_channels): - # Unpack channel_id (3 bytes, big-endian) - if pos + 3 > len(data): - raise ValueError("Incomplete data for channel_id") - channel_id = int.from_bytes(data[pos:pos+3], 'big') - pos += 3 - - # Unpack num_samples (2 bytes, big-endian) - if pos + 2 > len(data): - raise ValueError("Incomplete data for num_samples") - num_samples = struct.unpack(">H", data[pos:pos+2])[0] - pos += 2 - - # Calculate the number of bits and bytes needed for channel data - total_bits = num_samples * bit_width - bytes_needed = (total_bits + 7) // 8 # Round up to the nearest byte - - # Ensure we have enough data - if pos + bytes_needed > len(data): - raise ValueError("Incomplete data for channel_data") - channel_data_bytes = data[pos:pos + bytes_needed] - pos += bytes_needed - - # Unpack channel_data - channel_data, _, _ = to_ints( - channel_data_bytes, bit_width, num_samples, is_signed=is_signed - ) + a_channel_id, bit_offset, data = to_ints(data=data, bit_width=24, count=1, start_bit=bit_offset, is_signed=is_signed) + channel_id = a_channel_id[0] + + a_num_samples, bit_offset, data = to_ints(data=data, bit_width=16, count=1, start_bit=bit_offset, is_signed=is_signed) + num_samples = a_num_samples[0] + + channel_data, bit_offset, data = to_ints(data=data, bit_width=bit_width, count=num_samples, start_bit=bit_offset, is_signed=is_signed) channel = NDTPPayloadBroadbandChannelData(channel_id, channel_data) channels.append(channel) @@ -452,7 +453,7 @@ cdef class NDTPHeader: cdef public long long timestamp cdef public int seq_number - STRUCT = struct.Struct(">BIQH") # Define as a Python class attribute + STRUCT = struct.Struct(">BBQH") # Define as a Python class attribute def __init__(self, int data_type, long long timestamp, int seq_number): self.data_type = data_type @@ -549,7 +550,7 @@ cdef class NDTPMessage: if isinstance(data, bytes): data = bytearray(data) - cdef int header_size = NDTPHeader_STRUCT.size + cdef int header_size = NDTPHeader.STRUCT.size cdef NDTPHeader header cdef int crc16_value cdef object pbytes