From 26670a444c98020d8c6489f72534fe43ef621fb6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 13 Jan 2020 11:58:43 -0600 Subject: [PATCH 1/2] Adjusted workes to threads logic to agree with what is in scipy.fft --- mkl_fft/_scipy_fft_backend.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mkl_fft/_scipy_fft_backend.py b/mkl_fft/_scipy_fft_backend.py index 102a577..646a093 100644 --- a/mkl_fft/_scipy_fft_backend.py +++ b/mkl_fft/_scipy_fft_backend.py @@ -40,6 +40,9 @@ from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod) +_max_threads_count = mkl.get_max_threads() + + __all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn', 'hfft', 'ihfft', 'hfft2', 'ihfft2', 'hfftn', 'ihfftn', @@ -101,9 +104,20 @@ def _tot_size(x, axes): def _workers_to_num_threads(w): + """Handle conversion of workers to a positive number of threads in the + same way as scipy.fft.helpers._workers. + """ if w is None: - return mkl.domain_get_max_threads(domain='fft') - return int(w) + return get_workers() + _w = int(w) + if (_w == 0): + raise ValueError("Number of workers must be nonzero") + if (_w < 0): + _w += _max_threads_count + 1 + if _w <= 0: + raise ValueError("workers value out of range; got {}, must not be" + " less than {}".format(w, -_max_threads_count)) + return _w class Workers: @@ -119,8 +133,7 @@ def __enter__(self): def __exit__(self, *args): # restore default - max_num_threads = mkl.domain_get_max_threads(domain='fft') - mkl.domain_set_num_threads(max_num_threads, domain='fft') + mkl.domain_set_num_threads(_max_threads_count, domain='fft') @_implements(_fft.fft) From 8b34758e3a31acb00f2228ef983bdf03ea993f36 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 13 Jan 2020 14:40:04 -0600 Subject: [PATCH 2/2] use os.cpu_count to translate negative value of worker keyword, but issue a warning if cpu_count(0 ends up being higher than MKL's max_threads. The warning is only issued once --- mkl_fft/_scipy_fft_backend.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/mkl_fft/_scipy_fft_backend.py b/mkl_fft/_scipy_fft_backend.py index 646a093..752fd51 100644 --- a/mkl_fft/_scipy_fft_backend.py +++ b/mkl_fft/_scipy_fft_backend.py @@ -39,9 +39,35 @@ ) from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod) +from os import cpu_count as os_cpu_count +import warnings -_max_threads_count = mkl.get_max_threads() +class _cpu_max_threads_count: + def __init__(self): + self.cpu_count = None + self.max_threads_count = None + def get_cpu_count(self): + max_threads = self.get_max_threads_count() + if self.cpu_count is None: + self.cpu_count = os_cpu_count() + if self.cpu_count > max_threads: + warnings.warn( + ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. " + "Using negative values of worker option may amount to requesting more threads than " + "Intel(R) MKL can acommodate." + ).format(self.cpu_count, max_threads)) + return self.cpu_count + + def get_max_threads_count(self): + if self.max_threads_count is None: + self.max_threads_count = mkl.get_max_threads() + + return self.max_threads_count + + +_hardware_counts = _cpu_max_threads_count() + __all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn', @@ -113,10 +139,10 @@ def _workers_to_num_threads(w): if (_w == 0): raise ValueError("Number of workers must be nonzero") if (_w < 0): - _w += _max_threads_count + 1 + _w += _hardware_counts.get_cpu_count() + 1 if _w <= 0: raise ValueError("workers value out of range; got {}, must not be" - " less than {}".format(w, -_max_threads_count)) + " less than {}".format(w, -_hardware_counts.get_cpu_count())) return _w @@ -133,7 +159,8 @@ def __enter__(self): def __exit__(self, *args): # restore default - mkl.domain_set_num_threads(_max_threads_count, domain='fft') + n_threads = _hardware_counts.get_max_threads_count() + mkl.domain_set_num_threads(n_threads, domain='fft') @_implements(_fft.fft)