Skip to content

Commit

Permalink
use os.cpu_count to translate negative value of worker keyword, but i…
Browse files Browse the repository at this point in the history
…ssue a warning if cpu_count(0 ends up being higher than MKL's max_threads. The warning is only issued once
  • Loading branch information
oleksandr-pavlyk committed Jan 13, 2020
1 parent 26670a4 commit 8b34758
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions mkl_fft/_scipy_fft_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,35 @@
)

from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
from os import cpu_count as os_cpu_count
import warnings

_max_threads_count = mkl.get_max_threads()
class _cpu_max_threads_count:
def __init__(self):
self.cpu_count = None
self.max_threads_count = None

def get_cpu_count(self):
max_threads = self.get_max_threads_count()
if self.cpu_count is None:
self.cpu_count = os_cpu_count()
if self.cpu_count > max_threads:
warnings.warn(
("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
"Using negative values of worker option may amount to requesting more threads than "
"Intel(R) MKL can acommodate."
).format(self.cpu_count, max_threads))
return self.cpu_count

def get_max_threads_count(self):
if self.max_threads_count is None:
self.max_threads_count = mkl.get_max_threads()

return self.max_threads_count


_hardware_counts = _cpu_max_threads_count()


__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
Expand Down Expand Up @@ -113,10 +139,10 @@ def _workers_to_num_threads(w):
if (_w == 0):
raise ValueError("Number of workers must be nonzero")
if (_w < 0):
_w += _max_threads_count + 1
_w += _hardware_counts.get_cpu_count() + 1
if _w <= 0:
raise ValueError("workers value out of range; got {}, must not be"
" less than {}".format(w, -_max_threads_count))
" less than {}".format(w, -_hardware_counts.get_cpu_count()))
return _w


Expand All @@ -133,7 +159,8 @@ def __enter__(self):

def __exit__(self, *args):
# restore default
mkl.domain_set_num_threads(_max_threads_count, domain='fft')
n_threads = _hardware_counts.get_max_threads_count()
mkl.domain_set_num_threads(n_threads, domain='fft')


@_implements(_fft.fft)
Expand Down

0 comments on commit 8b34758

Please sign in to comment.