From 94819617df9edf72ce949db5e4ce14e6844a9237 Mon Sep 17 00:00:00 2001 From: Johannes Elferich Date: Tue, 20 Aug 2024 19:24:33 -0400 Subject: [PATCH] Initial masking draft --- src/ttfsc/_cli.py | 53 ++++++++++++++++++++++++++++++++++++------ src/ttfsc/_plotting.py | 18 +++++++++++--- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/ttfsc/_cli.py b/src/ttfsc/_cli.py index 5eee389..f584c08 100644 --- a/src/ttfsc/_cli.py +++ b/src/ttfsc/_cli.py @@ -1,3 +1,4 @@ +from enum import Enum from pathlib import Path from typing import Annotated, Optional @@ -10,6 +11,11 @@ cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False) +class Masking(str, Enum): + none = "none" + sphere = "sphere" + + @cli.command(no_args_is_help=True) def ttfsc_cli( map1: Annotated[Path, typer.Argument(show_default=False)], @@ -22,8 +28,10 @@ def ttfsc_cli( ] = None, plot: Annotated[bool, typer.Option("--plot")] = True, plot_with_matplotlib: Annotated[bool, typer.Option("--plot-with-matplotlib")] = False, - fsc_threshold: Annotated[float, typer.Option("--fsc-threshold", show_default=False, help="FSC threshold")] = 0.143, - # mask: Annotated[] + fsc_threshold: Annotated[float, typer.Option("--fsc-threshold", help="FSC threshold")] = 0.143, + mask: Annotated[Masking, typer.Option("--mask")] = Masking.none, + mask_radius: Annotated[float, typer.Option("--mask-radius")] = 100.0, + mask_soft_edge_width: Annotated[int, typer.Option("--mask-soft-edge-width")] = 10, ) -> None: with mrcfile.open(map1) as f: map1_tensor = torch.tensor(f.data) @@ -34,10 +42,39 @@ def ttfsc_cli( frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0]) resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms - fsc_values = fsc(map1_tensor, map2_tensor) - estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values < fsc_threshold).nonzero()[0] - 1]) - estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values < fsc_threshold).nonzero()[0] - 1]) + fsc_values_unmasked = fsc(map1_tensor, map2_tensor) + fsc_values_masked = None + + estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) + estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) + + if mask == Masking.sphere: + import numpy as np + from ttmask.box_setup import box_setup + from ttmask.soft_edge import add_soft_edge + # Taken from https://github.com/teamtomo/ttmask/blob/main/src/ttmask/sphere.py + + # establish our coordinate system and empty mask + coordinates_centered, mask_tensor = box_setup(map1_tensor.shape[0]) + + # determine distances of each pixel to the center + distance_to_center = np.linalg.norm(coordinates_centered, axis=-1) + + # set up criteria for which pixels are inside the sphere and modify values to 1. + inside_sphere = distance_to_center < (mask_radius / pixel_spacing_angstroms) + mask_tensor[inside_sphere] = 1 + + # if requested, a soft edge is added to the mask + mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width) + + map1_tensor = map1_tensor * mask_tensor + map2_tensor = map2_tensor * mask_tensor + fsc_values_masked = fsc(map1_tensor, map2_tensor) + + estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1]) + estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1]) + rprint(f"Estimated resolution using {fsc_threshold} criterion: {estimated_resolution_angstrom:.2f} Å") if plot: @@ -45,14 +82,16 @@ def ttfsc_cli( if plot_with_matplotlib: plot_matplotlib( - fsc_values=fsc_values, + fsc_values_unmasked=fsc_values_unmasked, + fsc_values_masked=fsc_values_masked, resolution_angstroms=resolution_angstroms, estimated_resolution_angstrom=estimated_resolution_angstrom, fsc_threshold=fsc_threshold, ) else: plot_plottile( - fsc_values=fsc_values, + fsc_values=fsc_values_unmasked, + fsc_values_masked=fsc_values_masked, frequency_pixels=frequency_pixels, pixel_spacing_angstroms=pixel_spacing_angstroms, estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel, diff --git a/src/ttfsc/_plotting.py b/src/ttfsc/_plotting.py index a03ba87..dc5f661 100644 --- a/src/ttfsc/_plotting.py +++ b/src/ttfsc/_plotting.py @@ -1,13 +1,22 @@ +from typing import Optional + import torch def plot_matplotlib( - fsc_values: torch.Tensor, resolution_angstroms: torch.Tensor, estimated_resolution_angstrom: float, fsc_threshold: float + fsc_values_unmasked: torch.Tensor, + fsc_values_masked: Optional[torch.Tensor], + resolution_angstroms: torch.Tensor, + estimated_resolution_angstrom: float, + fsc_threshold: float, ) -> None: from matplotlib import pyplot as plt plt.hlines(0, resolution_angstroms[1], resolution_angstroms[-2], "black") - plt.plot(resolution_angstroms, fsc_values, label="FSC (unmasked)") + plt.plot(resolution_angstroms, fsc_values_unmasked, label="FSC (unmasked)") + if fsc_values_masked is not None: + plt.plot(resolution_angstroms, fsc_values_masked, label="FSC (masked)") + plt.xlabel("Resolution (Å)") plt.ylabel("Correlation") plt.xscale("log") @@ -22,6 +31,7 @@ def plot_matplotlib( def plot_plottile( fsc_values: torch.Tensor, + fsc_values_masked: Optional[torch.Tensor], frequency_pixels: torch.Tensor, pixel_spacing_angstroms: float, estimated_resolution_frequency_pixel: float, @@ -39,7 +49,9 @@ def resolution_callback(x: float, _: float) -> str: return f"{(1 / x) * pixel_spacing_angstroms:.2f}" fig.x_ticks_fkt = resolution_callback - fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC") + fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC (unmasked)") + if fsc_values_masked is not None: + fig.plot(frequency_pixels[1:].numpy(), fsc_values_masked[1:].numpy(), lc="green", label="FSC (masked)") fig.plot( [float(frequency_pixels[1].numpy()), estimated_resolution_frequency_pixel],