diff --git a/xobjects/context_cupy.py b/xobjects/context_cupy.py index 8c6e3ce..a2b58b2 100644 --- a/xobjects/context_cupy.py +++ b/xobjects/context_cupy.py @@ -632,14 +632,14 @@ def to_pointer_arg(self, offset, nbytes): class KernelCupy(object): - def __init__( - self, function, description, block_size, context, shared_mem_size_bytes - ): + def __init__(self, function, description, block_size, context, + shared_mem_size_bytes): + self.function = function self.description = description self.block_size = block_size - self.context = context self.shared_mem_size_bytes = shared_mem_size_bytes + self.context = context def to_function_arg(self, arg, value): if arg.pointer: @@ -671,7 +671,7 @@ def to_function_arg(self, arg, value): def num_args(self): return len(self.description.args) - def __call__(self, **kwargs): + def __call__(self, shared_mem_size_bytes=None, **kwargs): assert len(kwargs.keys()) == self.num_args arg_list = [] for arg in self.description.args: @@ -683,14 +683,12 @@ def __call__(self, **kwargs): else: n_threads = self.description.n_threads - grid_size = int(np.ceil(n_threads / self.block_size)) - self.function( - (grid_size,), - (self.block_size,), - arg_list, - shared_mem=self.shared_mem_size_bytes, - ) + if shared_mem_size_bytes is None: + shared_mem_size_bytes = self.shared_mem_size_bytes + grid_size = int(np.ceil(n_threads / self.block_size)) + self.function((grid_size,), (self.block_size,), arg_list, + shared_mem=shared_mem_size_bytes) class FFTCupy(object): def __init__(self, context, data, axes):