diff --git a/grandcypher/__init__.py b/grandcypher/__init__.py index 63fc906..6628353 100644 --- a/grandcypher/__init__.py +++ b/grandcypher/__init__.py @@ -83,11 +83,12 @@ return_clause : "return"i distinct_return? return_item ("," return_item)* -return_item : entity_id | aggregation_function | entity_id "." attribute_id +return_item : (entity_id | aggregation_function | entity_id "." attribute_id) ( "AS"i alias )? aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")" AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN" attribute_id : CNAME +alias : CNAME distinct_return : "DISTINCT"i limit_clause : "limit"i NUMBER @@ -97,7 +98,7 @@ order_items : order_item ("," order_item)* -order_item : entity_id order_direction? +order_item : (entity_id | aggregation_function) order_direction? order_direction : "ASC"i -> asc | "DESC"i -> desc @@ -363,7 +364,7 @@ def inner( def _data_path_to_entity_name_attribute(data_path): - if not isinstance(data_path, str): + if isinstance(data_path, Token): data_path = data_path.value if "." in data_path: entity_name, entity_attribute = data_path.split(".") @@ -376,7 +377,9 @@ def _data_path_to_entity_name_attribute(data_path): class _GrandCypherTransformer(Transformer): def __init__(self, target_graph: nx.Graph, limit=None): - self._target_graph = target_graph + self._target_graph = nx.MultiDiGraph(target_graph) + self._entity2alias = dict() + self._alias2entity = dict() self._paths = [] self._where_condition: CONDITION = None self._motif = nx.MultiDiGraph() @@ -385,6 +388,7 @@ def __init__(self, target_graph: nx.Graph, limit=None): self._return_requests = [] self._return_edges = {} self._aggregate_functions = [] + self._aggregation_attributes = set() self._distinct = False self._order_by = None self._order_by_attributes = set() @@ -491,12 +495,15 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: ret_with_attr = [] for r in ret: r_attr = {} - for i, v in r.items(): - r_attr[(i, list(v.get("__labels__"))[0])] = v.get( - entity_attribute, None - ) - # eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}] - ret_with_attr.append(r_attr) + if isinstance(r, dict): + r = [r] + for el in r: + for i, v in el.items(): + r_attr[(i, list(v.get("__labels__", [i]))[0])] = v.get( + entity_attribute, None + ) + # eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}] + ret_with_attr.append(r_attr) ret = ret_with_attr @@ -508,23 +515,66 @@ def return_clause(self, clause): # collect all entity identifiers to be returned for item in clause: if item: + alias = self._extract_alias(item) item = item.children[0] if isinstance(item, Tree) else item if isinstance(item, Tree) and item.data == "aggregation_function": - func = str(item.children[0].value) # AGGREGATE_FUNC - entity = str(item.children[1].value) - if len(item.children) > 2: - entity += "." + str(item.children[2].children[0].value) + func, entity = self._parse_aggregation_token(item) + if alias: + self._entity2alias[self._format_aggregation_key(func, entity)] = alias + self._aggregation_attributes.add(entity) self._aggregate_functions.append((func, entity)) - self._return_requests.append(entity) else: if not isinstance(item, str): item = str(item.value) + + if alias: + self._entity2alias[item] = alias self._return_requests.append(item) + self._alias2entity.update({v: k for k, v in self._entity2alias.items()}) + + def _extract_alias(self, item: Tree): + ''' + Extract the alias from the return item (if it exists) + ''' + + if len(item.children) == 1: + return None + item_keys = [it.data if isinstance(it, Tree) else None for it in item.children] + if any(k == 'alias' for k in item_keys): + # get the index of the alias + alias_index = item_keys.index('alias') + return str(item.children[alias_index].children[0].value) + + return None + + def _parse_aggregation_token(self, item: Tree): + ''' + Parse the aggregation function token and return the function and entity + input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Token('CNAME', 'r'), Tree('attribute_id', [Token('CNAME', 'value')])]) + output: ('SUM', 'r.value') + ''' + func = str(item.children[0].value) # AGGREGATE_FUNC + entity = str(item.children[1].value) + if len(item.children) > 2: + entity += "." + str(item.children[2].children[0].value) + + return func, entity + + def _format_aggregation_key(self, func, entity): + return f"{func}({entity})" + def order_clause(self, order_clause): self._order_by = [] for item in order_clause[0].children: - field = str(item.children[0]) # assuming the field name is the first child + if isinstance(item.children[0], Tree) and item.children[0].data == "aggregation_function": + func, entity = self._parse_aggregation_token(item.children[0]) + field = self._format_aggregation_key(func, entity) + self._order_by_attributes.add(entity) + else: + field = str(item.children[0]) # assuming the field name is the first child + self._order_by_attributes.add(field) + # Default to 'ASC' if not specified if len(item.children) > 1 and str(item.children[1].data).lower() != "desc": direction = "ASC" @@ -532,7 +582,6 @@ def order_clause(self, order_clause): direction = "DESC" self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...] - self._order_by_attributes.add(field) def distinct_return(self, distinct): self._distinct = True @@ -616,8 +665,11 @@ def _collate_data(data, unique_labels, func): def returns(self, ignore_limit=False): + data_paths = self._return_requests + list(self._order_by_attributes) + list(self._aggregation_attributes) + # aliases should already be requested in their original form, so we will remove them for lookup + data_paths = [d for d in data_paths if d not in self._alias2entity] results = self._lookup( - self._return_requests + list(self._order_by_attributes), + data_paths, offset_limit=slice(0, None), ) if len(self._aggregate_functions) > 0: @@ -630,21 +682,25 @@ def returns(self, ignore_limit=False): aggregated_results = {} for func, entity in self._aggregate_functions: aggregated_data = self.aggregate(func, results, entity, group_keys) - func_key = f"{func}({entity})" + func_key = self._format_aggregation_key(func, entity) aggregated_results[func_key] = aggregated_data self._return_requests.append(func_key) results.update(aggregated_results) + + # update the results with the given alias(es) + results = {self._entity2alias.get(k, k): v for k, v in results.items()} + if self._order_by: results = self._apply_order_by(results) if self._distinct: results = self._apply_distinct(results) results = self._apply_pagination(results, ignore_limit) - # Exclude order-by-only attributes from the final results + # Only include keys that were asked for in `RETURN` in the final results results = { key: values for key, values in results.items() - if key in self._return_requests + if self._alias2entity.get(key, key) in self._return_requests } return results @@ -652,9 +708,8 @@ def returns(self, ignore_limit=False): def _apply_order_by(self, results): if self._order_by: sort_lists = [ - (results[field], direction) + (results[field], field, direction) for field, direction in self._order_by - if field in results ] if sort_lists: @@ -662,14 +717,40 @@ def _apply_order_by(self, results): indices = range( len(next(iter(results.values()))) ) # Safe because all lists are assumed to be of the same length - for sort_list, direction in reversed( + for (sort_list, field, direction) in reversed( sort_lists ): # reverse to ensure the first sort key is primary - indices = sorted( - indices, - key=lambda i: sort_list[i], - reverse=(direction == "DESC"), - ) + + if all(isinstance(item, dict) for item in sort_list): + # (for edge attributes) If all items in sort_list are dictionaries + # example: ([{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}], 'DESC') + + # sort within each edge first + sorted_sublists = [] + for sublist in sort_list: + sorted_sublist = sorted( + sublist.items(), + key=lambda x: x[1] or 0, # 0 if `None` + reverse=(direction == "DESC"), + ) + sorted_sublists.append({k: v for k, v in sorted_sublist}) + sort_list = sorted_sublists + + # then sort the indices based on the sorted sublists + indices = sorted( + indices, + key=lambda i: list(sort_list[i].values())[0] or 0, # 0 if `None` + reverse=(direction == "DESC"), + ) + # update results with sorted edge attributes list + results[field] = sort_list + else: + # (for node attributes) single values + indices = sorted( + indices, + key=lambda i: sort_list[i], + reverse=(direction == "DESC"), + ) # Reorder all lists in results using sorted indices for key in results: diff --git a/grandcypher/test_queries.py b/grandcypher/test_queries.py index afea528..2fdca30 100644 --- a/grandcypher/test_queries.py +++ b/grandcypher/test_queries.py @@ -769,6 +769,124 @@ def test_order_by_single_field_no_direction_provided(self, graph_type): res = GrandCypher(host).run(qry) assert res["n.name"] == ["Carol", "Alice", "Bob"] + def test_order_by_edge_attribute1(self): + host = nx.DiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("a", "c", __labels__={"paid"}, value=4) + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, r.value, m.name + ORDER BY r.value ASC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Alice', 'Alice', 'Bob'] + assert res['m.name'] == ['Carol', 'Bob', 'Alice'] + assert res['r.value'] == [{(0, 'paid'): 4}, {(0, 'paid'): 9}, {(0, 'paid'): 14}] + + qry = """ + MATCH (n)-[r]->() + RETURN n.name, r.value + ORDER BY r.value DESC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Bob', 'Alice', 'Alice'] + assert res['r.value'] == [{(0, 'paid'): 14}, {(0, 'paid'): 9}, {(0, 'paid'): 4}] + + + def test_order_by_edge_attribute2(self): + host = nx.DiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, amount=14) # different attribute name + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("c", "b", __labels__={"paid"}, value=980) + host.add_edge("b", "c", __labels__={"paid"}, value=11) + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, r.value, m.name + ORDER BY r.value ASC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Bob', 'Alice', 'Bob', 'Carol'] + assert res['r.value'] == [ + {(0, 'paid'): None}, # None for the different attribute edge + {(0, 'paid'): 9}, # within edges, the attributes are ordered + {(0, 'paid'): 11}, + {(0, 'paid'): 980} + ] + assert res['m.name'] == ['Alice', 'Bob', 'Carol', 'Bob'] + + def test_order_by_aggregation_function(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=96) + host.add_edge("a", "b", __labels__={"paid"}, value=40) + + # SUM + qry = """ + MATCH (n)-[r]->() + RETURN n.name, SUM(r.value) + ORDER BY SUM(r.value) ASC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Bob', 'Alice'] + assert res['SUM(r.value)'] == [{'paid': 14}, {'paid': 49}] + + # AVG + qry = """ + MATCH (n)-[r]->() + RETURN n.name, AVG(r.value), r.value + ORDER BY AVG(r.value) DESC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Alice', 'Bob'] + assert res['AVG(r.value)'] == [{'paid': 16.333333333333332}, {'paid': 14.0}] + assert res['r.value'] == [{(0, 'paid'): 9, (1, 'paid'): None, (2, 'paid'): 40}, {(0, 'paid'): 14}] + + # MIN, MAX, and COUNT + qry = """ + MATCH (n)-[r]->() + RETURN n.name, MIN(r.value), MAX(r.value), COUNT(r.value) + ORDER BY MAX(r.value) DESC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Alice', 'Bob'] + assert res['MIN(r.value)'] == [{'paid': 9}, {'paid': 14}] + assert res['MAX(r.value)'] == [{'paid': 40}, {'paid': 14}] + assert res['COUNT(r.value)'] == [{'paid': 3}, {'paid': 1}] + + + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_order_by_aggregation_fails_if_not_requested_in_return(self, graph_type): + host = graph_type() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=96) + host.add_edge("a", "b", __labels__={"paid"}, value=40) + + qry = """ + MATCH (n)-[r]->() + RETURN n.name, r.value + ORDER BY SUM(r.value) ASC + """ + with pytest.raises(Exception): + GrandCypher(host).run(qry) + + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) def test_order_by_multiple_fields(self, graph_type): host = graph_type() @@ -1031,6 +1149,62 @@ def test_multigraph_multiple_same_edge_labels(self): # the second "paid" edge between Bob -> Alice has no "amount" attribute, so it should be None assert res["r.amount"] == [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}] + def test_order_by_edge_attribute1(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("a", "b", __labels__={"paid"}, value=40) + + qry = """ + MATCH (n)-[r]->() + RETURN n.name, r.value + ORDER BY r.value ASC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Alice', 'Bob'] + assert res['r.value'] == [{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}] + + qry = """ + MATCH (n)-[r]->() + RETURN n.name, r.value + ORDER BY r.value DESC + """ + res = GrandCypher(host).run(qry) + assert res['n.name'] == ['Alice', 'Bob'] + assert res['r.value'] == [{(1, 'paid'): 40, (0, 'paid'): 9}, {(0, 'paid'): 14}] + + def test_order_by_edge_attribute2(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, amount=14) # different attribute name + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("c", "b", __labels__={"paid"}, value=980) + host.add_edge("c", "b", __labels__={"paid"}, value=4) + host.add_edge("b", "c", __labels__={"paid"}, value=11) + host.add_edge("a", "b", __labels__={"paid"}, value=40) + host.add_edge("b", "a", __labels__={"paid"}, value=14) # duplicate edge + host.add_edge("a", "b", __labels__={"paid"}, value=9) # duplicate edge + host.add_edge("a", "b", __labels__={"paid"}, value=40) # duplicate edge + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, r.value, m.name + ORDER BY r.value ASC + """ + res = GrandCypher(host).run(qry) + assert res['r.value'] == [ + {(0, 'paid'): None, (1, 'paid'): 14}, # None for the different attribute edge + {(1, 'paid'): 4, (0, 'paid'): 980}, # within edges, the attributes are ordered + {(0, 'paid'): 9, (2, 'paid'): 9, (1, 'paid'): 40, (3, 'paid'): 40}, + {(0, 'paid'): 11} + ] + assert res['m.name'] == ['Alice', 'Bob', 'Bob', 'Carol'] + def test_multigraph_aggregation_function_sum(self): host = nx.MultiDiGraph() host.add_node("a", name="Alice", age=25) @@ -1139,6 +1313,79 @@ def test_multigraph_multiple_aggregation_functions(self): assert res["SUM(r.amount)"] == [{'paid': 52}, {'paid': 6}] +class TestAlias: + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_alias_with_single_variable_length_relationship(self, graph_type): + host = graph_type() + host.add_node("x", foo=12) + host.add_node("y", foo=13) + host.add_node("z", foo=16) + host.add_edge("x", "y", bar="1") + host.add_edge("y", "z", bar="2") + host.add_edge("z", "x", bar="3") + + qry = """ + MATCH (A)-[r*0]->(B) + RETURN A AS ayy, B AS bee, r + """ + + res = GrandCypher(host).run(qry) + assert len(res) == 3 + assert res["ayy"] == ["x", "y", "z"] + assert res["bee"] == ["x", "y", "z"] + assert res["r"] == [[None], [None], [None]] + + qry = """ + MATCH (A)-[r*1]->(B) + RETURN A, B, r AS arr + """ + + res = GrandCypher(host).run(qry) + assert len(res) == 3 + assert res["A"] == ["x", "y", "z"] + assert res["B"] == ["y", "z", "x"] + assert graph_type in ACCEPTED_GRAPH_TYPES + assert res["arr"] == [[{0: {'bar': '1'}}], [{0: {'bar': '2'}}], [{0: {'bar': '3'}}]] + + def test_alias_with_order_by(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Carol", age=20) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"paid"}, value=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=96) + host.add_edge("a", "b", __labels__={"paid"}, value=40) + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, AVG(r.value) AS average, m.name, r.value + ORDER BY average ASC + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ['Bob', 'Alice'] + assert res["m.name"] == ['Alice', 'Bob'] + assert res['r.value'] == [ + {(0, 'paid'): 14}, + {(0, 'paid'): 9, (1, 'paid'): None, (2, 'paid'): 40} + ] + assert res["average"] == [{'paid': 14.0}, {'paid': 16.333333333333332}] + + qry = """ + MATCH (n)-[r]->(m) + RETURN n.name, m.name, AVG(r.value) AS total, r.value as myvalue + ORDER BY myvalue ASC + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ['Alice', 'Bob'] + assert res["m.name"] == ['Bob', 'Alice'] + assert res['total'] == [{'paid': 16.333333333333332}, {'paid': 14.0}] + assert res["myvalue"] == [ + {(1, 'paid'): None, (0, 'paid'): 9, (2, 'paid'): 40}, + {(0, 'paid'): 14} + ] + + class TestVariableLengthRelationship: @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) def test_single_variable_length_relationship(self, graph_type): @@ -1171,11 +1418,7 @@ def test_single_variable_length_relationship(self, graph_type): assert res["A"] == ["x", "y", "z"] assert res["B"] == ["y", "z", "x"] assert graph_type in ACCEPTED_GRAPH_TYPES - if graph_type is nx.DiGraph: - assert res["r"] == [[{"bar": "1"}], [{"bar": "2"}], [{"bar": "3"}]] - elif graph_type is nx.MultiDiGraph: - # MultiDiGraphs return a list of dictionaries to accommodate multiple edges between nodes - assert res["r"] == [[{0: {'bar': '1'}}], [{0: {'bar': '2'}}], [{0: {'bar': '3'}}]] + assert res["r"] == [[{0: {'bar': '1'}}], [{0: {'bar': '2'}}], [{0: {'bar': '3'}}]] qry = """ MATCH (A)-[r*2]->(B) @@ -1187,18 +1430,11 @@ def test_single_variable_length_relationship(self, graph_type): assert res["A"] == ["x", "y", "z"] assert res["B"] == ["z", "x", "y"] assert graph_type in ACCEPTED_GRAPH_TYPES - if graph_type is nx.DiGraph: - assert res["r"] == [ - [{"bar": "1"}, {"bar": "2"}], - [{"bar": "2"}, {"bar": "3"}], - [{"bar": "3"}, {"bar": "1"}], - ] - elif graph_type is nx.MultiGraph: - assert res["r"] == [ - [{0: {'bar': '1'}}, {1: {'bar': '2'}}], - [{0: {'bar': '2'}}, {1: {'bar': '3'}}], - [{0: {'bar': '3'}}, {1: {'bar': '1'}}], - ] + assert res["r"] == [ + [{0: {'bar': '1'}}, {0: {'bar': '2'}}], + [{0: {'bar': '2'}}, {0: {'bar': '3'}}], + [{0: {'bar': '3'}}, {0: {'bar': '1'}}] + ] @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) def test_complex_variable_length_relationship(self, graph_type): @@ -1220,28 +1456,15 @@ def test_complex_variable_length_relationship(self, graph_type): assert res["A"] == ["x", "y", "z", "x", "y", "z", "x", "y", "z"] assert res["B"] == ["x", "y", "z", "y", "z", "x", "z", "x", "y"] assert graph_type in ACCEPTED_GRAPH_TYPES - if graph_type is nx.DiGraph: - assert res["r"] == [ - [None], - [None], - [None], - [{"bar": "1"}], - [{"bar": "2"}], - [{"bar": "3"}], - [{"bar": "1"}, {"bar": "2"}], - [{"bar": "2"}, {"bar": "3"}], - [{"bar": "3"}, {"bar": "1"}], - ] - elif graph_type is nx.MultiDiGraph: - assert res["r"] == [ - [None], [None], [None], - [{0: {'bar': '1'}}], - [{0: {'bar': '2'}}], - [{0: {'bar': '3'}}], - [{0: {'bar': '1'}}, {0: {'bar': '2'}}], - [{0: {'bar': '2'}}, {0: {'bar': '3'}}], - [{0: {'bar': '3'}}, {0: {'bar': '1'}}] - ] + assert res["r"] == [ + [None], [None], [None], + [{0: {'bar': '1'}}], + [{0: {'bar': '2'}}], + [{0: {'bar': '3'}}], + [{0: {'bar': '1'}}, {0: {'bar': '2'}}], + [{0: {'bar': '2'}}, {0: {'bar': '3'}}], + [{0: {'bar': '3'}}, {0: {'bar': '1'}}] + ] class TestType: @@ -1347,30 +1570,17 @@ def test_edge_type_hop(self, graph_type): assert res["A"] == ["x", "y", "z", "x", "y", "z", "x", "y", "z"] assert res["B"] == ["x", "y", "z", "y", "z", "x", "z", "x", "y"] assert graph_type in ACCEPTED_GRAPH_TYPES - if graph_type is nx.DiGraph: - assert res["r"] == [ - [None], - [None], - [None], - [{"__labels__": {"Edge", "XY"}}], - [{"__labels__": {"Edge", "YZ"}}], - [{"__labels__": {"Edge", "ZX"}}], - [{"__labels__": {"Edge", "XY"}}, {"__labels__": {"Edge", "YZ"}}], - [{"__labels__": {"Edge", "YZ"}}, {"__labels__": {"Edge", "ZX"}}], - [{"__labels__": {"Edge", "ZX"}}, {"__labels__": {"Edge", "XY"}}], - ] - elif graph_type is nx.MultiDiGraph: - assert res["r"] == [ - [None], - [None], - [None], - [{0: {'__labels__': {'Edge', 'XY'}}}], - [{0: {'__labels__': {'Edge', 'YZ'}}}], - [{0: {'__labels__': {'Edge', 'ZX'}}}], - [{0: {'__labels__': {'Edge', 'XY'}}}, {0: {'__labels__': {'Edge', 'YZ'}}}], - [{0: {'__labels__': {'Edge', 'YZ'}}}, {0: {'__labels__': {'Edge', 'ZX'}}}], - [{0: {'__labels__': {'Edge', 'ZX'}}}, {0: {'__labels__': {'Edge', 'XY'}}}] - ] + assert res["r"] == [ + [None], + [None], + [None], + [{0: {'__labels__': {'Edge', 'XY'}}}], + [{0: {'__labels__': {'Edge', 'YZ'}}}], + [{0: {'__labels__': {'Edge', 'ZX'}}}], + [{0: {'__labels__': {'Edge', 'XY'}}}, {0: {'__labels__': {'Edge', 'YZ'}}}], + [{0: {'__labels__': {'Edge', 'YZ'}}}, {0: {'__labels__': {'Edge', 'ZX'}}}], + [{0: {'__labels__': {'Edge', 'ZX'}}}, {0: {'__labels__': {'Edge', 'XY'}}}] + ] @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) def test_host_no_node_type(self, graph_type):