-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* updates including tests * add scipy test dep and make 3.10+ * fix docstring typos * 3.10+ in CI * fix tests
- Loading branch information
1 parent
6abf032
commit 093c87e
Showing
11 changed files
with
85 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,16 +19,14 @@ name = "torch-fourier-slice" | |
dynamic = ["version"] | ||
description = "Fourier slice extraction/insertion in PyTorch." | ||
readme = "README.md" | ||
requires-python = ">=3.8" | ||
requires-python = ">=3.10" | ||
license = { text = "BSD-3-Clause" } | ||
authors = [{ name = "Alister Burt", email = "[email protected]" }] | ||
# https://pypi.org/classifiers/ | ||
classifiers = [ | ||
"Development Status :: 3 - Alpha", | ||
"License :: OSI Approved :: BSD License", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
|
@@ -40,13 +38,19 @@ dependencies = [ | |
"numpy", | ||
"einops", | ||
"torch_image_lerp", | ||
"torch_grid_utils", | ||
] | ||
|
||
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies | ||
# "extras" (e.g. for `pip install .[test]`) | ||
[project.optional-dependencies] | ||
# add dependencies used for testing here | ||
test = ["pytest", "pytest-cov"] | ||
test = [ | ||
"pytest", | ||
"pytest-cov", | ||
"torch-fourier-shell-correlation", | ||
"scipy" | ||
] | ||
# add anything else you like to have in your dev environment here | ||
dev = [ | ||
"ipython", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .fftfreq_grid import fftfreq_grid | ||
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
from pytest import fixture | ||
|
||
|
||
@fixture | ||
def cube() -> torch.Tensor: | ||
volume = torch.zeros((32, 32, 32)) | ||
volume[8:24, 8:24, 8:24] = 1 | ||
volume[16, 16, 16] = 32 | ||
return volume |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,59 @@ | ||
# temporary stub | ||
import pytest | ||
import torch | ||
|
||
def test_something(): | ||
pass | ||
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d | ||
from torch_fourier_shell_correlation import fsc | ||
from scipy.stats import special_ortho_group | ||
|
||
|
||
def test_project_3d_to_2d_rotation_center(): | ||
# rotation center should be at position of DC in DFT | ||
volume = torch.zeros((32, 32, 32)) | ||
volume[16, 16, 16] = 1 | ||
|
||
# make projections | ||
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float() | ||
projections = project_3d_to_2d( | ||
volume=volume, | ||
rotation_matrices=rotation_matrices, | ||
) | ||
|
||
# check max is always at (16, 16), implying point (16, 16) never moves | ||
for image in projections: | ||
max = torch.argmax(image) | ||
i, j = divmod(max.item(), 32) | ||
assert (i, j) == (16, 16) | ||
|
||
|
||
def test_3d_2d_projection_backprojection_cycle(cube): | ||
# make projections | ||
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float() | ||
projections = project_3d_to_2d( | ||
volume=cube, | ||
rotation_matrices=rotation_matrices, | ||
) | ||
|
||
# reconstruct | ||
volume = backproject_2d_to_3d( | ||
images=projections, | ||
rotation_matrices=rotation_matrices, | ||
) | ||
|
||
# calculate FSC between the projections and the reconstructions | ||
_fsc = fsc(cube, volume.float()) | ||
|
||
assert torch.all(_fsc[-5:] > 0.99) # few low res shells at 0.98... | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"images, rotation_matrices", | ||
[ | ||
( | ||
torch.rand((10, 28, 28)).float(), | ||
torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float() | ||
), | ||
] | ||
) | ||
def test_dtypes_slice_insertion(images, rotation_matrices): | ||
result = backproject_2d_to_3d(images, rotation_matrices) | ||
assert result.dtype == torch.float64 |