Skip to content

Commit

Permalink
Improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Feb 13, 2024
1 parent 3358323 commit 8218f8f
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 25 deletions.
21 changes: 16 additions & 5 deletions src/decolace/processing/cli_match_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def run_matchtemplate(
defocus_step: float = typer.Option(0.0, help="Defocus step for template matching"),
defocus_range: float = typer.Option(0.0, help="Defocus range for template matching"),
save_mip: bool = typer.Option(False, help="Save MIP of template"),
symmetry: str = typer.Option("C1", help="Symmetry of the template"),
run_on: str = typer.Option("all", help="Run on all or a subset of acquisition areas"),
):
"""Runs match template on all images in the acquisition areas"""
from decolace.processing.project_managment import MatchTemplateRun
Expand All @@ -25,7 +27,7 @@ def run_matchtemplate(
from pycistem.database import get_already_processed_images
import pandas as pd
logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format="%(message)s",
handlers=[
RichHandler(),
Expand All @@ -45,7 +47,9 @@ def run_matchtemplate(
"milano": 8,
"sofia": 8,
"manchester": 8,
}
}
if run_on != "all":
run_profile = {run_on: run_profile[run_on]}

new_mtr = True
if ctx.obj.match_template_job is None:
Expand All @@ -65,6 +69,7 @@ def run_matchtemplate(
defocus_step = ctx.obj.match_template_job.defocus_step
defocus_range = ctx.obj.match_template_job.defocus_range
match_template_job_id = ctx.obj.match_template_job.run_id
symmetry = ctx.obj.match_template_job.symmetry
typer.echo(f"Template match job id already exists, continuing job {ctx.obj.match_template_job.run_name}")
typer.echo(f"template_filename={template.absolute().as_posix()}, angular_step={angular_step}, in_plane_angular_step={in_plane_angular_step} defous_step={defocus_step}, defocus_range={defocus_range}, decolace=True)")

Expand All @@ -78,6 +83,7 @@ def run_matchtemplate(
in_plane_angular_step=in_plane_angular_step,
defocus_step=defocus_step,
defocus_range=defocus_range,
symmetry=symmetry
)
)
ctx.obj.project.write()
Expand All @@ -94,6 +100,7 @@ def run_matchtemplate(
par.defocus_step = defocus_step
par.defocus_search_range = defocus_range
par.max_threads = 2
par.my_symmetry = symmetry
if save_mip:
par.mip_output_file = par.scaled_mip_output_file.replace("_scaled_mip.mrc", "_mip.mrc")

Expand All @@ -117,6 +124,8 @@ def run_matchtemplate(
@app.command()
def create_tmpackage(
ctx: typer.Context,
force_incomplete: bool = typer.Option(False, help="Force creation of tm package even if not all images have been processed"),
swap_phi_and_psi: bool = typer.Option(False, help="Swap phi and psi in tm package"),
):
"""Creates a Template-Matches Package for each acquisition area (i.e. a star file containing all matches)"""
from pycistem.database import get_num_already_processed_images, write_match_template_to_starfile, insert_tmpackage_into_db, get_num_images
Expand All @@ -125,14 +134,16 @@ def create_tmpackage(
for aa in ctx.obj.acquisition_areas:
if get_num_already_processed_images(aa.cistem_project, ctx.obj.match_template_job.run_id) != get_num_images(aa.cistem_project):
typer.echo(f"No images processed for {aa.area_name}")
continue
typer.echo(f"Processed {get_num_already_processed_images(aa.cistem_project, ctx.obj.match_template_job.run_id)} out of {get_num_images(aa.cistem_project)}")
if not force_incomplete:
continue

typer.echo(f"Creating tm package for {aa.area_name}")
output_filename = Path(aa.cistem_project).parent / "Assets" / "TemplateMatching" / f"{aa.area_name}_{ctx.obj.match_template_job.run_id}_tm_package.star"
if output_filename.exists():
typer.echo(f"Package already exists for {aa.area_name}")
continue
write_match_template_to_starfile(aa.cistem_project, ctx.obj.match_template_job.run_id, output_filename)
write_match_template_to_starfile(aa.cistem_project, ctx.obj.match_template_job.run_id, output_filename, switch_phi_psi=swap_phi_and_psi)
insert_tmpackage_into_db(aa.cistem_project, f"DeCOoLACE_{aa.area_name}_TMRun_{ctx.obj.match_template_job.run_id}", output_filename)

@app.command()
Expand All @@ -153,7 +164,7 @@ def run_refinetemplate(
input_starfile=tm_package_file.as_posix(),
output_starfile=tm_package_file.with_suffix('').as_posix()+'_refined.star',
input_template=Path(ctx.obj.match_template_job.template_path).as_posix(),
num_threads=10
num_threads=10,
)
if Path(par.output_starfile).exists():
typer.echo(f"Refined tm package already exists for {aa.area_name}")
Expand Down
27 changes: 11 additions & 16 deletions src/decolace/processing/cli_montaging.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import typer
from pathlib import Path
from typing import List
from typing import List, Optional

app = typer.Typer()

@app.command()
def reset_montage(
project_main: Path = typer.Option(None, help="Path to wanted project file"),
ctx: typer.Context,
):
if project_main is None:
project_path = Path(glob.glob("*.decolace")[0])
project = ProcessingProject.read(project_path)

for aa in project.acquisition_areas:
for aa in ctx.obj.acquisition_areas:
aa.montage_image = None
aa.montage_star = None
aa.initial_tile_star = None
aa.refined_tile_star = None
project.write()
ctx.obj.project.write()

@app.command()
def run_montage(
Expand All @@ -33,20 +29,18 @@ def run_montage(
overlap_ratio: float = typer.Option(0.2, help="Overlap ratio parameter for masked crosscorrelation"),
redo: bool = typer.Option(False, help="Redo the montage even if it already exists"),
redo_montage: bool = typer.Option(False, help="Redo only the creatin of the montage even if it already exists"),
max_mean_density: Optional[float] = typer.Option(None, help="Maximum mean density of the tiles"),
cc_cutoff_as_fraction_of_median: float = typer.Option(0.5, help="Cutoff for the cross-correlation as a fraction of the median cross-correlation"),
):
from rich.console import Console
import starfile
from numpy.linalg import LinAlgError
import pandas as pd
import glob
from decolace.processing.project_managment import ProcessingProject


console = Console()

from decolace.processing.decolace_processing import read_data_from_cistem, read_decolace_data, create_tile_metadata, find_tile_pairs, calculate_shifts, calculate_refined_image_shifts, calculate_refined_intensity, create_montage_metadata, create_montage, prune_bad_shifts
from decolace.processing.decolace_processing import read_data_from_cistem, read_decolace_data, create_tile_metadata, find_tile_pairs, calculate_shifts, calculate_refined_image_shifts, calculate_refined_intensity, create_montage_metadata, create_montage, prune_bad_shifts, prune_tiles
import numpy as np

for i, aa in enumerate(ctx.obj.acquisition_areas):
if aa.montage_image is not None and not redo and not redo_montage:
continue
Expand All @@ -69,7 +63,9 @@ def run_montage(
if type(decolace_data) is not dict:
decolace_data = dict(decolace_data)
tile_metadata_result = create_tile_metadata(cistem_data, decolace_data, output_path)
tile_metadata = tile_metadata_result["tiles"]
tile_metadata = tile_metadata_result["tiles"]
tile_metadata, message = prune_tiles(tile_metadata, max_mean_density=max_mean_density)
console.log(message)
aa.initial_tile_star = output_path
typer.echo(f"Wrote tile metadata to {output_path}.")
else:
Expand All @@ -78,7 +74,6 @@ def run_montage(
typer.echo(f"Read tile metadata from {aa.initial_tile_star}.")

# I should sort by acquisition time here

if aa.refined_tile_star is None or redo:
estimated_distance_threshold = np.median(
tile_metadata["tile_x_size"] * tile_metadata["tile_pixel_size"]
Expand All @@ -98,7 +93,7 @@ def run_montage(
f"Shifts were adjusted. Mean: {np.mean(shifts['add_shift']):.2f} A, Median: {np.median(shifts['add_shift']):.2f} A, Std: {np.std(shifts['add_shift']):.2f} A. Min: {np.min(shifts['add_shift']):.2f} A, Max: {np.max(shifts['add_shift']):.2f} A."
)
# TODO: prune bad shifts and then tiles with no shifts
shifts, message = prune_bad_shifts(shifts)
shifts, message = prune_bad_shifts(shifts,cc_cutoff_as_fraction_of_media=cc_cutoff_as_fraction_of_median)
shifts = shifts.copy()
console.log(message)

Expand Down
12 changes: 11 additions & 1 deletion src/decolace/processing/cli_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def run_ctffind(
cmd_prefix: str = typer.Option("", help="Prefix of run command"),
cmd_suffix: str = typer.Option("", help="Suffix of run command"),
num_cores: int = typer.Option(10, help="Number of cores to use"),
fit_nodes: bool = typer.Option(True, help="Fit nodes"),
fit_nodes_brute_force: bool = typer.Option(True, help="Fit nodes brute force"),
fit_nodes_lowres: float = typer.Option(30.0, help="Fit nodes lowres"),
fit_nodes_highres: float = typer.Option(4.0, help="Fit nodes highres"),
):
"""
Run ctffind for each acquisition area
Expand All @@ -93,6 +97,11 @@ def run_ctffind(
typer.echo(f"Running ctffind for {aa.area_name}")
pars, image_info = ctffind.parameters_from_database(aa.cistem_project,decolace=True)

for par in pars:
par.fit_nodes = fit_nodes
par.fit_nodes_1D_brute_force = fit_nodes_brute_force
par.fit_nodes_low_resolution_limit = fit_nodes_lowres
par.fit_nodes_high_resolution_limit = fit_nodes_highres
res = ctffind.run(pars,num_procs=num_cores,cmd_prefix=cmd_prefix,cmd_suffix=cmd_suffix)

ctffind.write_results_to_database(aa.cistem_project,pars,res,image_info)
Expand All @@ -104,7 +113,8 @@ def update_database(
project_main: Path = typer.Option(None, help="Path to wanted project file")
):
from pycistem.core import Project

import glob
from decolace.processing.project_managment import ProcessingProject
if project_main is None:
project_path = Path(glob.glob("*.decolace")[0])
project = ProcessingProject.read(project_path)
Expand Down
2 changes: 2 additions & 0 deletions src/decolace/processing/cli_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rich import print
from rich.logging import RichHandler
from typer.core import TyperGroup
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from decolace.processing.project_managment import ProcessingProject, DLContext, DLGlobals

Expand Down
6 changes: 5 additions & 1 deletion src/decolace/processing/cli_project_managment.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ def status(
status_cache[f"{aa.area_name}_match_template"] = [cistem_project_size, cistem_project_mtime, mtm_status]
json.dump(status_cache, open(status_chache_file, "w"))
for i, mtm in enumerate(ctx.obj.project.match_template_runs):
mtm_totals[i] += int(mtm_status[i].split(" ")[1])
if i >= len(mtm_status):
continue
split_line = mtm_status[i].split(" ")
if len(split_line) > 1:
mtm_totals[i] += int(split_line[1])
match_template_table.add_row(
aa.area_name, *mtm_status)
mtm_totals = [str(mtm_total) for mtm_total in mtm_totals]
Expand Down
80 changes: 80 additions & 0 deletions src/decolace/processing/cli_single_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,86 @@

app = typer.Typer()

@app.command()
def split_particles_into_experimental_groups(
ctx: typer.Context,
tmpackage_star_file: Path = typer.Argument(..., help="Path to the tmpackage star file"),
variables: list[str] = typer.Argument(..., help="Variables to split by"),
extract: bool = typer.Option(False, help="Extract the particles"),
):
"""
Splits particles into experimental groups based on specified variables.
Args:
ctx (typer.Context): The Typer context object.
tmpackage_star_file (Path): Path to the tmpackage star file.
variables (list[str]): Variables to split the particles by.
Returns:
None
"""
import starfile
import pandas as pd
from rich.progress import track
from decolace.processing.project_managment import generate_aa_dataframe
particle_info = starfile.read(tmpackage_star_file)
aa_info = generate_aa_dataframe(ctx.obj.acquisition_areas)
aa_info["image_folder"] = aa_info["cistem_project"].map(lambda x: str(Path(x).parent / "Assets" / "Images" ))
particle_info["image_folder"] = particle_info["cisTEMOriginalImageFilename"].str.strip("'").map(lambda x: str(Path(x).parent))
particle_info = particle_info.join(aa_info.set_index("image_folder"), on="image_folder", rsuffix="_aa")
(tmpackage_star_file.parent / tmpackage_star_file.stem).mkdir(exist_ok=True)
def write_and_print(x):
name = "_".join(x.name)
print(f"Writing {tmpackage_star_file.parent / tmpackage_star_file.stem / name}.star")
# Only keep columns starting with cisTEM
x = x.filter(like="cisTEM")
x["cisTEMPositionInStack"] = [i+1 for i in range(len(x))]
x["cisTEMOccupancy"] = 1.0
x["cisTEMImageActivity"] = 1
starfile.write(x, tmpackage_star_file.parent / tmpackage_star_file.stem / f"{name}.star", overwrite=True, quote_all_strings=True, quote_character="'")

if extract:
from pycistem.utils import extract_particles

for _ in track(extract_particles(tmpackage_star_file.parent / tmpackage_star_file.stem / f"{name}.star",tmpackage_star_file.parent / tmpackage_star_file.stem / f"{name}.mrc"), description=f"Extracting {name}", total=len(x)):
pass

particle_info.groupby(variables).apply(write_and_print)

@app.command()
def reconstruct_split_particles(
ctx: typer.Context,
star_directory: Path = typer.Argument(..., help="Path to the directory containing the split star files"),
pixel_size: float = typer.Option(2.0, help="Pixel size of the particles"),
):
#logging.basicConfig(
# level=logging.DEBUG,
# format="%(message)s",
# handlers=[
# RichHandler(),
# #logging.FileHandler(current_output_directory / "log.log")
# ]
#)
from pycistem.programs.reconstruct3d import Reconstruct3dParameters, run
starfiles = star_directory.glob("*.star")
pars = []
for star in starfiles:
reconstructPar = Reconstruct3dParameters(
input_star_filename=str(star),
input_particle_stack=str(star.with_suffix(".mrc")),
output_reconstruction_1=str(star.parent / f"{star.stem}_reconstruction1.mrc"),
output_reconstruction_2=str(star.parent / f"{star.stem}_reconstruction2.mrc"),
output_reconstruction_filtered=str(star.parent / f"{star.stem}_reconstruction_filtered.mrc"),
output_resolution_statistics=str(star.parent / f"{star.stem}_resolution_statistics.txt"),
pixel_size=pixel_size,
molecular_mass_kDa=3000.0
)
pars.append(reconstructPar)
print(pars)
run(pars, num_threads=10)



@app.command()
def split_particles_into_optical_groups(
ctx: typer.Context,
Expand Down
21 changes: 21 additions & 0 deletions src/decolace/processing/decolace_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sqlite3
from functools import partial
from pathlib import Path
from typing import Optional

import mrcfile
import numpy as np
Expand Down Expand Up @@ -70,6 +71,20 @@ def read_decolace_data(decolace_filename: Path) -> dict:

return np.load(decolace_filename, allow_pickle=True).item()

def prune_tiles(tile_metadata, max_mean_density: Optional[float] = None) -> tuple[pd.DataFrame, str]:
if max_mean_density is None:
return tile_metadata, ""
good_tiles = []
initial_len = len(tile_metadata)
for i, tile in tile_metadata.iterrows():
with mrcfile.open(tile["tile_filename"]) as mrc:
mean_density = mrc.header["dmean"]
if mean_density < max_mean_density:
good_tiles.append(i)
tile_metadata = tile_metadata.loc[good_tiles]
message = f"Pruned {initial_len-len(good_tiles)} tiles with a mean density higher than {max_mean_density}.\n"
return tile_metadata.copy(), message

def prune_bad_shifts(shifts: pd.DataFrame, inital_area_cutoff_as_fraction_of_median:float = 0.2, cc_cutoff_as_fraction_of_media:float=0.5):

init_len = len(shifts)
Expand Down Expand Up @@ -379,6 +394,7 @@ def determine_shift_by_cc(
filter_order=4.0,
mask_size_cutoff: int = 100,
overlap_ratio: float = 0.1,
debug: bool = True,
):
import time
# Create the montage
Expand Down Expand Up @@ -446,6 +462,11 @@ def determine_shift_by_cc(
mode="full",
overlap_ratio=overlap_ratio,
)
if debug:
with mrcfile.new(
f"debug_{Path(im1['tile_filename']).name}_vs_{Path(im2['tile_filename']).name}.mrc", overwrite=True
) as mrc:
mrc.set_data(xcorr)
#print(f"Cross took {time.perf_counter() - prev} seconds")
prev = time.perf_counter()
# Generalize to the average of multiple equal maxima
Expand Down
Loading

0 comments on commit 8218f8f

Please sign in to comment.