diff --git a/backend/infrahub/core/validators/relationship/count.py b/backend/infrahub/core/validators/relationship/count.py index 85ab9f1d29..08044413c4 100644 --- a/backend/infrahub/core/validators/relationship/count.py +++ b/backend/infrahub/core/validators/relationship/count.py @@ -34,20 +34,22 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No branch_filter, branch_params = self.branch.get_query_filter_path(at=self.at.to_string(), is_isolated=False) self.params.update(branch_params) - self.params["node_kind"] = self.node_schema.kind self.params["relationship_id"] = self.relationship_schema.identifier + self.params["relationship_direction"] = self.relationship_schema.direction.value self.params["min_count"] = ( self.min_count_override if self.min_count_override is not None else self.relationship_schema.min_count ) - self.params["max_count"] = ( - self.max_count_override if self.max_count_override is not None else self.relationship_schema.max_count - ) + max_count: int | None = self.relationship_schema.max_count + if self.max_count_override: + max_count = self.max_count_override + if max_count == 0: + max_count = None + self.params["max_count"] = max_count # ruff: noqa: E501 query = """ // get the nodes on these branches nodes - MATCH (n:Node) - WHERE $node_kind IN LABELS(n) + MATCH (n:%(node_kind)s) CALL { WITH n MATCH path = (root:Root)<-[rroot:IS_PART_OF]-(n) @@ -64,7 +66,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No CALL { WITH active_node MATCH path = (active_node)-[rrel1:IS_RELATED]-(rel:Relationship { name: $relationship_id })-[rrel2:IS_RELATED]-(peer:Node) - WHERE all( + WHERE ($relationship_direction <> "outbound" OR (startNode(rrel1) = active_node AND startNode(rrel2) = rel)) + AND ($relationship_direction <> "inbound" OR (startNode(rrel1) = rel AND startNode(rrel2) = peer)) + AND all( r in relationships(path) WHERE (%(branch_filter)s) ) @@ -115,7 +119,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No } // return a row for each node-branch combination with a count for that branch UNWIND violation_branches_and_counts as violation_branch_and_count - """ % {"branch_filter": branch_filter} + """ % {"branch_filter": branch_filter, "node_kind": self.node_schema.kind} self.add_to_query(query) self.return_labels = [ diff --git a/backend/tests/unit/core/constraint_validators/test_relationship_count.py b/backend/tests/unit/core/constraint_validators/test_relationship_count.py index ed92a51f3f..bcbf362255 100644 --- a/backend/tests/unit/core/constraint_validators/test_relationship_count.py +++ b/backend/tests/unit/core/constraint_validators/test_relationship_count.py @@ -48,7 +48,7 @@ async def test_query_failure_cardinality_one( ): person_schema = registry.schema.get(name="TestPerson") cars_rel = person_schema.get_relationship(name="cars") - cars_rel.cardinality = RelationshipCardinality.ONE + cars_rel.max_count = 1 schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="TestPerson", field_name="cars") query = await RelationshipCountUpdateValidatorQuery.init( @@ -351,6 +351,88 @@ async def test_query_delete_on_branch_success( assert len(all_paths) == 0 +async def test_hierarchical_success(db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data_simple): + site_schema = registry.schema.get(name="LocationSite", duplicate=False) + + schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="LocationSite", field_name="parent") + query = await RelationshipCountUpdateValidatorQuery.init( + db=db, branch=default_branch, node_schema=site_schema, schema_path=schema_path + ) + + await query.execute(db=db) + + grouped_paths = await query.get_paths() + all_paths = grouped_paths.get_all_data_paths() + assert len(all_paths) == 0 + + +async def test_hierarchical_failure(db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data_simple): + paris_site = hierarchical_location_data_simple["paris"] + branch = await create_branch(branch_name=str("branch2"), db=db) + schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="LocationSite", field_name="children") + site_schema = registry.schema.get(name="LocationSite", branch=branch, duplicate=False) + + # check no violations to start with + query = await RelationshipCountUpdateValidatorQuery.init( + db=db, branch=branch, node_schema=site_schema, schema_path=schema_path + ) + + await query.execute(db=db) + grouped_paths = await query.get_paths() + all_paths = grouped_paths.get_all_data_paths() + assert len(all_paths) == 0 + + # add a violation + branch_rack = await NodeManager.get_one(db=db, branch=branch, id=paris_site.id) + extra_rack = await Node.init(db=db, branch=branch, schema="LocationRack") + await extra_rack.new(db=db, name="extra_rack", parent=branch_rack, status="online") + await extra_rack.save(db=db) + child_rel = site_schema.get_relationship(name="children") + child_rel.max_count = 2 + + query = await RelationshipCountUpdateValidatorQuery.init( + db=db, branch=branch, node_schema=site_schema, schema_path=schema_path + ) + await query.execute(db=db) + grouped_paths = await query.get_paths() + all_paths = grouped_paths.get_all_data_paths() + assert len(all_paths) == 2 + assert ( + DataPath( + branch=branch.name, + path_type=PathType.NODE, + node_id=paris_site.id, + kind="LocationSite", + field_name="children", + value=1, + ) + in all_paths + ) + assert ( + DataPath( + branch=default_branch.name, + path_type=PathType.NODE, + node_id=paris_site.id, + kind="LocationSite", + field_name="children", + value=2, + ) + in all_paths + ) + + # remove violation + branch_rack = await NodeManager.get_one(db=db, branch=branch, id=extra_rack.id) + await branch_rack.delete(db=db) + + query = await RelationshipCountUpdateValidatorQuery.init( + db=db, branch=branch, node_schema=site_schema, schema_path=schema_path + ) + await query.execute(db=db) + grouped_paths = await query.get_paths() + all_paths = grouped_paths.get_all_data_paths() + assert len(all_paths) == 0 + + async def test_validator( db: InfrahubDatabase, branch: Branch,