Skip to content

Commit

Permalink
Added lock to capture_sys_stdout to make it thread safe. Cached XTBAd…
Browse files Browse the repository at this point in the history
…apter.program_version calls since this was consuming too much with with os.listdir() calls inside of importlib.metadata.version.
  • Loading branch information
coltonbh committed Jul 16, 2024
1 parent 80ed659 commit 6d41335
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [unreleased]

### Changed

- Added `threading.Lock()` to `capture_sys_stdout` so that it is thread safe.
- Cached `XTBAdapter.program_version` result since calls to `importlib.metadata.version(...)` call `os.listdir()` and when doing very many `xtb` calls these became substantial (almost 1/2 the execution time). We have to use `importlib.metadata.version` rather than `xtb.__version__` directly because their `__version__` string is wrong.

## [0.7.3] - 2024-07-12

### Added
Expand Down
49 changes: 29 additions & 20 deletions qcop/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import sys
import tempfile
import threading
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -178,36 +179,44 @@ def capture_logs(
logger.removeHandler(handler)


lock = threading.Lock()


@contextmanager
def capture_sys_stdout():
"""Capture stdout from a program that bypasses the sys.stdout object.
Useful for capturing logs written by C/C++ libraries such as xtb.
The lock is necessary since we are modifying a global resource, the stdout file
descriptor, which is shared between threads. Without this lock race conditions
cause random output to be printed to the console and may freeze the program.
"""
# Create a pipe to capture output
r, w = os.pipe()
with lock:
# Create a pipe to capture output
r, w = os.pipe()

# Save the original stdout and stderr file descriptors
try:
stdout_fd = sys.stdout.fileno()
except AttributeError: # Handles case where celery LoggingProxy is sys.stdout
stdout_fd = sys.__stdout__.fileno()
old_stdout = os.dup(stdout_fd)
# Save the original stdout and stderr file descriptors
try:
stdout_fd = sys.stdout.fileno()
except AttributeError: # Handles case where celery LoggingProxy is sys.stdout
stdout_fd = sys.__stdout__.fileno()
old_stdout = os.dup(stdout_fd)

# Redirect stdout and stderr to the write end of the pipe
os.dup2(w, stdout_fd)
try:
yield r # Allow code to be executed within the context manager block
finally:
# Restore stdout and stderr
os.dup2(old_stdout, stdout_fd)
# Redirect stdout and stderr to the write end of the pipe
os.dup2(w, stdout_fd)
try:
yield r # Allow code to be executed within the context manager block
finally:
# Restore stdout and stderr
os.dup2(old_stdout, stdout_fd)

# Close the duplicated fds
os.close(old_stdout)
os.close(w)
# Close the duplicated fds
os.close(old_stdout)
os.close(w)

# Close the read ends of the pipe
os.close(r)
# Close the read ends of the pipe
os.close(r)


def construct_provenance(
Expand Down
13 changes: 12 additions & 1 deletion qcop/adapters/xtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from .base import ProgramAdapter
from .utils import capture_sys_stdout, set_env_variable

# NOTE: Calls to importlib.metadata.version("xtb") are slow due to underlying calls to
# os.listdir(), so we cache the version number here. This is necessary because the
# __version__ string in xtb.__init__.py is wrong so we have to look this up dynamically
# in the program_version method.
CACHED_XTB_VERSION = None


class XTBAdapter(ProgramAdapter[ProgramInput, SinglePointResults]):
"""Adapter for xtb-python."""
Expand Down Expand Up @@ -54,7 +60,12 @@ def program_version(self, stdout: Optional[str] = None) -> str:
Returns:
The program version.
"""
return importlib.metadata.version(self.program)
global CACHED_XTB_VERSION
if CACHED_XTB_VERSION:
return CACHED_XTB_VERSION
else:
CACHED_XTB_VERSION = importlib.metadata.version(self.program)
return CACHED_XTB_VERSION

@staticmethod
def _ensure_xtb():
Expand Down

0 comments on commit 6d41335

Please sign in to comment.