diff --git a/pyproject.toml b/pyproject.toml index fe287df..322ccda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ # "extras" (e.g. for `pip install .[test]`) [project.optional-dependencies] # add dependencies used for testing here -test = ["pytest", "pytest-cov"] +test = ["pytest", "pytest-cov", "mrcfile", "requests"] # add anything else you like to have in your dev environment here dev = [ "ipython", diff --git a/src/torch_fourier_slice/grids/central_slice_grid.py b/src/torch_fourier_slice/grids/central_slice_grid.py index 3b6f2c1..61f93da 100644 --- a/src/torch_fourier_slice/grids/central_slice_grid.py +++ b/src/torch_fourier_slice/grids/central_slice_grid.py @@ -1,10 +1,11 @@ import einops import torch - +from functools import lru_cache from .fftfreq_grid import _construct_fftfreq_grid_2d from ..dft_utils import rfft_shape, fftshift_2d +@lru_cache(1) #Alternativelly, we can have an argument that needs to be propagated to extract_central_slices_rfft_3d def central_slice_fftfreq_grid( volume_shape: tuple[int, int, int], rfft: bool, diff --git a/src/torch_fourier_slice/project.py b/src/torch_fourier_slice/project.py index 1072234..6c0d43f 100644 --- a/src/torch_fourier_slice/project.py +++ b/src/torch_fourier_slice/project.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import List from .grids import fftfreq_grid from .slice_extraction import extract_central_slices_rfft_3d @@ -13,6 +14,51 @@ def project_3d_to_2d( ) -> torch.Tensor: """Project a cubic volume by sampling a central slice through its DFT. + Parameters + ---------- + volume: torch.Tensor + `(d, d, d)` volume. + rotation_matrices: torch.Tensor + `(..., 3, 3)` array of rotation matrices for insert of `images`. + Rotation matrices left-multiply column vectors containing xyz coordinates. + pad: bool + Whether to pad the volume 2x with zeros to increase sampling rate in the DFT. + fftfreq_max: float | None + Maximum frequency (cycles per pixel) included in the projection. + + Returns + ------- + projections: torch.Tensor + `(..., d, d)` array of projection images. + """ + + dft, pad_length = _compute_dft3d_for_project(volume, pad) + + # make projections by taking central slices + projections = extract_central_slices_rfft_3d( + volume_rfft=dft, + image_shape=volume.shape, + rotation_matrices=rotation_matrices, + fftfreq_max=fftfreq_max + ) # (..., h, w) rfft stack + + # transform back to real space + projections = torch.fft.ifftshift(projections, dim=(-2,)) # ifftshift of 2D rfft + projections = torch.fft.irfftn(projections, dim=(-2, -1)) + projections = torch.fft.ifftshift(projections, dim=(-2, -1)) # recenter 2D image in real space + + # unpad if required + if pad is True: + projections = projections[..., pad_length:-pad_length, pad_length:-pad_length] + return torch.real(projections) + + +def _compute_dft3d_for_project( + volume: torch.Tensor, + pad: bool = True, +) -> (torch.Tensor, int): + """Project a cubic volume by sampling a central slice through its DFT. + Parameters ---------- volume: torch.Tensor @@ -31,6 +77,7 @@ def project_3d_to_2d( `(..., d, d)` array of projection images. """ # padding + pad_length = 0 if pad is True: pad_length = volume.shape[-1] // 2 volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) @@ -50,20 +97,8 @@ def project_3d_to_2d( dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of 3D rfft - # make projections by taking central slices - projections = extract_central_slices_rfft_3d( - volume_rfft=dft, - image_shape=volume.shape, - rotation_matrices=rotation_matrices, - fftfreq_max=fftfreq_max - ) # (..., h, w) rfft stack - # transform back to real space - projections = torch.fft.ifftshift(projections, dim=(-2,)) # ifftshift of 2D rfft - projections = torch.fft.irfftn(projections, dim=(-2, -1)) - projections = torch.fft.ifftshift(projections, dim=(-2, -1)) # recenter 2D image in real space + return dft, pad_length + + - # unpad if required - if pad is True: - projections = projections[..., pad_length:-pad_length, pad_length:-pad_length] - return torch.real(projections) diff --git a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py index d216aac..617fc9e 100644 --- a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py +++ b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py @@ -1,16 +1,18 @@ import torch import einops +from functools import lru_cache from torch_image_lerp import sample_image_3d from ..dft_utils import fftfreq_to_dft_coordinates from ..grids.central_slice_grid import central_slice_fftfreq_grid -def extract_central_slices_rfft_3d( - volume_rfft: torch.Tensor, +@lru_cache(1) +def _prepare_extract_central_slices_rfft_3d( image_shape: tuple[int, int, int], - rotation_matrices: torch.Tensor, # (..., 3, 3) - fftfreq_max: float | None = None, + rotation_matrices_shape: torch.Size, # (..., 3, 3) + fftfreq_max: float | None, + device: torch.device ): """Extract central slice from an fftshifted rfft.""" # generate grid of DFT sample frequencies for a central slice spanning the xy-plane @@ -18,11 +20,11 @@ def extract_central_slices_rfft_3d( volume_shape=image_shape, rfft=True, fftshift=True, - device=volume_rfft.device, + device=device, ) # (h, w, 3) zyx coords # keep track of some shapes - stack_shape = tuple(rotation_matrices.shape[:-2]) + stack_shape = tuple(rotation_matrices_shape[:-2]) rfft_shape = freq_grid.shape[-3], freq_grid.shape[-2] output_shape = (*stack_shape, *rfft_shape) @@ -34,7 +36,24 @@ def extract_central_slices_rfft_3d( else: valid_coords = einops.rearrange(freq_grid, 'h w zyx -> (h w) zyx') valid_coords = einops.rearrange(valid_coords, 'b zyx -> b zyx 1') - + return freq_grid, rfft_shape, output_shape, valid_coords + + +def extract_central_slices_rfft_3d( + volume_rfft: torch.Tensor, + image_shape: tuple[int, int, int], + rotation_matrices: torch.Tensor, # (..., 3, 3) + fftfreq_max: float | None = None, +): + """Extract central slice from an fftshifted rfft.""" + # generate grid of DFT sample frequencies for a central slice spanning the xy-plane + + freq_grid, rfft_shape, output_shape, valid_coords = _prepare_extract_central_slices_rfft_3d( + image_shape = image_shape, + rotation_matrices_shape=rotation_matrices.shape, + fftfreq_max = fftfreq_max, + device = volume_rfft.device + ) # rotation matrices rotate xyz coordinates, make them rotate zyx coordinates # xyz: # [a b c] [x] [ax + by + cz] @@ -46,7 +65,6 @@ def extract_central_slices_rfft_3d( # [f e d] [y] = [dx + ey + fz] # [c b a] [x] [ax + by + cz] rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1)) - # add extra dim to rotation matrices for broadcasting rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 i j') diff --git a/tests/test_torch_fourier_slice.py b/tests/test_torch_fourier_slice.py index 4c6e45f..c643433 100644 --- a/tests/test_torch_fourier_slice.py +++ b/tests/test_torch_fourier_slice.py @@ -1,4 +1,86 @@ -# temporary stub +from torch_fourier_shell_correlation import fsc -def test_something(): - pass + + +def test_central_slice(): + + import os + import mrcfile + import requests + import tempfile + import torch + from scipy.spatial.transform import Rotation as R + from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d + + tmpdir= tempfile.gettempdir() + fname = os.path.join(tmpdir, "emd_17129.map.gz") + if not os.path.isfile(fname): + response = requests.get("https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-17129/map/emd_17129.map.gz", stream=True) + if response.status_code == 200: + # Open a file in write-binary mode + + with open(fname, "wb") as f: + # Write the content of the response to the file in chunks + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print("Download completed successfully!") + else: + print(f"Failed to download the file. Status code: {response.status_code}") + raise RuntimeError() + volume = torch.as_tensor(mrcfile.read(fname), dtype=torch.float32) + + eulersDegs = [[0,0,0],[0,0,90], [0,90,0], [0,45,45]] + rot_mats = torch.as_tensor(R.from_euler("ZYZ", eulersDegs, degrees=True + ).as_matrix(), dtype=torch.float32) + projs = project_3d_to_2d( + volume=volume, + rotation_matrices=rot_mats, + pad=False, + fftfreq_max=None) + + + from torch_fourier_slice.grids.fftfreq_grid import fftfreq_grid + from torch_fourier_slice.dft_utils import fftfreq_to_dft_coordinates + + freq_grid = fftfreq_grid(image_shape = volume.shape, rfft = False, fftshift = True, spacing= 1, norm = False, device = "cpu") + + + rotation_matrices = torch.flip(rot_mats, dims=(-2, -1)) + + rotated_coords = torch.einsum("b q p, ... p -> b ... q", rotation_matrices, freq_grid) + _rotated_coords = fftfreq_to_dft_coordinates(frequencies=rotated_coords, image_shape=volume.shape, rfft=False) + + from torch_image_lerp import sample_image_3d + rot_vols = sample_image_3d(image=volume, coordinates=_rotated_coords) + + + projs_sum = rot_vols.sum(1) + + diff = torch.abs(projs - projs_sum) + from matplotlib import pyplot as plt + + + for i in range(projs.shape[0]): + _fsc = fsc(projs[i,...], projs_sum[i,...]) +# print(_fsc) + print(diff[i,...].mean(-1).mean(-1)) +# assert torch.isclose(projs[i], projs_sum[i], atol=1e-1).all(), f"Error, disagreement in projections {i}" +# breakpoint() + plt.plot(_fsc, label="euler degs %s"%(eulersDegs[i])) + plt.legend() + plt.show() + + + f, axes = plt.subplots(3,len(diff)) + for i in range(len(diff)): + axes[0,i].imshow(projs[i]) + axes[1,i].imshow(diff[i]) + axes[2,i].imshow(projs_sum[i]) + plt.show() + + +if __name__ == "__main__": + test_central_slice() + """ +PYTHONPATH=../torch-image-lerp/src:src/:$PYTHONPATH python tests/test_torch_fourier_slice.py + """