Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defining generic wrapper decorators #69

Open
rsanchezgarc opened this issue Mar 6, 2024 · 3 comments
Open

Defining generic wrapper decorators #69

rsanchezgarc opened this issue Mar 6, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@rsanchezgarc
Copy link
Collaborator

I have been doing some experiments using different decorators to speed the code (e.g., functools.lru_cache, torch.compile...), and I think that we should consider a generic mechanism to let the user to optionally activate and deactivate this kind of decorators. Something like this

from functools import lru_cache

import torch

decorators_blacklist = ["torch.compile"] #This should be stored in some global config. Potentially accesible using an ENV_VAR

def decorator_manager(*decorators):
    def wrapper(func):
        for name, decorator in reversed(decorators):
            if name not in decorators_blacklist:
                func = decorator(func)
        return func
    return wrapper


@decorator_manager(("lru_cache", lru_cache(maxsize=32)), ("torch.no_grad", torch.no_grad()), ("torch.compile", torch.compile))
def torch_op(i):
    return torch.rand(i)

torch_op(1)
print(torch_op(2))

By doing in that way, the programmer can choose what are the potentially useful decorators for a given function, but the user can disable those that are not useful for them.

What do you think?

Do you already have some way of defining/modifying config variables?

@rsanchezgarc rsanchezgarc added the enhancement New feature or request label Mar 6, 2024
@alisterburt
Copy link
Member

alisterburt commented Mar 8, 2024

Interesting! I agree this could provide major performance benefits for minimal cost - could you outline a few places this might be useful specifically?

I haven't personally used this in projects but have heard good things about hydra https://hydra.cc/docs/intro/ - might be too much here

I'm a bit hesitant to introduce state for config into libtilt as it can make it harder to help anyone debug things, let's move cautiously here

@rsanchezgarc
Copy link
Collaborator Author

Interesting! I agree this could provide major performance benefits for minimal cost - could you outline a few places this might be useful specifically?

The perfect example is in the circle() function to create masks or where you initialize things based on the shape, that will probably be the same many times. I will provide a list of all the functions that I have spotted next week.

I haven't personally used this in projects but have heard good things about hydra https://hydra.cc/docs/intro/ - might be too much here

Same. The closer I use is the pytorch lightning cli, that uses config files and command line arguments, but so far we don't really need a lot of config stuff, so probably we do not need something that advanced.

I'm a bit hesitant to introduce state for config into libtilt as it can make it harder to help anyone debug things, let's move cautiously here

That is understandable. I guess that, for the moment we may just have a plain config.py that is meant to contain "constants" not to be touched. What would you think of this?

@rsanchezgarc
Copy link
Collaborator Author

@alisterburt these are the places where I used lru_cache. The idea is to cache everything that depends only on the image shape, as it is very unlikely that you used two different image sizes

src/libtilt/grids/central_slice_grid.py:
@lru_cache(1)
def central_slice_grid(
    image_shape: tuple[int, int, int],
    rotation_matrix_zyx: bool,
    rfft: bool,
    fftshift: bool = False,
    device: torch.device | None = None,
) -> torch.Tensor:
src/libtilt/grids/fftfreq_grid.py
@functools.lru_cache(maxsize=1)
def fftfreq_grid(
    image_shape: tuple[int, int] | tuple[int, int, int],
    rfft: bool,
    fftshift: bool = False,
    spacing: float | tuple[float, float] | tuple[float, float, float] = 1,
    norm: bool = False,
    device: torch.device | None = None,
):

src/libtilt/fft_utils.py

@functools.lru_cache(1)
def rfft_shape(input_shape: Sequence[int]) -> Tuple[int]:
    """Get the output shape of an rfft on an input with input_shape."""
    rfft_shape = list(input_shape)
    rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1)
    return tuple(rfft_shape)

@functools.lru_cache(maxsize=1)
def fft_sizes(lower_bound: int = 0) -> torch.LongTensor:


##The next function was extracted from fftfreq_to_dft_coordinates to use the cache. It is not there in the original code
@functools.lru_cache(1)
def _prepare_shapes(image_shape, device, dtype):
    _image_shape = image_shape
    image_shape = torch.as_tensor(
        _image_shape, device=device, dtype=dtype
    )
    _rfft_shape = rfft_shape(_image_shape)
    _rfft_shape = torch.as_tensor(
        _rfft_shape, device=device, dtype=dtype
    )
    return image_shape, _rfft_shape
src/libtilt/ctf/ctf_2d.py
##I extracted this function from calculate_ctf to implement the cache
@lru_cache(1)
def prepare_near_constant_params(
        voltage: float,
        spherical_aberration: float,
        amplitude_contrast: float,
        b_factor: float,
        phase_shift: float,
        pixel_size: float,
        image_shape: Tuple[int, int],
        rfft: bool,
        fftshift: bool,
        device: torch.device | None = None):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants