Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFC-841 fixes for relationship count constraint validator #4743

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions backend/infrahub/core/validators/relationship/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No
branch_filter, branch_params = self.branch.get_query_filter_path(at=self.at.to_string(), is_isolated=False)
self.params.update(branch_params)

self.params["node_kind"] = self.node_schema.kind
self.params["relationship_id"] = self.relationship_schema.identifier
self.params["relationship_direction"] = self.relationship_schema.direction.value
self.params["min_count"] = (
self.min_count_override if self.min_count_override is not None else self.relationship_schema.min_count
)
self.params["max_count"] = (
self.max_count_override if self.max_count_override is not None else self.relationship_schema.max_count
)
max_count: int | None = self.relationship_schema.max_count
if self.max_count_override:
max_count = self.max_count_override
if max_count == 0:
max_count = None
self.params["max_count"] = max_count

# ruff: noqa: E501
query = """
// get the nodes on these branches nodes
MATCH (n:Node)
WHERE $node_kind IN LABELS(n)
MATCH (n:%(node_kind)s)
CALL {
WITH n
MATCH path = (root:Root)<-[rroot:IS_PART_OF]-(n)
Expand All @@ -64,7 +66,9 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No
CALL {
WITH active_node
MATCH path = (active_node)-[rrel1:IS_RELATED]-(rel:Relationship { name: $relationship_id })-[rrel2:IS_RELATED]-(peer:Node)
WHERE all(
WHERE ($relationship_direction <> "outbound" OR (startNode(rrel1) = active_node AND startNode(rrel2) = rel))
AND ($relationship_direction <> "inbound" OR (startNode(rrel1) = rel AND startNode(rrel2) = peer))
AND all(
r in relationships(path)
WHERE (%(branch_filter)s)
)
Expand Down Expand Up @@ -115,7 +119,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No
}
// return a row for each node-branch combination with a count for that branch
UNWIND violation_branches_and_counts as violation_branch_and_count
""" % {"branch_filter": branch_filter}
""" % {"branch_filter": branch_filter, "node_kind": self.node_schema.kind}

self.add_to_query(query)
self.return_labels = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_query_failure_cardinality_one(
):
person_schema = registry.schema.get(name="TestPerson")
cars_rel = person_schema.get_relationship(name="cars")
cars_rel.cardinality = RelationshipCardinality.ONE
cars_rel.max_count = 1

schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="TestPerson", field_name="cars")
query = await RelationshipCountUpdateValidatorQuery.init(
Expand Down Expand Up @@ -351,6 +351,88 @@ async def test_query_delete_on_branch_success(
assert len(all_paths) == 0


async def test_hierarchical_success(db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data_simple):
site_schema = registry.schema.get(name="LocationSite", duplicate=False)

schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="LocationSite", field_name="parent")
query = await RelationshipCountUpdateValidatorQuery.init(
db=db, branch=default_branch, node_schema=site_schema, schema_path=schema_path
)

await query.execute(db=db)

grouped_paths = await query.get_paths()
all_paths = grouped_paths.get_all_data_paths()
assert len(all_paths) == 0


async def test_hierarchical_failure(db: InfrahubDatabase, default_branch: Branch, hierarchical_location_data_simple):
paris_site = hierarchical_location_data_simple["paris"]
branch = await create_branch(branch_name=str("branch2"), db=db)
schema_path = SchemaPath(path_type=SchemaPathType.RELATIONSHIP, schema_kind="LocationSite", field_name="children")
site_schema = registry.schema.get(name="LocationSite", branch=branch, duplicate=False)

# check no violations to start with
query = await RelationshipCountUpdateValidatorQuery.init(
db=db, branch=branch, node_schema=site_schema, schema_path=schema_path
)

await query.execute(db=db)
grouped_paths = await query.get_paths()
all_paths = grouped_paths.get_all_data_paths()
assert len(all_paths) == 0

# add a violation
branch_rack = await NodeManager.get_one(db=db, branch=branch, id=paris_site.id)
extra_rack = await Node.init(db=db, branch=branch, schema="LocationRack")
await extra_rack.new(db=db, name="extra_rack", parent=branch_rack, status="online")
await extra_rack.save(db=db)
child_rel = site_schema.get_relationship(name="children")
child_rel.max_count = 2

query = await RelationshipCountUpdateValidatorQuery.init(
db=db, branch=branch, node_schema=site_schema, schema_path=schema_path
)
await query.execute(db=db)
grouped_paths = await query.get_paths()
all_paths = grouped_paths.get_all_data_paths()
assert len(all_paths) == 2
assert (
DataPath(
branch=branch.name,
path_type=PathType.NODE,
node_id=paris_site.id,
kind="LocationSite",
field_name="children",
value=1,
)
in all_paths
)
assert (
DataPath(
branch=default_branch.name,
path_type=PathType.NODE,
node_id=paris_site.id,
kind="LocationSite",
field_name="children",
value=2,
)
in all_paths
)

# remove violation
branch_rack = await NodeManager.get_one(db=db, branch=branch, id=extra_rack.id)
await branch_rack.delete(db=db)

query = await RelationshipCountUpdateValidatorQuery.init(
db=db, branch=branch, node_schema=site_schema, schema_path=schema_path
)
await query.execute(db=db)
grouped_paths = await query.get_paths()
all_paths = grouped_paths.get_all_data_paths()
assert len(all_paths) == 0


async def test_validator(
db: InfrahubDatabase,
branch: Branch,
Expand Down
Loading