Skip to content

Commit

Permalink
updates including tests (#6)
Browse files Browse the repository at this point in the history
* updates including tests

* add scipy test dep and make 3.10+

* fix docstring typos

* 3.10+ in CI

* fix tests
  • Loading branch information
alisterburt authored Nov 7, 2024
1 parent 6abf032 commit 093c87e
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 174 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
platform: [ubuntu-latest, ] # macos-latest, windows-latest]

steps:
Expand Down
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ name = "torch-fourier-slice"
dynamic = ["version"]
description = "Fourier slice extraction/insertion in PyTorch."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = { text = "BSD-3-Clause" }
authors = [{ name = "Alister Burt", email = "[email protected]" }]
# https://pypi.org/classifiers/
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -40,13 +38,19 @@ dependencies = [
"numpy",
"einops",
"torch_image_lerp",
"torch_grid_utils",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
# "extras" (e.g. for `pip install .[test]`)
[project.optional-dependencies]
# add dependencies used for testing here
test = ["pytest", "pytest-cov"]
test = [
"pytest",
"pytest-cov",
"torch-fourier-shell-correlation",
"scipy"
]
# add anything else you like to have in your dev environment here
dev = [
"ipython",
Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/backproject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_insertion import insert_central_slices_rfft_3d


Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fftfreq_grid import fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch

from .fftfreq_grid import _construct_fftfreq_grid_2d
from torch_grid_utils import fftfreq_grid
from ..dft_utils import rfft_shape, fftshift_2d


Expand All @@ -13,7 +13,7 @@ def central_slice_fftfreq_grid(
) -> torch.Tensor:
# generate 2d grid of DFT sample frequencies, shape (h, w, 2)
h, w = volume_shape[-2:]
grid = _construct_fftfreq_grid_2d(
grid = fftfreq_grid(
image_shape=(h, w),
rfft=rfft,
device=device
Expand Down
158 changes: 0 additions & 158 deletions src/torch_fourier_slice/grids/fftfreq_grid.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_extraction import extract_central_slices_rfft_3d


Expand All @@ -18,7 +18,7 @@ def project_3d_to_2d(
volume: torch.Tensor
`(d, d, d)` volume.
rotation_matrices: torch.Tensor
`(..., 3, 3)` array of rotation matrices for insert of `images`.
`(..., 3, 3)` array of rotation matrices for insertion of `images`.
Rotation matrices left-multiply column vectors containing xyz coordinates.
pad: bool
Whether to pad the volume 2x with zeros to increase sampling rate in the DFT.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import sample_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def extract_central_slices_rfft_3d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import insert_into_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates, rfft_shape
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def insert_central_slices_rfft_3d(
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from pytest import fixture


@fixture
def cube() -> torch.Tensor:
volume = torch.zeros((32, 32, 32))
volume[8:24, 8:24, 8:24] = 1
volume[16, 16, 16] = 32
return volume
61 changes: 58 additions & 3 deletions tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,59 @@
# temporary stub
import pytest
import torch

def test_something():
pass
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group


def test_project_3d_to_2d_rotation_center():
# rotation center should be at position of DC in DFT
volume = torch.zeros((32, 32, 32))
volume[16, 16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float()
projections = project_3d_to_2d(
volume=volume,
rotation_matrices=rotation_matrices,
)

# check max is always at (16, 16), implying point (16, 16) never moves
for image in projections:
max = torch.argmax(image)
i, j = divmod(max.item(), 32)
assert (i, j) == (16, 16)


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
projections = project_3d_to_2d(
volume=cube,
rotation_matrices=rotation_matrices,
)

# reconstruct
volume = backproject_2d_to_3d(
images=projections,
rotation_matrices=rotation_matrices,
)

# calculate FSC between the projections and the reconstructions
_fsc = fsc(cube, volume.float())

assert torch.all(_fsc[-5:] > 0.99) # few low res shells at 0.98...


@pytest.mark.parametrize(
"images, rotation_matrices",
[
(
torch.rand((10, 28, 28)).float(),
torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float()
),
]
)
def test_dtypes_slice_insertion(images, rotation_matrices):
result = backproject_2d_to_3d(images, rotation_matrices)
assert result.dtype == torch.float64

0 comments on commit 093c87e

Please sign in to comment.