diff --git a/CHANGELOG.md b/CHANGELOG.md index 11663d0..d23b843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [unreleased] +### Changed + +- Added `threading.Lock()` to `capture_sys_stdout` so that it is thread safe. +- Cached `XTBAdapter.program_version` result since calls to `importlib.metadata.version(...)` call `os.listdir()` and when doing very many `xtb` calls these became substantial (almost 1/2 the execution time). We have to use `importlib.metadata.version` rather than `xtb.__version__` directly because their `__version__` string is wrong. + ## [0.7.3] - 2024-07-12 ### Added diff --git a/qcop/adapters/utils.py b/qcop/adapters/utils.py index baf1c45..c0bac49 100644 --- a/qcop/adapters/utils.py +++ b/qcop/adapters/utils.py @@ -10,6 +10,7 @@ import subprocess import sys import tempfile +import threading from contextlib import contextmanager from io import StringIO from pathlib import Path @@ -178,36 +179,44 @@ def capture_logs( logger.removeHandler(handler) +lock = threading.Lock() + + @contextmanager def capture_sys_stdout(): """Capture stdout from a program that bypasses the sys.stdout object. Useful for capturing logs written by C/C++ libraries such as xtb. + + The lock is necessary since we are modifying a global resource, the stdout file + descriptor, which is shared between threads. Without this lock race conditions + cause random output to be printed to the console and may freeze the program. """ - # Create a pipe to capture output - r, w = os.pipe() + with lock: + # Create a pipe to capture output + r, w = os.pipe() - # Save the original stdout and stderr file descriptors - try: - stdout_fd = sys.stdout.fileno() - except AttributeError: # Handles case where celery LoggingProxy is sys.stdout - stdout_fd = sys.__stdout__.fileno() - old_stdout = os.dup(stdout_fd) + # Save the original stdout and stderr file descriptors + try: + stdout_fd = sys.stdout.fileno() + except AttributeError: # Handles case where celery LoggingProxy is sys.stdout + stdout_fd = sys.__stdout__.fileno() + old_stdout = os.dup(stdout_fd) - # Redirect stdout and stderr to the write end of the pipe - os.dup2(w, stdout_fd) - try: - yield r # Allow code to be executed within the context manager block - finally: - # Restore stdout and stderr - os.dup2(old_stdout, stdout_fd) + # Redirect stdout and stderr to the write end of the pipe + os.dup2(w, stdout_fd) + try: + yield r # Allow code to be executed within the context manager block + finally: + # Restore stdout and stderr + os.dup2(old_stdout, stdout_fd) - # Close the duplicated fds - os.close(old_stdout) - os.close(w) + # Close the duplicated fds + os.close(old_stdout) + os.close(w) - # Close the read ends of the pipe - os.close(r) + # Close the read ends of the pipe + os.close(r) def construct_provenance( diff --git a/qcop/adapters/xtb.py b/qcop/adapters/xtb.py index 7546046..96b9947 100644 --- a/qcop/adapters/xtb.py +++ b/qcop/adapters/xtb.py @@ -20,6 +20,12 @@ from .base import ProgramAdapter from .utils import capture_sys_stdout, set_env_variable +# NOTE: Calls to importlib.metadata.version("xtb") are slow due to underlying calls to +# os.listdir(), so we cache the version number here. This is necessary because the +# __version__ string in xtb.__init__.py is wrong so we have to look this up dynamically +# in the program_version method. +CACHED_XTB_VERSION = None + class XTBAdapter(ProgramAdapter[ProgramInput, SinglePointResults]): """Adapter for xtb-python.""" @@ -54,7 +60,12 @@ def program_version(self, stdout: Optional[str] = None) -> str: Returns: The program version. """ - return importlib.metadata.version(self.program) + global CACHED_XTB_VERSION + if CACHED_XTB_VERSION: + return CACHED_XTB_VERSION + else: + CACHED_XTB_VERSION = importlib.metadata.version(self.program) + return CACHED_XTB_VERSION @staticmethod def _ensure_xtb():