Skip to content

Commit

Permalink
feat: Allowed option to reflectively pad projection DFT for reduced h…
Browse files Browse the repository at this point in the history
…igh res artifacts
  • Loading branch information
jdickerson95 committed Nov 7, 2024
1 parent 093c87e commit 35677d6
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 55 deletions.
21 changes: 16 additions & 5 deletions src/torch_fourier_slice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Fourier slice slice_extraction/slice_insertion from 2D images and 3D volumes in PyTorch."""
"""Fourier slice extraction/insertion from 2D images and 3D volumes in PyTorch."""

from importlib.metadata import PackageNotFoundError, version

Expand All @@ -9,7 +9,18 @@
__author__ = "Alister Burt"
__email__ = "[email protected]"

from .project import project_3d_to_2d
from .backproject import backproject_2d_to_3d
from .slice_insertion import insert_central_slices_rfft_3d
from .slice_extraction import extract_central_slices_rfft_3d
__all__ = [
"backproject_2d_to_3d",
"project_3d_to_2d",
"extract_central_slices_rfft_3d",
"insert_central_slices_rfft_3d",
]

from torch_fourier_slice.backproject import backproject_2d_to_3d
from torch_fourier_slice.project import project_3d_to_2d
from torch_fourier_slice.slice_extraction._extract_central_slices_rfft_3d import (
extract_central_slices_rfft_3d,
)
from torch_fourier_slice.slice_insertion._insert_central_slices_rfft_3d import (
insert_central_slices_rfft_3d,
)
26 changes: 16 additions & 10 deletions src/torch_fourier_slice/backproject.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""Module for 2D to 3D backprojection operations using Fourier slice theorem."""

import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .slice_insertion import insert_central_slices_rfft_3d
from torch_fourier_slice.slice_insertion._insert_central_slices_rfft_3d import (
insert_central_slices_rfft_3d,
)


def backproject_2d_to_3d(
images: torch.Tensor, # (b, h, w)
rotation_matrices: torch.Tensor, # (b, 3, 3)
pad: bool = True,
fftfreq_max: float | None = None,
):
) -> torch.Tensor:
"""Perform a 3D reconstruction from a set of 2D projection images.
Parameters
Expand All @@ -32,7 +36,7 @@ def backproject_2d_to_3d(
"""
b, h, w = images.shape
if h != w:
raise ValueError('images must be square.')
raise ValueError("images must be square.")
if pad is True:
p = images.shape[-1] // 4
images = F.pad(images, pad=[p] * 4)
Expand All @@ -51,25 +55,27 @@ def backproject_2d_to_3d(
image_rfft=images,
volume_shape=volume_shape,
rotation_matrices=rotation_matrices,
fftfreq_max=fftfreq_max
fftfreq_max=fftfreq_max,
)

# reweight reconstruction
valid_weights = weights > 1e-3
dft[valid_weights] /= weights[valid_weights]

# back to real space
dft = torch.fft.ifftshift(dft, dim=(-3, -2,)) # actual ifftshift
dft = torch.fft.ifftshift(
dft,
dim=(
-3,
-2,
),
) # actual ifftshift
dft = torch.fft.irfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.ifftshift(dft, dim=(-3, -2, -1)) # center in real space

# correct for convolution with linear interpolation kernel
grid = fftfreq_grid(
image_shape=dft.shape,
rfft=False,
fftshift=True,
norm=True,
device=dft.device
image_shape=dft.shape, rfft=False, fftshift=True, norm=True, device=dft.device
)
dft = dft / torch.sinc(grid) ** 2

