diff --git a/src/decolace/processing/cli_match_template.py b/src/decolace/processing/cli_match_template.py index cb8b85b..25e72d7 100644 --- a/src/decolace/processing/cli_match_template.py +++ b/src/decolace/processing/cli_match_template.py @@ -23,11 +23,11 @@ def run_matchtemplate( from decolace.processing.project_managment import MatchTemplateRun from pycistem.programs import match_template, generate_gpu_prefix, generate_num_procs import pycistem - pycistem.set_cistem_path("/groups/elferich/cistem_binaries/") + pycistem.set_cistem_path(ctx.obj.cistem_path) from pycistem.database import get_already_processed_images import pandas as pd logging.basicConfig( - level=logging.INFO, + level=logging.DEBUG, format="%(message)s", handlers=[ RichHandler(), @@ -47,6 +47,12 @@ def run_matchtemplate( "milano": 8, "sofia": 8, "manchester": 8, + "quebec": 8, + "boston": 8, + "orleans": 8, + "brno": 8, + "hannover": 8, + "muenchen": 8, } if run_on != "all": run_profile = {run_on: run_profile[run_on]} @@ -99,7 +105,7 @@ def run_matchtemplate( par.in_plane_angular_step = in_plane_angular_step par.defocus_step = defocus_step par.defocus_search_range = defocus_range - par.max_threads = 2 + par.max_threads = 6 par.my_symmetry = symmetry if save_mip: par.mip_output_file = par.scaled_mip_output_file.replace("_scaled_mip.mrc", "_mip.mrc") @@ -114,7 +120,7 @@ def run_matchtemplate( all_image_info = pd.concat(all_image_info) typer.echo(f"Total of {len(all_image_info)} tiles to process") - res = match_template.run(all_image_info,num_procs=generate_num_procs(run_profile),cmd_prefix=list(generate_gpu_prefix(run_profile)),cmd_suffix='"', sleep_time=1.0, write_directly_to_db=True) + res = match_template.run(all_image_info,num_procs=generate_num_procs(run_profile),cmd_prefix=list(generate_gpu_prefix(run_profile)),cmd_suffix=f'" 2>> /tmp/tmerror.txt 1>> /tmp/tmlog.txt', sleep_time=1.0, write_directly_to_db=True, save_output=False, save_output_path="/tmp/tm_debug/") #typer.echo(f"Writing results for {aa.area_name}") #match_template.write_results_to_database(aa.cistem_project,pars,res,image_info) @@ -238,9 +244,11 @@ def precompute_filters( typer.echo(f"Filter values for matches already exist for {aa.area_name}") continue refined_matches = starfile.read(refined_matches_starfile) - with concurrent.futures.ThreadPoolExecutor() as executor: - fn = partial(get_distance_to_edge, refined_matches=refined_matches, binning_boxsize=binning_boxsize) - executor.map(fn, refined_matches["cisTEMOriginalImageFilename"].unique()) + for filename in refined_matches["cisTEMOriginalImageFilename"].unique(): + get_distance_to_edge(filename, refined_matches, binning_boxsize) + #with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + # fn = partial(get_distance_to_edge, refined_matches=refined_matches, binning_boxsize=binning_boxsize) + # executor.map(fn, refined_matches["cisTEMOriginalImageFilename"].unique()) starfile.write(refined_matches, filtered_matches_starfile, overwrite=True) @@ -269,6 +277,7 @@ def filter_matches( typer.echo(f"No refined matches for {aa.area_name}") continue refined_matches = starfile.read(refined_matches_starfile) + refined_matches["LACEBeamEdgeDistance"] = refined_matches["LACEBeamEdgeDistance"].astype(float) print(f"Starting with {len(refined_matches)} matches") refined_matches = refined_matches[refined_matches["cisTEMScore"] >= refined_score_cutoff] print(f"After {refined_score_cutoff} score criterion {len(refined_matches)} matches") @@ -322,7 +331,7 @@ def join_matches( combined_matches.iloc[i,combined_matches.columns.get_loc('cisTEMOriginalImageFilename')] = "'"+str(new_filename)+"'" if use_different_pixel_size is not None: combined_matches['cisTEMPixelSize'] = use_different_pixel_size - starfile.write(combined_matches, combined_matches_starfile, overwrite=True) + starfile.write(combined_matches, combined_matches_starfile, overwrite=True, quote_character="'", quote_all_strings=True) database = create_project(f"processing_{name}", ctx.obj.project.project_path.absolute() / "cistem_projects") print(f"{combined_matches_starfile}") diff --git a/src/decolace/processing/cli_montaging.py b/src/decolace/processing/cli_montaging.py index 67f769a..2f6cd14 100644 --- a/src/decolace/processing/cli_montaging.py +++ b/src/decolace/processing/cli_montaging.py @@ -31,6 +31,7 @@ def run_montage( redo_montage: bool = typer.Option(False, help="Redo only the creatin of the montage even if it already exists"), max_mean_density: Optional[float] = typer.Option(None, help="Maximum mean density of the tiles"), cc_cutoff_as_fraction_of_median: float = typer.Option(0.5, help="Cutoff for the cross-correlation as a fraction of the median cross-correlation"), + debug: bool = typer.Option(False, help="Debug mode") ): from rich.console import Console import starfile diff --git a/src/decolace/processing/cli_preprocessing.py b/src/decolace/processing/cli_preprocessing.py index 387b9f0..e10858a 100644 --- a/src/decolace/processing/cli_preprocessing.py +++ b/src/decolace/processing/cli_preprocessing.py @@ -81,6 +81,7 @@ def run_ctffind( cmd_suffix: str = typer.Option("", help="Suffix of run command"), num_cores: int = typer.Option(10, help="Number of cores to use"), fit_nodes: bool = typer.Option(True, help="Fit nodes"), + tilt: bool = typer.Option(True, help="Tilt"), fit_nodes_brute_force: bool = typer.Option(True, help="Fit nodes brute force"), fit_nodes_lowres: float = typer.Option(30.0, help="Fit nodes lowres"), fit_nodes_highres: float = typer.Option(4.0, help="Fit nodes highres"), @@ -89,7 +90,8 @@ def run_ctffind( Run ctffind for each acquisition area """ from pycistem.programs import ctffind - + import pycistem + pycistem.set_cistem_path(ctx.obj.cistem_path) for aa in ctx.obj.acquisition_areas: if aa.ctffind_run: @@ -98,6 +100,9 @@ def run_ctffind( pars, image_info = ctffind.parameters_from_database(aa.cistem_project,decolace=True) for par in pars: + par.determine_tilt = tilt + par.minimum_defocus = 10000 + par.maximum_defocus = 50000 par.fit_nodes = fit_nodes par.fit_nodes_1D_brute_force = fit_nodes_brute_force par.fit_nodes_low_resolution_limit = fit_nodes_lowres @@ -141,19 +146,16 @@ def redo_projects( @app.command() def redo_unblur( - project_main: Path = typer.Option(None, help="Path to wanted project file") + ctx: typer.Context, ): - if project_main is None: - project_path = Path(glob.glob("*.decolace")[0]) - project = ProcessingProject.read(project_path) - - for aa in project.acquisition_areas: + for aa in ctx.obj.acquisition_areas: aa.unblur_run = False - project.write() + ctx.obj.project.write() -@app.command() + +app.command() def redo_ctffind( project_main: Path = typer.Option(None, help="Path to wanted project file") ): diff --git a/src/decolace/processing/cli_project_managment.py b/src/decolace/processing/cli_project_managment.py index 0620984..f70f01d 100644 --- a/src/decolace/processing/cli_project_managment.py +++ b/src/decolace/processing/cli_project_managment.py @@ -39,13 +39,13 @@ def add_acquisition_area( aa = AcquisitionAreaSingle(area_name, area_directory.as_posix()) aa.load_from_disk() - if np.sum(aa.state['positions_acquired']) == 0: + if np.sum(aa.state.positions_acquired) == 0: print(f"{aa.name}: No Data") aa_pre = AcquisitionAreaPreProcessing( area_name = f"{session_name}_{grid_name}_{aa.name}", decolace_acquisition_area_info_path = sorted(Path(aa.directory).glob(f"{aa.name}*.npy"))[-1], decolace_grid_info_path = None, - decolace_session_info_path = None, + decolace_session_info_path = sorted(Path(aa.directory).parent.parent.glob(f"{session_name}*.npy"))[-1], frames_folder = aa.frames_directory, ) ctx.obj.project.acquisition_areas.append(aa_pre) diff --git a/src/decolace/processing/decolace_processing.py b/src/decolace/processing/decolace_processing.py index 4bcbb9a..875f6fe 100644 --- a/src/decolace/processing/decolace_processing.py +++ b/src/decolace/processing/decolace_processing.py @@ -12,8 +12,10 @@ import starfile from rich.progress import track from scipy import optimize -from scipy.ndimage import binary_erosion +from scipy.ndimage import binary_erosion, mean from scipy.spatial import cKDTree +from scipy.signal import savgol_filter +from scipy.interpolate import interp1d from skimage import filters, transform from skimage.registration._masked_phase_cross_correlation import cross_correlate_masked from skimage.transform import resize @@ -204,12 +206,12 @@ def create_montage_metadata( unbinned_size_x = ( tile_data["tile_image_shift_pixel_x"].max() + tile_data["tile_x_size"].max() - - tile_data["tile_image_shift_pixel_x"].min() + - tile_data["tile_image_shift_pixel_x"].min() + binning*5 ) unbinned_size_y = ( tile_data["tile_image_shift_pixel_y"].max() + tile_data["tile_y_size"].max() - - tile_data["tile_image_shift_pixel_y"].min() + - tile_data["tile_image_shift_pixel_y"].min() + binning*5 ) x_offset = tile_data["tile_image_shift_pixel_x"].min() @@ -240,8 +242,21 @@ def create_montage_metadata( starfile.write(results, output_path_metadata, overwrite=True) return results +def calculate_diagonal_radius(box_size: int = 512) -> int: + return int(np.around(np.sqrt(2 * (box_size / 2 - 0.5) ** 2))) -def create_montage(montage_metadata: dict, output_path_montage: Path, erode_mask: int = 0): +def distance_from_center_array(shape) -> np.ndarray: + x, y = np.ogrid[0:shape[0], 0:shape[1]] + size = np.min(shape) + r = np.hypot(x - (size - 1) / 2, y - (size - 1) / 2) + return r + +def radial_average(spectrum: np.ndarray) -> np.ndarray: + distance_from_center = distance_from_center_array(spectrum.shape) + bins = np.around(distance_from_center).astype(np.int32) + return mean(spectrum, labels=bins, index=np.arange(1, calculate_diagonal_radius(spectrum.shape[0])+1)) + +def create_montage(montage_metadata: dict, output_path_montage: Path, erode_mask: int = 0, correct_dark_ring: bool = True, dark_ring_start: float = 0.9, dark_ring_windowlength: int = 50): import time # Create the montage prev = time.perf_counter() @@ -292,8 +307,28 @@ def create_montage(montage_metadata: dict, output_path_montage: Path, erode_mask ) #print(f"Opening took {time.perf_counter() - prev} seconds") prev = time.perf_counter() + if correct_dark_ring and tile.shape[0] > 4000 and tile.shape[1] > 4000: + ra = radial_average(tile) + shape = tile.shape + x, y = np.ogrid[0:shape[0], 0:shape[1]] + size = np.min(shape) + r = np.hypot(x - (size - 1) / 2, y - (size - 1) / 2) + correction_image = np.ones_like(r) + + smoothed_curve = savgol_filter(ra[int((shape[0]//2)*0.9):(shape[0]//2)], dark_ring_windowlength, 3) + smoothed_curve = smoothed_curve / smoothed_curve[0] + values_to_correct = r[np.where(np.logical_and(r>=int((shape[0]//2)*dark_ring_start), r < shape[0//2]))] + oldvalues = radial_average(r)[int((shape[0]//2)*0.9):(shape[0]//2)] + correction_image[np.where(np.logical_and(r>=int((shape[0]//2)*dark_ring_start), r < shape[0//2]))] = np.interp(values_to_correct,oldvalues, smoothed_curve) + + #correction_image[r>=int((tile.shape[0]//2)*dark_ring_start)].flatten() = interp1d(correction_image[r>=int((tile.shape[0]//2)*dark_ring_start)].flatten(), + tile = tile / correction_image + tile = resize(tile, tile_binned_dimensions, anti_aliasing=True) - mask_float = resize(mask_float, tile_binned_dimensions, anti_aliasing=True) + mask_float = resize(mask_float, tile_binned_dimensions, anti_aliasing=False) + mask_float = mask_float > 0.5 + mask_float = mask_float.astype(np.float32) + mask_float *= 1.0 #print(f"Resizing took {time.perf_counter() - prev} seconds") prev = time.perf_counter() insertion_slice = ( @@ -307,6 +342,9 @@ def create_montage(montage_metadata: dict, output_path_montage: Path, erode_mask ), ) tile *= mask_float + + + existing_mask = 1.0 - mask_montage[insertion_slice] tile *= existing_mask mask_float *= existing_mask @@ -371,7 +409,7 @@ def calculate_shifts(row_pairs: list, num_proc: int = 1, erode_mask: int = 0, fi # map the worker function to the input data using the pool results = pool.imap_unordered( - partial(determine_shift_by_cc, erode_mask=erode_mask, filter_cutoff_frequency_ratio=filter_cutoff_frequency_ratio, filter_order=filter_order,mask_size_cutoff=mask_size_cutoff, overlap_ratio=overlap_ratio), row_pairs + partial(determine_shift_by_cc2, erode_mask=erode_mask, filter_cutoff_frequency_ratio=filter_cutoff_frequency_ratio, filter_order=filter_order,mask_size_cutoff=mask_size_cutoff, overlap_ratio=overlap_ratio), row_pairs ) shifts = [] # use the rich.progress module to track the progress of the results @@ -394,7 +432,8 @@ def determine_shift_by_cc( filter_order=4.0, mask_size_cutoff: int = 100, overlap_ratio: float = 0.1, - debug: bool = True, + debug: bool = False, + debug_object = {} ): import time # Create the montage @@ -410,6 +449,8 @@ def determine_shift_by_cc( order=filter_order, high_pass=False, ) + if debug: + debug_object["reference"] = reference.copy() with mrcfile.open(im2["tile_filename"]) as mrc: moving = np.copy(mrc.data[0]) moving = filters.butterworth( @@ -418,6 +459,8 @@ def determine_shift_by_cc( order=filter_order, high_pass=False, ) + if debug: + debug_object["moving"] = moving.copy() #print(f"Loading images took {time.perf_counter() - prev} seconds") prev = time.perf_counter() diff = ( @@ -426,6 +469,8 @@ def determine_shift_by_cc( ) tform = transform.SimilarityTransform(translation=(diff[0], diff[1])).inverse moving = transform.warp(moving, tform, output_shape=reference.shape) + if debug: + debug_object["moving_moved"] = moving.copy() #print(f"Transforming images took {time.perf_counter() - prev} seconds") prev = time.perf_counter() with mrcfile.open(im1["tile_mask_filename"]) as mrc: @@ -447,6 +492,8 @@ def determine_shift_by_cc( prev = time.perf_counter() moving_mask = transform.warp(moving_mask, tform, output_shape=reference_mask.shape) mask = np.minimum(reference_mask, moving_mask) > 0.9 + if debug: + debug_object["mask"] = mask.copy() if np.sum(mask) < mask_size_cutoff: return None reference *= mask @@ -463,10 +510,7 @@ def determine_shift_by_cc( overlap_ratio=overlap_ratio, ) if debug: - with mrcfile.new( - f"debug_{Path(im1['tile_filename']).name}_vs_{Path(im2['tile_filename']).name}.mrc", overwrite=True - ) as mrc: - mrc.set_data(xcorr) + debug_object["xcorr"] = xcorr.copy() #print(f"Cross took {time.perf_counter() - prev} seconds") prev = time.perf_counter() # Generalize to the average of multiple equal maxima @@ -493,6 +537,118 @@ def determine_shift_by_cc( "image_2": im2["tile_filename"], } +def determine_shift_by_cc2( + doubled, + erode_mask: float = 0, + filter_cutoff_frequency_ratio: float = 0.02, + filter_order=4.0, + mask_size_cutoff: int = 100, + overlap_ratio: float = 0.1, + debug: bool = False, + debug_object = {} +): + # Given the infow of two images, calculate the refined relative shifts by crosscorrelation return + im1, im2 = doubled + + # Open the masks + with mrcfile.open(im1["tile_mask_filename"]) as mrc: + reference_mask = np.copy(mrc.data[0]) + reference_mask.dtype = np.uint8 + reference_mask = reference_mask / 255.0 + with mrcfile.open(im2["tile_mask_filename"]) as mrc: + moving_mask = np.copy(mrc.data[0]) + moving_mask.dtype = np.uint8 + moving_mask = moving_mask / 255.0 + + if erode_mask > 0: + reference_mask = reference_mask > 0.5 + moving_mask = moving_mask > 0.5 + reference_mask = binary_erosion(reference_mask, iterations=erode_mask) + moving_mask = binary_erosion(moving_mask, iterations=erode_mask) + + # Transform mask2 + diff = ( + int(im2["tile_image_shift_pixel_x"] - im1["tile_image_shift_pixel_x"]), + int(im2["tile_image_shift_pixel_y"] - im1["tile_image_shift_pixel_y"]), + ) + tform = transform.SimilarityTransform(translation=(diff[0], diff[1])).inverse + moving_mask = transform.warp(moving_mask, tform, output_shape=reference_mask.shape) + mask = np.minimum(reference_mask, moving_mask) > 0.9 + if np.sum(mask) < mask_size_cutoff: + return None + # Get bounding box of the mask + bbox = np.array([np.min(np.nonzero(mask)[0]), np.max(np.nonzero(mask)[0]), np.min(np.nonzero(mask)[1]), np.max(np.nonzero(mask)[1])]) + # Calculate bounding box for moving + bbox_moving = np.array([bbox[0] - diff[1], bbox[1] - diff[1], bbox[2] - diff[0], bbox[3] - diff[0]]) + mask = mask[bbox[0]:bbox[1], bbox[2]:bbox[3]] + + + with mrcfile.open(im1["tile_filename"]) as mrc: + reference = np.copy(mrc.data[0]) + # Cut out the bounding box + reference = reference[bbox[0]:bbox[1], bbox[2]:bbox[3]] + reference = filters.butterworth( + reference, + cutoff_frequency_ratio=filter_cutoff_frequency_ratio, + order=filter_order, + high_pass=False, + ) + if debug: + debug_object["reference"] = reference.copy() + with mrcfile.open(im2["tile_filename"]) as mrc: + moving = np.copy(mrc.data[0]) + # Cut out the bounding box + moving = moving[bbox_moving[0]:bbox_moving[1], bbox_moving[2]:bbox_moving[3]] + moving = filters.butterworth( + moving, + cutoff_frequency_ratio=filter_cutoff_frequency_ratio, + order=filter_order, + high_pass=False, + ) + if debug: + debug_object["moving"] = moving.copy() + + + + + + reference *= mask + moving *= mask + + xcorr = cross_correlate_masked( + moving, + reference, + mask, + mask, + axes=tuple(range(moving.ndim)), + mode="full", + overlap_ratio=overlap_ratio, + ) + if debug: + debug_object["xcorr"] = xcorr.copy() + #print(f"Cross took {time.perf_counter() - prev} seconds") + # Generalize to the average of multiple equal maxima + maxima = np.stack(np.nonzero(xcorr == xcorr.max()), axis=1) + center = np.mean(maxima, axis=0) + shift = center - np.array(reference.shape) + 1 + shift = -shift + + with np.errstate(all="raise"): + try: + ratio = np.sum(reference) / np.sum(moving) + except FloatingPointError: + ratio = 1 + return { + "shift_x": diff[0] + shift[1], + "shift_y": diff[1] + shift[0], + "initial_area": np.sum(mask), + "max_cc": xcorr.max(), + "add_shift": np.linalg.norm(shift), + "int_ratio": ratio, + "image_1": im1["tile_filename"], + "image_2": im2["tile_filename"], + } + def _position_residuals(is_pixel, index_image_1, index_image_2, shifts): distance = (is_pixel[index_image_2] - is_pixel[index_image_1]) - shifts diff --git a/src/decolace/processing/match_filtering.py b/src/decolace/processing/match_filtering.py index 2cfde7e..ecd4015 100644 --- a/src/decolace/processing/match_filtering.py +++ b/src/decolace/processing/match_filtering.py @@ -29,7 +29,9 @@ def get_distance_to_edge(orig_image_filename,refined_matches,binning_boxsize): pixel_position_x, ] except IndexError: + print(f"IndexError for {orig_image_filename}") refined_matches.loc[i, "LACEBeamEdgeDistance"] = 0 + #refined_matches["LACEBeamEdgeDistance"] = refined_matches["LACEBeamEdgeDistance"].fillna(0) # Compute variance after binning with mrcfile.open(image_filename) as image: micrograph = image.data diff --git a/src/decolace/processing/project_managment.py b/src/decolace/processing/project_managment.py index 73224ca..5cf7edc 100644 --- a/src/decolace/processing/project_managment.py +++ b/src/decolace/processing/project_managment.py @@ -62,13 +62,18 @@ class MatchTemplateRun(BaseModel): symmetry: str = "C1" +class RunProfileCommand(BaseModel): + pass +class RunProfile(BaseModel): + pass class ProcessingProject(BaseModel): project_name: str project_path: Path processing_pixel_size: float = 2.0 + acquisition_areas: List[AcquisitionAreaPreProcessing] = [] match_template_runs: List[MatchTemplateRun] = []