From d2d9d447feaeb106adb7a6df4ad600b7059fbf76 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 17 Jul 2024 16:16:21 +0000 Subject: [PATCH] fix(edits/split): filter out inactive cross edges AT EACH LAYER --- pychunkedgraph/__init__.py | 2 +- pychunkedgraph/graph/chunkedgraph.py | 14 +++++--------- pychunkedgraph/graph/edges/utils.py | 22 +++++++++++++++++++++- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pychunkedgraph/__init__.py b/pychunkedgraph/__init__.py index 8d1c8625f..528787cfc 100644 --- a/pychunkedgraph/__init__.py +++ b/pychunkedgraph/__init__.py @@ -1 +1 @@ -__version__ = "3.0.3" +__version__ = "3.0.0" diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 1836094f0..7823695db 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -3,6 +3,8 @@ import time import typing import datetime +from itertools import chain +from functools import reduce import numpy as np from pychunkedgraph import __version__ @@ -667,8 +669,6 @@ def get_l2_agglomerations( Children of Level 2 Node IDs and edges. Edges are read from cloud storage. """ - from itertools import chain - from functools import reduce from .misc import get_agglomerations chunk_ids = np.unique(self.get_chunk_ids_from_node_ids(level2_ids)) @@ -708,13 +708,9 @@ def get_l2_agglomerations( sv_parent_d.update(dict(zip(svs.tolist(), [l2id] * len(svs)))) if active: - n1, n2 = all_chunk_edges.node_ids1, all_chunk_edges.node_ids2 - layers = self.get_cross_chunk_edges_layer(all_chunk_edges.get_pairs()) - max_layer = np.max(layers) + 1 - parents1 = self.get_roots(n1, stop_layer=max_layer, time_stamp=time_stamp) - parents2 = self.get_roots(n2, stop_layer=max_layer, time_stamp=time_stamp) - mask = parents1 == parents2 - all_chunk_edges = all_chunk_edges[mask] + all_chunk_edges = edge_utils.filter_inactive_cross_edges( + self, all_chunk_edges, time_stamp=time_stamp + ) in_edges, out_edges, cross_edges = edge_utils.categorize_edges_v2( self.meta, all_chunk_edges, sv_parent_d diff --git a/pychunkedgraph/graph/edges/utils.py b/pychunkedgraph/graph/edges/utils.py index cd0e85fe8..76f8ea1d8 100644 --- a/pychunkedgraph/graph/edges/utils.py +++ b/pychunkedgraph/graph/edges/utils.py @@ -9,6 +9,7 @@ from typing import Iterable from typing import Optional from collections import defaultdict +from functools import reduce import fastremap import numpy as np @@ -46,7 +47,9 @@ def concatenate_chunk_edges(chunk_edge_dicts: Iterable) -> Dict: return edges_dict -def concatenate_cross_edge_dicts(edges_ds: Iterable[Dict], unique: bool = False) -> Dict: +def concatenate_cross_edge_dicts( + edges_ds: Iterable[Dict], unique: bool = False +) -> Dict: """Combines cross chunk edge dicts of form {layer id : edge list}.""" result_d = defaultdict(list) for edges_d in edges_ds: @@ -182,3 +185,20 @@ def get_edges_status(cg, edges: Iterable, time_stamp: Optional[float] = None): active_status.extend(mask) active_status = np.array(active_status, dtype=bool) return existence_status, active_status + + +def filter_inactive_cross_edges( + cg, all_chunk_edges: Edges, time_stamp: Optional[float] = None +): + result = [] + layers = cg.get_cross_chunk_edges_layer(all_chunk_edges.get_pairs()) + for layer in np.unique(layers): + layer_mask = layers == layer + parent_layer = layer + 1 + layer_edges = all_chunk_edges[layer_mask] + n1, n2 = layer_edges.node_ids1, layer_edges.node_ids2 + parents1 = cg.get_roots(n1, stop_layer=parent_layer, time_stamp=time_stamp) + parents2 = cg.get_roots(n2, stop_layer=parent_layer, time_stamp=time_stamp) + mask = parents1 == parents2 + result.append(layer_edges[mask]) + return reduce(lambda x, y: x + y, result, Edges([], []))