Expand Down
26 changes: 22 additions & 4 deletions src/torch_fourier_slice/dft_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utility functions for Fourier slice operations."""

from typing import Sequence, Tuple

import torch
Expand All @@ -11,6 +13,7 @@ def rfft_shape(input_shape: Sequence[int]) -> Tuple[int, ...]:


def fftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
"""Forward fftshift of a 2D tensor."""
if rfft is False:
output = torch.fft.fftshift(input, dim=(-2, -1))
else:
Expand All @@ -19,6 +22,7 @@ def fftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:


def ifftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
"""Inverse fftshift of a 2D tensor."""
if rfft is False:
output = torch.fft.ifftshift(input, dim=(-2, -1))
else:
Expand All @@ -27,24 +31,38 @@ def ifftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:


def fftshift_3d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
"""Forward fftshift of a 3D tensor."""
if rfft is False:
output = torch.fft.fftshift(input, dim=(-3, -2, -1))
else:
output = torch.fft.fftshift(input, dim=(-3, -2,))
output = torch.fft.fftshift(
input,
dim=(
-3,
-2,
),
)
return output


def ifftshift_3d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
"""Inverse fftshift of a 3D tensor."""
if rfft is False:
output = torch.fft.ifftshift(input, dim=(-3, -2, -1))
else:
output = torch.fft.ifftshift(input, dim=(-3, -2,))
output = torch.fft.ifftshift(
input,
dim=(
-3,
-2,
),
)
return output


def fftfreq_to_dft_coordinates(
frequencies: torch.Tensor, image_shape: tuple[int, ...], rfft: bool
):
) -> torch.Tensor:
"""Convert DFT sample frequencies into array coordinates in a fftshifted DFT.
Parameters
Expand Down Expand Up @@ -90,7 +108,7 @@ def dft_center(
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
fft_center = torch.divide(image_shape, 2, rounding_mode="floor")
if rfft is True:
fft_center[-1] = 0
return fft_center.long()
54 changes: 48 additions & 6 deletions src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
"""Module for 3D to 2D projection operations using Fourier slice theorem.
This module provides functionality to project 3D volumes into 2D images using
the Fourier slice theorem, which states that the 2D Fourier transform of a
projection is equal to a central slice through the 3D Fourier transform of
the volume.
"""

import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .slice_extraction import extract_central_slices_rfft_3d
from torch_fourier_slice.slice_extraction._extract_central_slices_rfft_3d import (
extract_central_slices_rfft_3d,
)


def project_3d_to_2d(
volume: torch.Tensor,
rotation_matrices: torch.Tensor,
pad: bool = True,
fftfreq_max: float | None = None,
edge_pad_fraction: float = 0.0,
) -> torch.Tensor:
"""Project a cubic volume by sampling a central slice through its DFT.
Expand All @@ -24,6 +35,10 @@ def project_3d_to_2d(
Whether to pad the volume 2x with zeros to increase sampling rate in the DFT.
fftfreq_max: float | None
Maximum frequency (cycles per pixel) included in the projection.
edge_pad_fraction: float
Fraction of the dft edge to pad with reflective padding.
This is useful for reducing artifacts from the DFT edge.
Recommended values are 0.1 - 0.25.
Returns
-------
Expand All @@ -33,37 +48,64 @@ def project_3d_to_2d(
# padding
if pad is True:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)
volume = F.pad(volume, pad=[pad_length] * 6, mode="constant", value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
device=volume.device,
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of 3D rfft
dft = torch.fft.fftshift(
dft,
dim=(
-3,
-2,
),
) # actual fftshift of 3D rfft

# make projections by taking central slices
projections = extract_central_slices_rfft_3d(
volume_rfft=dft,
image_shape=volume.shape,
rotation_matrices=rotation_matrices,
fftfreq_max=fftfreq_max
fftfreq_max=fftfreq_max,
) # (..., h, w) rfft stack

# edge padding
if edge_pad_fraction > 0:
# Use the longer dimension (height) for calculating pad size
edge_pad = int(projections.shape[-2] * edge_pad_fraction)
# For rfft, pad only left, right, and top (not bottom)
projections = torch.nn.functional.pad(
projections,
pad=[
edge_pad, # left
0, # right (no pad as rfft is symmetric)
edge_pad, # top
edge_pad, # bottom
],
mode="reflect",
)

# transform back to real space
projections = torch.fft.ifftshift(projections, dim=(-2,)) # ifftshift of 2D rfft
projections = torch.fft.irfftn(projections, dim=(-2, -1))
projections = torch.fft.ifftshift(projections, dim=(-2, -1)) # recenter 2D image in real space
projections = torch.fft.ifftshift(
projections, dim=(-2, -1)
) # recenter 2D image in real space

# unpad if required
if pad is True:
projections = projections[..., pad_length:-pad_length, pad_length:-pad_length]
if edge_pad_fraction > 0:
projections = projections[..., edge_pad:-edge_pad, edge_pad:-edge_pad]

return torch.real(projections)
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import torch
import einops
import torch
from torch_image_lerp import sample_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid
from torch_fourier_slice.dft_utils import fftfreq_to_dft_coordinates
from torch_fourier_slice.grids.central_slice_fftfreq_grid import (
central_slice_fftfreq_grid,
)


def extract_central_slices_rfft_3d(
volume_rfft: torch.Tensor,
image_shape: tuple[int, int, int],
rotation_matrices: torch.Tensor, # (..., 3, 3)
fftfreq_max: float | None = None,
):
"""Extract central slice from an fftshifted rfft."""
) -> torch.Tensor:
"""Extract central slice from an fftshifted rfft.
Parameters
----------
volume_rfft: torch.Tensor
`(..., h, w)` array of fftshifted rfft of 2D slices.
image_shape: tuple[int, int, int]
Shape of the 3D volume from which `volume_rfft` was generated.
rotation_matrices: torch.Tensor
`(..., 3, 3)` array of rotation matrices for extraction of `volume_rfft`.
fftfreq_max: float | None
Maximum frequency (cycles per pixel) included in the projection.
Returns
-------
projection_image_dfts: torch.Tensor
`(..., h, w)` array of central slices extracted from `volume_rfft`.
"""
# generate grid of DFT sample frequencies for a central slice spanning the xy-plane
freq_grid = central_slice_fftfreq_grid(
volume_shape=image_shape,
Expand All @@ -28,12 +47,14 @@ def extract_central_slices_rfft_3d(

# get (b, 3, 1) array of zyx coordinates to rotate
if fftfreq_max is not None:
normed_grid = einops.reduce(freq_grid ** 2, 'h w zyx -> h w', reduction='sum') ** 0.5
normed_grid = (
einops.reduce(freq_grid**2, "h w zyx -> h w", reduction="sum") ** 0.5
)
freq_grid_mask = normed_grid <= fftfreq_max
valid_coords = freq_grid[freq_grid_mask, ...] # (b, zyx)
else:
valid_coords = einops.rearrange(freq_grid, 'h w zyx -> (h w) zyx')
valid_coords = einops.rearrange(valid_coords, 'b zyx -> b zyx 1')
valid_coords = einops.rearrange(freq_grid, "h w zyx -> (h w) zyx")
valid_coords = einops.rearrange(valid_coords, "b zyx -> b zyx 1")

# rotation matrices rotate xyz coordinates, make them rotate zyx coordinates
# xyz:
Expand All @@ -48,33 +69,37 @@ def extract_central_slices_rfft_3d(
rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1))

# add extra dim to rotation matrices for broadcasting
rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 i j')
rotation_matrices = einops.rearrange(rotation_matrices, "... i j -> ... 1 i j")

# rotate all valid coordinates by each rotation matrix
rotated_coords = rotation_matrices @ valid_coords # (..., b, zyx, 1)

# remove last dim of size 1
rotated_coords = einops.rearrange(rotated_coords, '... b zyx 1 -> ... b zyx')
rotated_coords = einops.rearrange(rotated_coords, "... b zyx 1 -> ... b zyx")

# flip coordinates that ended up in redundant half transform after rotation
conjugate_mask = rotated_coords[..., 2] < 0
rotated_coords[conjugate_mask, ...] *= -1

# convert frequencies to array coordinates in fftshifted DFT
rotated_coords = fftfreq_to_dft_coordinates(
frequencies=rotated_coords,
image_shape=image_shape,
rfft=True
frequencies=rotated_coords, image_shape=image_shape, rfft=True
)
samples = sample_image_3d(image=volume_rfft, coordinates=rotated_coords) # (...) rfft
samples = sample_image_3d(
image=volume_rfft, coordinates=rotated_coords
) # (...) rfft

# take complex conjugate of values from redundant half transform
samples[conjugate_mask] = torch.conj(samples[conjugate_mask])

# insert samples back into DFTs
projection_image_dfts = torch.zeros(output_shape, device=volume_rfft.device, dtype=volume_rfft.dtype)
projection_image_dfts = torch.zeros(
output_shape, device=volume_rfft.device, dtype=volume_rfft.dtype
)
if fftfreq_max is None:
freq_grid_mask = torch.ones(size=rfft_shape, dtype=torch.bool, device=volume_rfft.device)
freq_grid_mask = torch.ones(
size=rfft_shape, dtype=torch.bool, device=volume_rfft.device
)

projection_image_dfts[..., freq_grid_mask] = samples

Expand Down
Loading

0 comments on commit 35677d6

Please sign in to comment.