diff --git a/mkl_fft/_scipy_fft_backend.py b/mkl_fft/_scipy_fft_backend.py index 646a093..752fd51 100644 --- a/mkl_fft/_scipy_fft_backend.py +++ b/mkl_fft/_scipy_fft_backend.py @@ -39,9 +39,35 @@ ) from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod) +from os import cpu_count as os_cpu_count +import warnings -_max_threads_count = mkl.get_max_threads() +class _cpu_max_threads_count: + def __init__(self): + self.cpu_count = None + self.max_threads_count = None + def get_cpu_count(self): + max_threads = self.get_max_threads_count() + if self.cpu_count is None: + self.cpu_count = os_cpu_count() + if self.cpu_count > max_threads: + warnings.warn( + ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. " + "Using negative values of worker option may amount to requesting more threads than " + "Intel(R) MKL can acommodate." + ).format(self.cpu_count, max_threads)) + return self.cpu_count + + def get_max_threads_count(self): + if self.max_threads_count is None: + self.max_threads_count = mkl.get_max_threads() + + return self.max_threads_count + + +_hardware_counts = _cpu_max_threads_count() + __all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn', @@ -113,10 +139,10 @@ def _workers_to_num_threads(w): if (_w == 0): raise ValueError("Number of workers must be nonzero") if (_w < 0): - _w += _max_threads_count + 1 + _w += _hardware_counts.get_cpu_count() + 1 if _w <= 0: raise ValueError("workers value out of range; got {}, must not be" - " less than {}".format(w, -_max_threads_count)) + " less than {}".format(w, -_hardware_counts.get_cpu_count())) return _w @@ -133,7 +159,8 @@ def __enter__(self): def __exit__(self, *args): # restore default - mkl.domain_set_num_threads(_max_threads_count, domain='fft') + n_threads = _hardware_counts.get_max_threads_count() + mkl.domain_set_num_threads(n_threads, domain='fft') @_implements(_fft.fft)