From 93b263d9950c2c70a94b3183a596a4177a358d76 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Tue, 20 Aug 2024 18:54:51 +0100 Subject: [PATCH 1/6] refactor for fourier use --- .../grids/central_slice_grid.py | 3 +- src/torch_fourier_slice/project.py | 65 ++++++++++++++----- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/torch_fourier_slice/grids/central_slice_grid.py b/src/torch_fourier_slice/grids/central_slice_grid.py index 3b6f2c1..61f93da 100644 --- a/src/torch_fourier_slice/grids/central_slice_grid.py +++ b/src/torch_fourier_slice/grids/central_slice_grid.py @@ -1,10 +1,11 @@ import einops import torch - +from functools import lru_cache from .fftfreq_grid import _construct_fftfreq_grid_2d from ..dft_utils import rfft_shape, fftshift_2d +@lru_cache(1) #Alternativelly, we can have an argument that needs to be propagated to extract_central_slices_rfft_3d def central_slice_fftfreq_grid( volume_shape: tuple[int, int, int], rfft: bool, diff --git a/src/torch_fourier_slice/project.py b/src/torch_fourier_slice/project.py index 1072234..6c0d43f 100644 --- a/src/torch_fourier_slice/project.py +++ b/src/torch_fourier_slice/project.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from typing import List from .grids import fftfreq_grid from .slice_extraction import extract_central_slices_rfft_3d @@ -13,6 +14,51 @@ def project_3d_to_2d( ) -> torch.Tensor: """Project a cubic volume by sampling a central slice through its DFT. + Parameters + ---------- + volume: torch.Tensor + `(d, d, d)` volume. + rotation_matrices: torch.Tensor + `(..., 3, 3)` array of rotation matrices for insert of `images`. + Rotation matrices left-multiply column vectors containing xyz coordinates. + pad: bool + Whether to pad the volume 2x with zeros to increase sampling rate in the DFT. + fftfreq_max: float | None + Maximum frequency (cycles per pixel) included in the projection. + + Returns + ------- + projections: torch.Tensor + `(..., d, d)` array of projection images. + """ + + dft, pad_length = _compute_dft3d_for_project(volume, pad) + + # make projections by taking central slices + projections = extract_central_slices_rfft_3d( + volume_rfft=dft, + image_shape=volume.shape, + rotation_matrices=rotation_matrices, + fftfreq_max=fftfreq_max + ) # (..., h, w) rfft stack + + # transform back to real space + projections = torch.fft.ifftshift(projections, dim=(-2,)) # ifftshift of 2D rfft + projections = torch.fft.irfftn(projections, dim=(-2, -1)) + projections = torch.fft.ifftshift(projections, dim=(-2, -1)) # recenter 2D image in real space + + # unpad if required + if pad is True: + projections = projections[..., pad_length:-pad_length, pad_length:-pad_length] + return torch.real(projections) + + +def _compute_dft3d_for_project( + volume: torch.Tensor, + pad: bool = True, +) -> (torch.Tensor, int): + """Project a cubic volume by sampling a central slice through its DFT. + Parameters ---------- volume: torch.Tensor @@ -31,6 +77,7 @@ def project_3d_to_2d( `(..., d, d)` array of projection images. """ # padding + pad_length = 0 if pad is True: pad_length = volume.shape[-1] // 2 volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) @@ -50,20 +97,8 @@ def project_3d_to_2d( dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of 3D rfft - # make projections by taking central slices - projections = extract_central_slices_rfft_3d( - volume_rfft=dft, - image_shape=volume.shape, - rotation_matrices=rotation_matrices, - fftfreq_max=fftfreq_max - ) # (..., h, w) rfft stack - # transform back to real space - projections = torch.fft.ifftshift(projections, dim=(-2,)) # ifftshift of 2D rfft - projections = torch.fft.irfftn(projections, dim=(-2, -1)) - projections = torch.fft.ifftshift(projections, dim=(-2, -1)) # recenter 2D image in real space + return dft, pad_length + + - # unpad if required - if pad is True: - projections = projections[..., pad_length:-pad_length, pad_length:-pad_length] - return torch.real(projections) From 53487a8bd46099df17631902730b8786419c6906 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Tue, 20 Aug 2024 18:55:07 +0100 Subject: [PATCH 2/6] not satisfactory test --- tests/test_torch_fourier_slice.py | 77 +++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_torch_fourier_slice.py b/tests/test_torch_fourier_slice.py index 4c6e45f..6917bbc 100644 --- a/tests/test_torch_fourier_slice.py +++ b/tests/test_torch_fourier_slice.py @@ -2,3 +2,80 @@ def test_something(): pass + + + +def test_central_slice(): + + import os + import mrcfile + import requests + import tempfile + import torch + from scipy.spatial.transform import Rotation as R + from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d + + tmpdir= tempfile.gettempdir() + fname = os.path.join(tmpdir, "emd_17129.map.gz") + if not os.path.isfile(fname): + response = requests.get("https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-17129/map/emd_17129.map.gz", stream=True) + if response.status_code == 200: + # Open a file in write-binary mode + + with open(fname, "wb") as f: + # Write the content of the response to the file in chunks + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print("Download completed successfully!") + else: + print(f"Failed to download the file. Status code: {response.status_code}") + raise RuntimeError() + volume = torch.as_tensor(mrcfile.read(fname), dtype=torch.float32) + + rot_mats = torch.as_tensor(R.from_euler("ZYZ", [[0,0,0],[0,0,90], [0,90,0]], degrees=True + ).as_matrix(), dtype=torch.float32) + projs = project_3d_to_2d( + volume=volume, + rotation_matrices=rot_mats, + pad=False, + fftfreq_max=None) + + + affine_mats = torch.zeros(rot_mats.shape[0], 3, 4) + affine_mats[:,:3,:3] = rot_mats +# affine_mats[:,:3,-1] += 1./volume.shape[-1] #TODO: It seems that the projections may be off by half a pixel + + + volume = volume[None,None,...].expand(rot_mats.shape[0], -1, -1, -1, -1) + rot_vols = torch.nn.functional.grid_sample(volume, torch.nn.functional.affine_grid(affine_mats, size=volume.shape), align_corners=False) + projs_sum = rot_vols.sum(2).squeeze(1) + + for i in range(projs.shape[0]): + assert torch.isclose(projs[i], projs_sum[i], atol=1e-1).all(), f"Error, disagreement in projections {i}" + + diff = torch.abs(projs - projs_sum) + print(diff.mean(-1).mean(-1)) + +# from matplotlib import pyplot as plt +# f, axes = plt.subplots(3,3) +# axes[0,0].imshow(projs[0]) +# axes[0,1].imshow(projs[1]) +# axes[0,2].imshow(projs[2]) +# +# axes[1,0].imshow(diff[0]) +# axes[1,1].imshow(diff[1]) +# axes[1,2].imshow(diff[2]) +# +# axes[2,0].imshow(projs_sum[0]) +# axes[2,1].imshow(projs_sum[1]) +# axes[2,2].imshow(projs_sum[2]) +# plt.show() + +""" + PYTHONPATH=../torch-image-lerp/src:$PYTHONPATH python -m src.torch_fourier_slice.project +""" +if __name__ == "__main__": + test_central_slice() + """ +PYTHONPATH=../torch-image-lerp/src:src/:$PYTHONPATH python tests/test_torch_fourier_slice.py + """ From ebdbe4125ff863c813d40e4a7c775fe8a69aab23 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Tue, 20 Aug 2024 19:04:41 +0100 Subject: [PATCH 3/6] add dependencies for test --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fe287df..b459e09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ # "extras" (e.g. for `pip install .[test]`) [project.optional-dependencies] # add dependencies used for testing here -test = ["pytest", "pytest-cov"] +test = ["pytest", "pytest-cov", "mrcfile"] # add anything else you like to have in your dev environment here dev = [ "ipython", From fbe4c68088acc951db2edeebaf9a5b3a92458926 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Sep 2024 19:13:56 +0100 Subject: [PATCH 4/6] add mrcfile for test --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b459e09..322ccda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ # "extras" (e.g. for `pip install .[test]`) [project.optional-dependencies] # add dependencies used for testing here -test = ["pytest", "pytest-cov", "mrcfile"] +test = ["pytest", "pytest-cov", "mrcfile", "requests"] # add anything else you like to have in your dev environment here dev = [ "ipython", From 0bacaff87b97d00b7e589b18ae2ee105a5e1eca5 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Sep 2024 19:14:19 +0100 Subject: [PATCH 5/6] add cache for feq_grid and valid_coords --- .../_extract_central_slices_rfft_3d.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py index d216aac..617fc9e 100644 --- a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py +++ b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py @@ -1,16 +1,18 @@ import torch import einops +from functools import lru_cache from torch_image_lerp import sample_image_3d from ..dft_utils import fftfreq_to_dft_coordinates from ..grids.central_slice_grid import central_slice_fftfreq_grid -def extract_central_slices_rfft_3d( - volume_rfft: torch.Tensor, +@lru_cache(1) +def _prepare_extract_central_slices_rfft_3d( image_shape: tuple[int, int, int], - rotation_matrices: torch.Tensor, # (..., 3, 3) - fftfreq_max: float | None = None, + rotation_matrices_shape: torch.Size, # (..., 3, 3) + fftfreq_max: float | None, + device: torch.device ): """Extract central slice from an fftshifted rfft.""" # generate grid of DFT sample frequencies for a central slice spanning the xy-plane @@ -18,11 +20,11 @@ def extract_central_slices_rfft_3d( volume_shape=image_shape, rfft=True, fftshift=True, - device=volume_rfft.device, + device=device, ) # (h, w, 3) zyx coords # keep track of some shapes - stack_shape = tuple(rotation_matrices.shape[:-2]) + stack_shape = tuple(rotation_matrices_shape[:-2]) rfft_shape = freq_grid.shape[-3], freq_grid.shape[-2] output_shape = (*stack_shape, *rfft_shape) @@ -34,7 +36,24 @@ def extract_central_slices_rfft_3d( else: valid_coords = einops.rearrange(freq_grid, 'h w zyx -> (h w) zyx') valid_coords = einops.rearrange(valid_coords, 'b zyx -> b zyx 1') - + return freq_grid, rfft_shape, output_shape, valid_coords + + +def extract_central_slices_rfft_3d( + volume_rfft: torch.Tensor, + image_shape: tuple[int, int, int], + rotation_matrices: torch.Tensor, # (..., 3, 3) + fftfreq_max: float | None = None, +): + """Extract central slice from an fftshifted rfft.""" + # generate grid of DFT sample frequencies for a central slice spanning the xy-plane + + freq_grid, rfft_shape, output_shape, valid_coords = _prepare_extract_central_slices_rfft_3d( + image_shape = image_shape, + rotation_matrices_shape=rotation_matrices.shape, + fftfreq_max = fftfreq_max, + device = volume_rfft.device + ) # rotation matrices rotate xyz coordinates, make them rotate zyx coordinates # xyz: # [a b c] [x] [ax + by + cz] @@ -46,7 +65,6 @@ def extract_central_slices_rfft_3d( # [f e d] [y] = [dx + ey + fz] # [c b a] [x] [ax + by + cz] rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1)) - # add extra dim to rotation matrices for broadcasting rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 i j') From 634178c7e85abbb0f3ffd5f57e368b3282166347 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Sep 2024 19:16:28 +0100 Subject: [PATCH 6/6] update test to use torch_image_lerp --- tests/test_torch_fourier_slice.py | 69 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/test_torch_fourier_slice.py b/tests/test_torch_fourier_slice.py index 6917bbc..c643433 100644 --- a/tests/test_torch_fourier_slice.py +++ b/tests/test_torch_fourier_slice.py @@ -1,7 +1,4 @@ -# temporary stub - -def test_something(): - pass +from torch_fourier_shell_correlation import fsc @@ -31,8 +28,9 @@ def test_central_slice(): print(f"Failed to download the file. Status code: {response.status_code}") raise RuntimeError() volume = torch.as_tensor(mrcfile.read(fname), dtype=torch.float32) - - rot_mats = torch.as_tensor(R.from_euler("ZYZ", [[0,0,0],[0,0,90], [0,90,0]], degrees=True + + eulersDegs = [[0,0,0],[0,0,90], [0,90,0], [0,45,45]] + rot_mats = torch.as_tensor(R.from_euler("ZYZ", eulersDegs, degrees=True ).as_matrix(), dtype=torch.float32) projs = project_3d_to_2d( volume=volume, @@ -40,40 +38,47 @@ def test_central_slice(): pad=False, fftfreq_max=None) + + from torch_fourier_slice.grids.fftfreq_grid import fftfreq_grid + from torch_fourier_slice.dft_utils import fftfreq_to_dft_coordinates + + freq_grid = fftfreq_grid(image_shape = volume.shape, rfft = False, fftshift = True, spacing= 1, norm = False, device = "cpu") - affine_mats = torch.zeros(rot_mats.shape[0], 3, 4) - affine_mats[:,:3,:3] = rot_mats -# affine_mats[:,:3,-1] += 1./volume.shape[-1] #TODO: It seems that the projections may be off by half a pixel + rotation_matrices = torch.flip(rot_mats, dims=(-2, -1)) - volume = volume[None,None,...].expand(rot_mats.shape[0], -1, -1, -1, -1) - rot_vols = torch.nn.functional.grid_sample(volume, torch.nn.functional.affine_grid(affine_mats, size=volume.shape), align_corners=False) - projs_sum = rot_vols.sum(2).squeeze(1) + rotated_coords = torch.einsum("b q p, ... p -> b ... q", rotation_matrices, freq_grid) + _rotated_coords = fftfreq_to_dft_coordinates(frequencies=rotated_coords, image_shape=volume.shape, rfft=False) - for i in range(projs.shape[0]): - assert torch.isclose(projs[i], projs_sum[i], atol=1e-1).all(), f"Error, disagreement in projections {i}" + from torch_image_lerp import sample_image_3d + rot_vols = sample_image_3d(image=volume, coordinates=_rotated_coords) + + projs_sum = rot_vols.sum(1) + diff = torch.abs(projs - projs_sum) - print(diff.mean(-1).mean(-1)) + from matplotlib import pyplot as plt + + + for i in range(projs.shape[0]): + _fsc = fsc(projs[i,...], projs_sum[i,...]) +# print(_fsc) + print(diff[i,...].mean(-1).mean(-1)) +# assert torch.isclose(projs[i], projs_sum[i], atol=1e-1).all(), f"Error, disagreement in projections {i}" +# breakpoint() + plt.plot(_fsc, label="euler degs %s"%(eulersDegs[i])) + plt.legend() + plt.show() + + + f, axes = plt.subplots(3,len(diff)) + for i in range(len(diff)): + axes[0,i].imshow(projs[i]) + axes[1,i].imshow(diff[i]) + axes[2,i].imshow(projs_sum[i]) + plt.show() -# from matplotlib import pyplot as plt -# f, axes = plt.subplots(3,3) -# axes[0,0].imshow(projs[0]) -# axes[0,1].imshow(projs[1]) -# axes[0,2].imshow(projs[2]) -# -# axes[1,0].imshow(diff[0]) -# axes[1,1].imshow(diff[1]) -# axes[1,2].imshow(diff[2]) -# -# axes[2,0].imshow(projs_sum[0]) -# axes[2,1].imshow(projs_sum[1]) -# axes[2,2].imshow(projs_sum[2]) -# plt.show() -""" - PYTHONPATH=../torch-image-lerp/src:$PYTHONPATH python -m src.torch_fourier_slice.project -""" if __name__ == "__main__": test_central_slice() """