Skip to content

Commit

Permalink
Initial masking draft
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Aug 20, 2024
1 parent 95ca943 commit 9481961
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
53 changes: 46 additions & 7 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional

Expand All @@ -10,6 +11,11 @@
cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False)


class Masking(str, Enum):
none = "none"
sphere = "sphere"


@cli.command(no_args_is_help=True)
def ttfsc_cli(
map1: Annotated[Path, typer.Argument(show_default=False)],
Expand All @@ -22,8 +28,10 @@ def ttfsc_cli(
] = None,
plot: Annotated[bool, typer.Option("--plot")] = True,
plot_with_matplotlib: Annotated[bool, typer.Option("--plot-with-matplotlib")] = False,
fsc_threshold: Annotated[float, typer.Option("--fsc-threshold", show_default=False, help="FSC threshold")] = 0.143,
# mask: Annotated[]
fsc_threshold: Annotated[float, typer.Option("--fsc-threshold", help="FSC threshold")] = 0.143,
mask: Annotated[Masking, typer.Option("--mask")] = Masking.none,
mask_radius: Annotated[float, typer.Option("--mask-radius")] = 100.0,
mask_soft_edge_width: Annotated[int, typer.Option("--mask-soft-edge-width")] = 10,
) -> None:
with mrcfile.open(map1) as f:
map1_tensor = torch.tensor(f.data)
Expand All @@ -34,25 +42,56 @@ def ttfsc_cli(

frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0])
resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms
fsc_values = fsc(map1_tensor, map2_tensor)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values < fsc_threshold).nonzero()[0] - 1])
fsc_values_unmasked = fsc(map1_tensor, map2_tensor)
fsc_values_masked = None

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])

if mask == Masking.sphere:
import numpy as np
from ttmask.box_setup import box_setup
from ttmask.soft_edge import add_soft_edge
# Taken from https://github.com/teamtomo/ttmask/blob/main/src/ttmask/sphere.py

# establish our coordinate system and empty mask
coordinates_centered, mask_tensor = box_setup(map1_tensor.shape[0])

# determine distances of each pixel to the center
distance_to_center = np.linalg.norm(coordinates_centered, axis=-1)

# set up criteria for which pixels are inside the sphere and modify values to 1.
inside_sphere = distance_to_center < (mask_radius / pixel_spacing_angstroms)
mask_tensor[inside_sphere] = 1

# if requested, a soft edge is added to the mask
mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width)

map1_tensor = map1_tensor * mask_tensor
map2_tensor = map2_tensor * mask_tensor
fsc_values_masked = fsc(map1_tensor, map2_tensor)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])

rprint(f"Estimated resolution using {fsc_threshold} criterion: {estimated_resolution_angstrom:.2f} Å")

if plot:
from ._plotting import plot_matplotlib, plot_plottile

if plot_with_matplotlib:
plot_matplotlib(
fsc_values=fsc_values,
fsc_values_unmasked=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
resolution_angstroms=resolution_angstroms,
estimated_resolution_angstrom=estimated_resolution_angstrom,
fsc_threshold=fsc_threshold,
)
else:
plot_plottile(
fsc_values=fsc_values,
fsc_values=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
frequency_pixels=frequency_pixels,
pixel_spacing_angstroms=pixel_spacing_angstroms,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
Expand Down
18 changes: 15 additions & 3 deletions src/ttfsc/_plotting.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from typing import Optional

import torch


def plot_matplotlib(
fsc_values: torch.Tensor, resolution_angstroms: torch.Tensor, estimated_resolution_angstrom: float, fsc_threshold: float
fsc_values_unmasked: torch.Tensor,
fsc_values_masked: Optional[torch.Tensor],
resolution_angstroms: torch.Tensor,
estimated_resolution_angstrom: float,
fsc_threshold: float,
) -> None:
from matplotlib import pyplot as plt

plt.hlines(0, resolution_angstroms[1], resolution_angstroms[-2], "black")
plt.plot(resolution_angstroms, fsc_values, label="FSC (unmasked)")
plt.plot(resolution_angstroms, fsc_values_unmasked, label="FSC (unmasked)")
if fsc_values_masked is not None:
plt.plot(resolution_angstroms, fsc_values_masked, label="FSC (masked)")

plt.xlabel("Resolution (Å)")
plt.ylabel("Correlation")
plt.xscale("log")
Expand All @@ -22,6 +31,7 @@ def plot_matplotlib(

def plot_plottile(
fsc_values: torch.Tensor,
fsc_values_masked: Optional[torch.Tensor],
frequency_pixels: torch.Tensor,
pixel_spacing_angstroms: float,
estimated_resolution_frequency_pixel: float,
Expand All @@ -39,7 +49,9 @@ def resolution_callback(x: float, _: float) -> str:
return f"{(1 / x) * pixel_spacing_angstroms:.2f}"

fig.x_ticks_fkt = resolution_callback
fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC")
fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC (unmasked)")
if fsc_values_masked is not None:
fig.plot(frequency_pixels[1:].numpy(), fsc_values_masked[1:].numpy(), lc="green", label="FSC (masked)")

fig.plot(
[float(frequency_pixels[1].numpy()), estimated_resolution_frequency_pixel],
Expand Down

0 comments on commit 9481961

Please sign in to comment.