diff --git a/docs/source/protocol_proposed.rst b/docs/source/protocol_proposed.rst index 369d4eb..cced67a 100644 --- a/docs/source/protocol_proposed.rst +++ b/docs/source/protocol_proposed.rst @@ -20,14 +20,13 @@ The suggested changes to the protocol are as follows: (bit) (imdpull: true or false) (bit) (imdwait: true or false) (bit) (imdterm: true or false) - (bit) (wrapped positions: true or false) - (bit) (energies included: true or false) - (bit) (dimensions included: true or false) - (bit) (positions included: true or false) - (bit) (velocities included: true or false) - - (bit) (forces included: true or false) - (7 bits) (unused) + (bit) (wrapped positions: true or false. if positions rate is 0, this is a placeholder value) + + (int32) (energies rate: number of steps that elapse between steps that contain energy data. 0 means never) + (int32) (dimensions rate: number of steps that elapse between steps that contain dimension data. 0 means never) + (int32) (positions rate: number of steps that elapse between steps that contain position data. 0 means never) + (int32) (velocities rate: number of steps that elapse between steps that contain velocity data. 0 means never) + (int32) (forces rate: number of steps that elapse between steps that contain force data. 0 means never) "wrapped positions" will be a new ``.mdp`` setting which specifies whether the atoms' positions should be adjusted to fit within the simulation box before sending. This is useful for visualization purposes. diff --git a/imdreader/IMDProtocol.py b/imdreader/IMDProtocol.py index bc86c2e..9fac5d9 100644 --- a/imdreader/IMDProtocol.py +++ b/imdreader/IMDProtocol.py @@ -2,6 +2,10 @@ import struct import logging from enum import Enum, auto +from typing import Union +from dataclasses import dataclass +import abc +import threading """ IMD Packets have an 8 byte header and a variable length payload @@ -17,7 +21,8 @@ """ IMDHEADERSIZE = 8 IMDENERGYPACKETLENGTH = 40 -IMDVERSION = 2 +IMDBOXPACKETLENGTH = 36 +IMDVERSIONS = {2, 3} class IMDType(Enum): @@ -31,6 +36,11 @@ class IMDType(Enum): IMD_PAUSE = 7 IMD_TRATE = 8 IMD_IOERROR = 9 + # New in IMD v3 + IMD_BOX = 10 + IMD_VELS = 11 + IMD_FORCES = 12 + IMD_EOS = 13 class IMDHeader: @@ -41,6 +51,32 @@ def __init__(self, msg_type: IMDType, length: int): self.length = length +@dataclass +class IMDSessionInfo: + """Convenience class to represent the session information of an IMD connection + + '<' represents little endian and '>' represents big endian + + Data should be loaded into and out of buffers in the order of the fields in this class + if present in the session for that step, i.e. + 1. energies, + 2. dimensions, + etc. + """ + + version: int + endianness: str + imdterm: Union[bool, None] + imdwait: Union[bool, None] + imdpull: Union[bool, None] + wrapped_coords: bool + energies: int + dimensions: int + positions: int + velocities: int + forces: int + + def create_header_bytes(msg_type: IMDType, length: int): # NOTE: add error checking for invalid packet msg_type here type = msg_type.value diff --git a/imdreader/IMDREADER.py b/imdreader/IMDREADER.py index d023b36..2de57d5 100644 --- a/imdreader/IMDREADER.py +++ b/imdreader/IMDREADER.py @@ -49,6 +49,7 @@ FrameIteratorAll, FrameIteratorSliced, ) +from MDAnalysis.coordinates import core from MDAnalysis.lib.util import store_init_arguments from .IMDProtocol import * from .util import * @@ -93,16 +94,11 @@ def __init__( is the port number. """ + self._producer = None super(IMDReader, self).__init__(filename, **kwargs) - - # NOTE: Replace me with header packet which contains this information - # OR get this information from the topology? if not n_atoms: raise ValueError("`n_atoms` kwarg must be specified") self.n_atoms = n_atoms - self.ts = self._Timestep( - self.n_atoms, positions=True, **self._ts_kwargs - ) self.units = { "time": "ps", @@ -110,32 +106,34 @@ def __init__( "force": "kJ/(mol*nm)", } - self._buffer = CircularByteBuf(buffer_size, self.n_atoms, self.ts) - - self._attempt_event = threading.Event() - self._success_event = threading.Event() - self._producer = IMDProducer( - filename, - self._buffer, - self.n_atoms, - self._attempt_event, - self._success_event, - socket_bufsize=socket_bufsize, - ) + self._host, self._port = parse_host_port(filename) + self._buffer_size = buffer_size + self._socket_bufsize = socket_bufsize self._frame = -1 def _read_next_timestep(self): if self._frame == -1: - self._producer.start() - # Wait for producer to connect to server - self._attempt_event.wait() - logger.debug("IMDReader: Waiting for producer to connect to server") - if not self._success_event.wait(timeout=5): - raise ConnectionRefusedError( - "IMDReader: Failed to connect to server" + # Reader is responsible for performing handshake + # and parsing the configuration before + # passing the connection off to the appropriate producer + # and allocating an appropriate buffer + conn = self._connect_to_server() + imdsinfo = self._await_IMD_handshake(conn) + + if imdsinfo.version == 2: + self._ts = self._Timestep(self.n_atoms, positions=True) + self._buffer = CircularByteBuf( + self._buffer_size, self.n_atoms, self._ts, imdsinfo + ) + self._producer = IMDv2Producer( + conn, + self._buffer, + imdsinfo, + self.n_atoms, ) - logger.debug("IMDReader: Producer connected to server") + # Producer responsible for sending go packet + self._producer.start() return self._read_frame(self._frame + 1) @@ -143,15 +141,86 @@ def _read_frame(self, frame): # loads the timestep with step, positions, and energy self._buffer.consume_next_timestep() - self.ts.frame = self._frame + 1 - self.ts.dimensions = None + self._ts.frame = self._frame + 1 # Must set frame after read occurs successfully # Since buffer raises IO error # after producer is finished and there are no more frames self._frame = frame - return self.ts + return self._ts + + def _connect_to_server(self): + """ + Establish connection with the server, failing out if this + does not occur within 5 seconds. + """ + conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self._socket_bufsize is not None: + conn.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize + ) + conn.settimeout(5) + try: + conn.connect((self._host, self._port)) + except ConnectionRefusedError: + logger.error( + f"IMDReader: Connection to {self._host}:{self._port} refused" + ) + raise ConnectionRefusedError( + f"IMDReader: Connection to {self._host}:{self._port} refused" + ) + conn.settimeout(None) + return conn + + def _await_IMD_handshake(self, conn) -> IMDSessionInfo: + """ + Wait for the server to send a handshake packet, then parse + endianness and version information and IMD session configuration. + """ + end = ">" + ver = None + + handshake = parse_header_bytes(conn.recv(IMDHEADERSIZE)) + if handshake.type != IMDType.IMD_HANDSHAKE: + raise ValueError( + f"Expected packet type {IMDType.IMD_HANDSHAKE}, got {handshake.type}" + ) + + if handshake.length not in IMDVERSIONS: + # Try swapping endianness + swapped = struct.unpack("i", handshake.length))[ + 0 + ] + if swapped not in IMDVERSIONS: + err_version = min(swapped, handshake.length) + # Don't call stop, simulation hasn't started yet + raise ValueError( + f"Incompatible IMD version. Expected version in {IMDVERSIONS}, got {err_version}" + ) + else: + end = "<" + ver = swapped + else: + ver = handshake.length + + sinfo = None + if ver == 2: + # IMD v2 does not send a configuration packet + sinfo = IMDSessionInfo( + version=ver, + endianness=end, + imdterm=None, + imdwait=None, + imdpull=None, + wrapped_coords=False, + energies=1, + dimensions=0, + positions=1, + velocities=0, + forces=0, + ) + return sinfo @property def n_frames(self): @@ -171,8 +240,9 @@ def _format_hint(thing): def close(self): """Gracefully shut down the reader. Stops the producer thread.""" - self._producer.stop() - self._producer.join() + if self._producer is not None: + self._producer.stop() + self._producer.join() print("IMDReader shut down gracefully.") # Incompatible methods @@ -189,75 +259,46 @@ def __getitem__(self, frame): raise RuntimeError("IMDReader: Trajectory can only be read in for loop") -class IMDProducer(threading.Thread): - """ - Producer thread for IMDReader. Reads packets from the socket - and places them into the shared buffer. - """ +class AbstractIMDProducer(abc.ABC, threading.Thread): - def __init__( - self, - filename, - buffer, - n_atoms, - attempt_event=None, - success_event=None, - socket_bufsize=None, - pausable=True, - ): - super(IMDProducer, self).__init__() - self._host, self._port = parse_host_port(filename) - self._conn = None + def __init__(self, conn, buffer, imdsinfo, n_atoms): + # call threading.Thread init + super(AbstractIMDProducer, self).__init__() + self._conn = conn + self.imdsinfo = imdsinfo self.running = False - - self._attempt_event = attempt_event - self._success_event = success_event - - self._buffer = buffer - self.n_atoms = n_atoms - self._expected_data_bytes = 12 * n_atoms - self._socket_bufsize = socket_bufsize - self.pausable = pausable self.paused = False - # < represents little endian and > represents big endian - # we assume big by default and use handshake to check + self._is_disconnected = False + self._buffer = buffer + self.parsed_frames = 0 self._full_frames = 0 self._parse_frame_time = 0 - # Saving memory by preallocating space for the frame - # we're loading into the buffer - self._energy_byte_buf = bytearray(IMDENERGYPACKETLENGTH) - self._energy_byte_view = memoryview(self._energy_byte_buf) - self._body_byte_buf = bytearray(self._expected_data_bytes) - self._body_byte_view = memoryview(self._body_byte_buf) - - self._is_disconnected = False # The body of a force or position packet should contain # (4 bytes per float * 3 atoms * n_atoms) bytes - self._expected_data_bytes = 12 * self.n_atoms + self.n_atoms = n_atoms + self._data_bytes = 12 * n_atoms + + self._byte_dict = {} + if self.imdsinfo.energies > 0: + self._byte_dict["energies"] = bytearray(40) + if self.imdsinfo.dimensions > 0: + self._byte_dict["dimensions"] = bytearray(36) + if self.imdsinfo.positions > 0: + self._byte_dict["positions"] = bytearray(self._data_bytes) + if self.imdsinfo.velocities > 0: + self._byte_dict["velocities"] = bytearray(self._data_bytes) + if self.imdsinfo.forces > 0: + self._byte_dict["forces"] = bytearray(self._data_bytes) + + @abc.abstractmethod + def _parse_and_insert_frame(self): + pass - def _await_IMD_handshake(self): - """ - Wait for the server to send a handshake packet, set endianness, - and check IMD Protocol version. - """ - print("waiting for handshake...") - handshake = self._expect_header(expected_type=IMDType.IMD_HANDSHAKE) - if handshake.length != IMDVERSION: - # Try swapping endianness - swapped = struct.unpack("i", handshake.length))[ - 0 - ] - if swapped != IMDVERSION: - err_version = min(swapped, handshake.length) - # Don't call stop, simulation hasn't started yet - raise ValueError( - f"Incompatible IMD version. Expected {IMDVERSION}, got {err_version}" - ) - else: - self._buffer.inform_endianness("<") - print("handshake received") + @abc.abstractmethod + def _check_end_of_simulation(self): + pass def _send_go_packet(self): """ @@ -267,6 +308,7 @@ def _send_go_packet(self): print("sending go packet...") go = create_header_bytes(IMDType.IMD_GO, 0) self._conn.sendall(go) + logger.debug("IMDProducer: Sent go packet to server") def _pause_simulation(self): """ @@ -278,6 +320,9 @@ def _pause_simulation(self): logger.debug( "IMDProducer: Pausing simulation because buffer is almost full" ) + # make the socket non-blocking so we can use BlockingIOError to check + # if we can unpause + self._conn.settimeout(0) self.paused = True except ConnectionResetError as e: self._is_disconnected = True @@ -289,6 +334,9 @@ def _pause_simulation(self): def _unpause_simulation(self): try: + logger.debug( + "IMDProducer: Waiting to unpause until buffer almost empty" + ) self._buffer.wait_almost_empty() logger.debug("IMDProducer: Unpausing simulation, buffer has space") pause = create_header_bytes(IMDType.IMD_PAUSE, 0) @@ -303,52 +351,22 @@ def _unpause_simulation(self): "data likely lost in frame {}".format(self.parsed_frames) ) - def _connection_sequence(self): - """ - Establish connection with the server and perform - the handshake and go packet exchange. - """ - self._conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self._socket_bufsize is not None: - self._conn.setsockopt( - socket.SOL_SOCKET, socket.SO_RCVBUF, self._socket_bufsize - ) - self._conn.settimeout(5) - try: - self._attempt_event.set() - self._conn.connect((self._host, self._port)) - except ConnectionRefusedError: - self.stop() - logger.error( - f"IMDProducer: Connection to {self._host}:{self._port} refused" - ) - raise ConnectionRefusedError( - f"IMDProducer: Connection to {self._host}:{self._port} refused" - ) - self._success_event.set() - self._conn.settimeout(None) - self._await_IMD_handshake() - self._send_go_packet() - self.running = True - self._is_disconnected = False - def run(self): """ Producer thread method. Reads from the socket and sends a 'pause' signal if needed. """ - self._connection_sequence() - + self._send_go_packet() + self.running = True while self.running: self._parse_and_insert_frame() # If buffer is more than 50% full, pause the simulation - if self.pausable: - if ( - not self.paused - and self._full_frames >= self._buffer.capacity // 2 - ): - self._pause_simulation() + if ( + not self.paused + and self._full_frames >= self._buffer.capacity // 2 + ): + self._pause_simulation() # Simulation will unpause during frame parsing if # 1. buffer is empty @@ -359,34 +377,8 @@ def run(self): if not self.paused: self._check_end_of_simulation() - # Reset timeout if it was changed during the loop - self._conn.settimeout(None) - return - def _parse_and_insert_frame(self): - with timeit() as parse_frame: - self._expect_header( - expected_type=IMDType.IMD_ENERGIES, expected_value=1 - ) - energies = self._recv_data("energy") - self._expect_header( - expected_type=IMDType.IMD_FCOORDS, - expected_value=self.n_atoms, - ) - pos = self._recv_data("body") - self._full_frames = self._buffer.insert(energies, pos) - - self.parsed_frames += 1 - # Use the longest parse frame time to calculate - # the timeout for the next frame - self._parse_frame_time = max( - parse_frame.elapsed, self._parse_frame_time - ) - logger.debug( - f"IMDProducer: Added frame #{self.parsed_frames - 1} to buffer in {parse_frame.elapsed} seconds" - ) - def stop(self): # Tell reader not to expect more frames to be added self._buffer.producer_finished = True @@ -394,74 +386,27 @@ def stop(self): self._ensure_disconnect() self.running = False - def _check_end_of_simulation(self): - # It is the server's reponsibility - # to close the connection, but this may not happen in time - # for us to check during the last frame - # Therefore, use a timeout on a peek to check if the server has closed - try: - logger.debug( - f"IMDProducer: Checking for frame #{self.parsed_frames}" - ) - # Continue reading if socket contains bytes - self._conn.settimeout(5 * self._parse_frame_time) - b = self._conn.recv(1, socket.MSG_PEEK) - # Two cases for ending producer reads from socket - - # case 1: server has closed the connection and we have no more data to read - if not b: - logger.debug( - "IMDProducer: Assuming simulation is over at frame " - "{} due to closed connection".format(self.parsed_frames - 1) - ) - self.running = False - self._buffer.producer_finished = True - # case 2: server has not closed the connection and we have no more data to read - except socket.timeout: - logger.debug( - "IMDProducer: Assuming simulation is over at frame " - "#{} due to read timeout,".format(self.parsed_frames - 1) - ) - self.running = False - self._buffer.producer_finished = True - def _handle_signal(self, signum, frame): """Handle SIGINT and SIGTERM signals.""" self.running = False self._ensure_disconnect() - def _recv_data(self, type): - """Used to receive headers and data packets from the socket. - For energies and positions, the data is stored in a preallocated - buffer to avoid memory allocation during the run loop. + def _recv_data(self, bytearray): + """Used to receive headers and data packets from the socket + into self._frame_buf at a specified offset. This method will behave differently is self.paused is True, since having no more data available in the socket doesn't indicate the connection is closed, but rather that the simulation is now ready to be unpaused - - ``type`` can be one of: "header", "body", "energy" """ + logger.debug(f"IMDProducer: Receiving {len(bytearray)} bytes") - if type == "header": - data = bytearray(IMDHEADERSIZE) - view = memoryview(data) - - elif type == "body": - data = self._body_byte_buf - view = self._body_byte_view - - elif type == "energy": - data = self._energy_byte_buf - view = self._energy_byte_view - - if self.paused: - self._conn.settimeout(0) - + memview = memoryview(bytearray) total_received = 0 - while total_received < len(data): + while total_received < len(bytearray): try: - chunk = self._conn.recv(len(data) - total_received) + chunk = self._conn.recv(len(bytearray) - total_received) if not chunk: self._is_disconnected = True self.stop() @@ -476,16 +421,23 @@ def _recv_data(self, type): if self.paused: self._unpause_simulation() continue - view[total_received : total_received + len(chunk)] = chunk + except Exception as e: + logger.error( + f"IMDProducer: Error receiving data: {e}. Stopping producer" + ) + memview[total_received : total_received + len(chunk)] = chunk total_received += len(chunk) - - return data + logger.debug( + f"IMDProducer: Receiving data. Total received: {total_received}" + ) def _expect_header(self, expected_type=None, expected_value=None): """ Read a header packet from the socket. """ - header = parse_header_bytes(self._recv_data("header")) + header_bytes = bytearray(IMDHEADERSIZE) + self._recv_data(header_bytes) + header = parse_header_bytes(header_bytes) if expected_type is not None and header.type != expected_type: self.stop() raise ValueError( @@ -520,29 +472,210 @@ def _disconnect(self): ) +class IMDv2Producer(AbstractIMDProducer): + """ + Producer thread for IMDReader. Reads packets from the socket + and places them into the shared buffer. + """ + + def __init__(self, conn, buffer, imdsinfo, n_atoms): + super(IMDv2Producer, self).__init__(conn, buffer, imdsinfo, n_atoms) + self._data_elements = {"positions", "energies"} + + def _parse_and_insert_frame(self): + with timeit() as parse_frame: + self._expect_header( + expected_type=IMDType.IMD_ENERGIES, expected_value=1 + ) + logger.debug("IMDProducer: Received energies packet") + self._recv_data(self._byte_dict["energies"]) + self._expect_header( + expected_type=IMDType.IMD_FCOORDS, + expected_value=self.n_atoms, + ) + logger.debug("IMDProducer: Received positions packet") + self._recv_data(self._byte_dict["positions"]) + self._full_frames = self._buffer.insert( + self._byte_dict, self._data_elements + ) + + self.parsed_frames += 1 + # Use the longest parse frame time to calculate + # the timeout for the next frame + self._parse_frame_time = max( + parse_frame.elapsed, self._parse_frame_time + ) + logger.debug( + f"IMDProducer: Added frame #{self.parsed_frames - 1} to buffer in {parse_frame.elapsed} seconds" + ) + + def _check_end_of_simulation(self): + # It is the server's reponsibility + # to close the connection, but this may not happen in time + # for us to check during the last frame + # Therefore, use a timeout on a peek to check if the server has closed + try: + logger.debug( + f"IMDProducer: Checking for frame #{self.parsed_frames}" + ) + # Continue reading if socket contains bytes + self._conn.settimeout(5 * self._parse_frame_time) + b = self._conn.recv(1, socket.MSG_PEEK) + # Two cases for ending producer reads from socket + + # case 1: server has closed the connection and we have no more data to read + if not b: + logger.debug( + "IMDProducer: Assuming simulation is over at frame " + "{} due to closed connection".format(self.parsed_frames - 1) + ) + self.running = False + self._buffer.producer_finished = True + # case 2: server has not closed the connection and we have no more data to read + except socket.timeout: + logger.debug( + "IMDProducer: Assuming simulation is over at frame " + "#{} due to read timeout,".format(self.parsed_frames - 1) + ) + self.running = False + self._buffer.producer_finished = True + + self._conn.settimeout(None) + + +class IMDv3Producer(AbstractIMDProducer): + def _parse_and_insert_frame(self): + with timeit() as parse_frame: + + data_elements = set() + if ( + "energies" in self._byte_dict + and (self.parsed_frames + 1 % self.imdsinfo.energies) == 0 + ): + data_elements.add("energies") + self._expect_header( + expected_type=IMDType.IMD_ENERGIES, expected_value=1 + ) + self._recv_data(self._byte_dict["energies"]) + + if ( + "dimensions" in self._byte_dict + and (self.parsed_frames + 1 % self.imdsinfo.dimensions) == 0 + ): + data_elements.add("dimensions") + self._expect_header( + expected_type=IMDType.IMD_BOX, + ) + + if ( + "positions" in self._byte_dict + and (self.parsed_frames + 1 % self.imdsinfo.positions) == 0 + ): + data_elements.add("positions") + self._expect_header( + expected_type=IMDType.IMD_FCOORDS, + expected_value=self.n_atoms, + ) + self._recv_data(self._byte_dict["positions"]) + + if ( + "velocities" in self._byte_dict + and (self.parsed_frames + 1 % self.imdsinfo.velocities) == 0 + ): + data_elements.add("velocities") + self._expect_header( + expected_type=IMDType.IMD_VELS, + expected_value=self.n_atoms, + ) + + if ( + "forces" in self._byte_dict + and (self.parsed_frames + 1 % self.imdsinfo.forces) == 0 + ): + data_elements.add("forces") + self._expect_header( + expected_type=IMDType.IMD_FORCES, + expected_value=self.n_atoms, + ) + + self._full_frames = self._buffer.insert( + self._byte_dict, data_elements + ) + + self.parsed_frames += 1 + # Use the longest parse frame time to calculate + # the timeout for the next frame + self._parse_frame_time = max( + parse_frame.elapsed, self._parse_frame_time + ) + logger.debug( + f"IMDProducer: Added frame #{self.parsed_frames - 1} to buffer in {parse_frame.elapsed} seconds" + ) + + def _check_end_of_simulation(self): + # Peek for an EOS Header packet + header_bytes = self._conn.recv(8, socket.MSG_PEEK) + header = parse_header_bytes(header_bytes) + if header.type == IMDType.IMD_EOS: + logger.debug( + "IMDProducer: Received end of simulation packet, stopping producer" + ) + self.running = False + self._buffer.producer_finished = True + + class CircularByteBuf: """ Acts as interface between producer and consumer threads """ # NOTE: Use 1 buffer for pos, vel, force rather than 3 - def __init__(self, buffer_size, n_atoms, ts): - # a frame is the number of bytes needed to hold - # energies + positions - self._frame_size = 40 + (n_atoms * 12) + def __init__(self, buffer_size, n_atoms, timestep, imdsinfo): self._n_atoms = n_atoms - self._ts = ts + self._buffer_size = buffer_size + # Syncing reader and producer + self._producer_finished = False + + self._ts = timestep + self.imdsinfo = imdsinfo + + self._frame_size = 1 + # framesize is the number of bytes needed to hold 1 simulation frame in the buffer + # one byte is used to hold flags for the type of data in the frame + + # Even though every simulation step might not contain all data, + # we allocate space for all data in every frame. This is memory-innefficient, + # but works for the majority of use cases, so we can optimize later if needed + # this could be optimized by creating a timestep buffer, some kind of byte queue, + # or a circular buffer with wrapped inserts and reads + if self.imdsinfo.energies > 0: + self._frame_size += 40 + if self.imdsinfo.dimensions > 0: + self._frame_size += 36 + if self.imdsinfo.positions > 0: + self._frame_size += self._n_atoms * 12 + if self.imdsinfo.velocities > 0: + self._frame_size += self._n_atoms * 12 + if self.imdsinfo.forces > 0: + self._frame_size += self._n_atoms * 12 + self._ts_buf = bytearray(self._frame_size) - self._body_bytes = n_atoms * 12 + self._data_bytes = self._n_atoms * 12 - if self._frame_size > buffer_size: + # offsets for each data element in the frame + self._energy_offset = 1 + self._dim_offset = 41 + self._pos_offset = 77 + self._vel_offset = self._pos_offset + (self._data_bytes) + self._force_offset = self._vel_offset + (self._data_bytes) + + if self._frame_size > self._buffer_size: raise MemoryError( - f"Requested buffer size of {buffer_size} " - + f"doesn't meet memory requirement of {self._frame_size} " - + f"(energy and position data for {n_atoms})" + f"Requested buffer size of {self._buffer_size} " + + f"doesn't meet memory requirement of {self._frame_size}" ) - self._buf_frames = (buffer_size) // self._frame_size + self._buf_frames = (self._buffer_size) // self._frame_size self._buf = bytearray(self._buf_frames * self._frame_size) self._memview = memoryview(self._buf) @@ -564,28 +697,57 @@ def __init__(self, buffer_size, n_atoms, ts): self._t2 = None self._start = True self._analyze_frame_time = None - self._frame = 0 - # Syncing reader and producer - self._producer_finished = False + self._frame = 0 - def inform_endianness(self, endianness): - """Producer thread must determine - endianness before sending data to the buffer""" - self._end = endianness + def insert(self, byte_dict, data_elements): + """byte_dict is a dictionary of bytearrays + where the keys are the names of the data elements + and the values are the bytearrays of the data elements for that frame - def insert(self, energy, pos): + data_elements is a set of the names of the data elements that are present in the frame + """ with self._not_full: while not self._empty: self._not_full.wait() - self._memview[self._fill : self._fill + IMDENERGYPACKETLENGTH] = ( - energy - ) - self._fill += IMDENERGYPACKETLENGTH + flags = 0 + self._memview[self._fill] = flags + self._fill += 1 + if "energies" in data_elements: + self._memview[ + self._fill : self._fill + IMDENERGYPACKETLENGTH + ] = byte_dict["energies"][:] + self._fill += IMDENERGYPACKETLENGTH + flags |= 1 << 4 + + if "dimensions" in data_elements: + self._memview[self._fill : self._fill + 36] = byte_dict[ + "dimensions" + ][:] + self._fill += 36 + flags |= 1 << 3 + + if "positions" in data_elements: + self._memview[self._fill : self._fill + self._data_bytes] = ( + byte_dict["positions"][:] + ) + self._fill += self._data_bytes + flags |= 1 << 2 - self._memview[self._fill : self._fill + self._body_bytes] = pos - self._fill += self._body_bytes + if "velocities" in data_elements: + self._memview[self._fill : self._fill + self._data_bytes] = ( + byte_dict["velocities"][:] + ) + self._fill += self._data_bytes + flags |= 1 << 1 + + if "forces" in data_elements: + self._memview[self._fill : self._fill + self._data_bytes] = ( + byte_dict["forces"][:] + ) + self._fill += self._data_bytes + flags |= 1 self._fill %= len(self._buf) @@ -611,6 +773,8 @@ def consume_next_timestep(self): ) self._analyze_frame_time = self._t2 - self._t1 + self._frame += 1 + with self._not_empty: while not self._full and not self.producer_finished: self._not_empty.wait() @@ -619,6 +783,7 @@ def consume_next_timestep(self): if self.producer_finished and not self._full: raise IOError from None + # quickly unload memory to free mutex and later parse self._ts_buf[:] = self._memview[ self._use : self._use + self._frame_size ] @@ -628,43 +793,61 @@ def consume_next_timestep(self): self._empty += 1 self._not_full.notify() - self._ts.data["step"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}i4", offset=0, count=1 - )[0] - # absolute temperature - self._ts.data["temperature"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=4, count=1 - )[0] - self._ts.data["total_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=8, count=1 - )[0] - self._ts.data["potential_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=12, count=1 - )[0] - self._ts.data["van_der_walls_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=16, count=1 - )[0] - self._ts.data["coulomb_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=20, count=1 - )[0] - self._ts.data["bonds_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=24, count=1 - )[0] - self._ts.data["angles_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=28, count=1 - )[0] - self._ts.data["dihedrals_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=32, count=1 - )[0] - self._ts.data["improper_dihedrals_energy"] = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=36, count=1 - )[0] - - self._ts.positions = np.frombuffer( - self._ts_buf, dtype=f"{self._end}f4", offset=40 - ).reshape((self._n_atoms, 3)) + flags = self._ts_buf[0] + + if flags >> 4 & 1: + self._ts.data["step"] = np.frombuffer( + self._ts_buf, + dtype=f"{self.imdinfo.endianness}i4", + offset=self._energy_offset, + )[0] + energydata = ( + "temperature", + "total_energy", + "potential_energy", + "van_der_walls_energy", + "coulomb_energy", + "bonds_energy", + "angles_energy", + "dihedrals_energy", + "improper_dihedrals_energy", + ) + for i, name in enumerate(energydata): + self._ts.data[name] = np.frombuffer( + self._ts_buf, + dtype=f"{self.imdinfo.endianness}f4", + offset=(self._energy_offset + 4 + (i * 4)), + count=1, + )[0] + + if flags >> 3 & 1: + dim = np.frombuffer( + self._ts_buf, dtype=f"{self._end}f4", offset=self._dim_offset + ).reshape((3, 3)) + self._ts.dimensions = core.triclinic_box(*dim) + else: + self._ts.dimensions = None - self._frame += 1 + if flags >> 2 & 1: + self._ts.positions = np.frombuffer( + self._ts_buf, dtype=f"{self._end}f4", offset=self._pos_offset + ).reshape((self._n_atoms, 3)) + else: + self._ts.has_positions = False + + if flags >> 1 & 1: + self._ts.velocities = np.frombuffer( + self._ts_buf, dtype=f"{self._end}f4", offset=self._vel_offset + ).reshape((self._n_atoms, 3)) + else: + self._ts.has_velocities = False + + if flags & 1: + self._ts.forces = np.frombuffer( + self._ts_buf, dtype=f"{self._end}f4", offset=self._force_offset + ).reshape((self._n_atoms, 3)) + else: + self._ts.has_forces = False def wait_almost_empty(self): with self._not_full: diff --git a/imdreader/tests/test_imdreader.py b/imdreader/tests/test_imdreader.py index 107c4e3..bb85fee 100644 --- a/imdreader/tests/test_imdreader.py +++ b/imdreader/tests/test_imdreader.py @@ -88,9 +88,6 @@ def run_gmx(tmpdir): try: p.wait(timeout=10) except subprocess.TimeoutExpired: - logger.error( - "Process did not terminate in time, killing it." - ) p.kill() p.wait() @@ -195,10 +192,10 @@ def test_no_connection(caplog): buffer_size=62000, ) for ts in u.trajectory: - with pytest.raises(ConnectionError): + with pytest.raises(ConnectionRefusedError): pass # NOTE: assert this in output: No connection received. Pausing simulation. - assert "IMDProducer: Connection to localhost:8888 refused" in caplog.text + assert "IMDReader: Connection to localhost:8888 refused" in caplog.text """