Skip to content

Commit

Permalink
Merge pull request #47 from IntelPython/adjust-workers-behavior-in-ff…
Browse files Browse the repository at this point in the history
…t-backend

Adjusted workes to threads logic to agree with what is in scipy.fft
  • Loading branch information
oleksandr-pavlyk authored Jan 14, 2020
2 parents 357a0b7 + 8b34758 commit 2d79309
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions mkl_fft/_scipy_fft_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,35 @@
)

from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
from os import cpu_count as os_cpu_count
import warnings

class _cpu_max_threads_count:
def __init__(self):
self.cpu_count = None
self.max_threads_count = None

def get_cpu_count(self):
max_threads = self.get_max_threads_count()
if self.cpu_count is None:
self.cpu_count = os_cpu_count()
if self.cpu_count > max_threads:
warnings.warn(
("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
"Using negative values of worker option may amount to requesting more threads than "
"Intel(R) MKL can acommodate."
).format(self.cpu_count, max_threads))
return self.cpu_count

def get_max_threads_count(self):
if self.max_threads_count is None:
self.max_threads_count = mkl.get_max_threads()

return self.max_threads_count


_hardware_counts = _cpu_max_threads_count()


__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
Expand Down Expand Up @@ -101,9 +130,20 @@ def _tot_size(x, axes):


def _workers_to_num_threads(w):
"""Handle conversion of workers to a positive number of threads in the
same way as scipy.fft.helpers._workers.
"""
if w is None:
return mkl.domain_get_max_threads(domain='fft')
return int(w)
return get_workers()
_w = int(w)
if (_w == 0):
raise ValueError("Number of workers must be nonzero")
if (_w < 0):
_w += _hardware_counts.get_cpu_count() + 1
if _w <= 0:
raise ValueError("workers value out of range; got {}, must not be"
" less than {}".format(w, -_hardware_counts.get_cpu_count()))
return _w


class Workers:
Expand All @@ -119,8 +159,8 @@ def __enter__(self):

def __exit__(self, *args):
# restore default
max_num_threads = mkl.domain_get_max_threads(domain='fft')
mkl.domain_set_num_threads(max_num_threads, domain='fft')
n_threads = _hardware_counts.get_max_threads_count()
mkl.domain_set_num_threads(n_threads, domain='fft')


@_implements(_fft.fft)
Expand Down

0 comments on commit 2d79309

Please sign in to comment.