Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate RequestDiffUpdate/RequestDiffRefresh to prefect #4733

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pydantic import Field
from pydantic import BaseModel, Field

from infrahub.message_bus import InfrahubMessage


class RequestDiffUpdate(InfrahubMessage):
class RequestDiffUpdate(BaseModel):
"""
Request diff to be updated.

Expand All @@ -15,3 +13,10 @@ class RequestDiffUpdate(InfrahubMessage):
name: str | None = None
from_time: str | None = None
to_time: str | None = None


class RequestDiffRefresh(BaseModel):
"""Request diff be recalculated from scratch."""

branch_name: str = Field(..., description="The branch associated with the diff")
diff_id: str = Field(..., description="The id for this diff")
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,39 @@

from infrahub.core import registry
from infrahub.core.diff.coordinator import DiffCoordinator
from infrahub.core.diff.models import RequestDiffRefresh, RequestDiffUpdate
from infrahub.dependencies.registry import get_component_registry
from infrahub.log import get_logger
from infrahub.message_bus import messages
from infrahub.services import InfrahubServices
from infrahub.services import services

log = get_logger()


@flow(name="diff-update")
async def update(message: messages.RequestDiffUpdate, service: InfrahubServices) -> None:
async def update_diff(model: RequestDiffUpdate) -> None:
service = services.service
component_registry = get_component_registry()
base_branch = await registry.get_branch(db=service.database, branch=registry.default_branch)
diff_branch = await registry.get_branch(db=service.database, branch=message.branch_name)
diff_branch = await registry.get_branch(db=service.database, branch=model.branch_name)

diff_coordinator = await component_registry.get_component(DiffCoordinator, db=service.database, branch=diff_branch)

await diff_coordinator.run_update(
base_branch=base_branch,
diff_branch=diff_branch,
from_time=message.from_time,
to_time=message.to_time,
name=message.name,
from_time=model.from_time,
to_time=model.to_time,
name=model.name,
)


@flow(name="diff-refresh")
async def refresh(message: messages.RequestDiffRefresh, service: InfrahubServices) -> None:
async def refresh_diff(model: RequestDiffRefresh) -> None:
service = services.service

component_registry = get_component_registry()
base_branch = await registry.get_branch(db=service.database, branch=registry.default_branch)
diff_branch = await registry.get_branch(db=service.database, branch=message.branch_name)
diff_branch = await registry.get_branch(db=service.database, branch=model.branch_name)

diff_coordinator = await component_registry.get_component(DiffCoordinator, db=service.database, branch=diff_branch)
await diff_coordinator.recalculate(base_branch=base_branch, diff_branch=diff_branch, diff_id=message.diff_id)
await diff_coordinator.recalculate(base_branch=base_branch, diff_branch=diff_branch, diff_id=model.diff_id)
4 changes: 2 additions & 2 deletions backend/infrahub/graphql/mutations/artifact_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing_extensions import Self

from infrahub.core.schema import NodeSchema
from infrahub.git.models import RequestArtifactDefinitionGenerate
from infrahub.log import get_logger
from infrahub.workflows.catalogue import REQUEST_ARTIFACT_DEFINITION_GENERATE

from ...git.models import RequestArtifactDefinitionGenerate
from ...workflows.catalogue import REQUEST_ARTIFACT_DEFINITION_GENERATE
from .main import InfrahubMutationMixin, InfrahubMutationOptions

if TYPE_CHECKING:
Expand Down
7 changes: 4 additions & 3 deletions backend/infrahub/graphql/mutations/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from infrahub.core import registry
from infrahub.core.diff.coordinator import DiffCoordinator
from infrahub.core.diff.models import RequestDiffUpdate
from infrahub.dependencies.registry import get_component_registry
from infrahub.message_bus import messages
from infrahub.workflows.catalogue import REQUEST_DIFF_UPDATE

if TYPE_CHECKING:
from ..initialization import GraphqlContext
Expand Down Expand Up @@ -55,13 +56,13 @@ async def mutate(

return {"ok": True}

message = messages.RequestDiffUpdate(
model = RequestDiffUpdate(
branch_name=str(data.branch),
name=data.name,
from_time=from_timestamp_str,
to_time=to_timestamp_str,
)
if context.service:
await context.service.send(message=message)
await context.service.workflow.submit_workflow(workflow=REQUEST_DIFF_UPDATE, parameters={"model": model})

return {"ok": True}
4 changes: 0 additions & 4 deletions backend/infrahub/message_bus/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
from .refresh_registry_rebasedbranch import RefreshRegistryRebasedBranch
from .refresh_webhook_configuration import RefreshWebhookConfiguration
from .request_artifactdefinition_check import RequestArtifactDefinitionCheck
from .request_diff_refresh import RequestDiffRefresh
from .request_diff_update import RequestDiffUpdate
from .request_generatordefinition_check import RequestGeneratorDefinitionCheck
from .request_generatordefinition_run import RequestGeneratorDefinitionRun
from .request_graphqlquerygroup_update import RequestGraphQLQueryGroupUpdate
Expand Down Expand Up @@ -76,8 +74,6 @@
"refresh.registry.rebased_branch": RefreshRegistryRebasedBranch,
"refresh.webhook.configuration": RefreshWebhookConfiguration,
"request.artifact_definition.check": RequestArtifactDefinitionCheck,
"request.diff.update": RequestDiffUpdate,
"request.diff.refresh": RequestDiffRefresh,
"request.generator_definition.check": RequestGeneratorDefinitionCheck,
"request.generator_definition.run": RequestGeneratorDefinitionRun,
"request.graphql_query_group.update": RequestGraphQLQueryGroupUpdate,
Expand Down
2 changes: 0 additions & 2 deletions backend/infrahub/message_bus/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
"refresh.registry.branches": refresh.registry.branches,
"refresh.registry.rebased_branch": refresh.registry.rebased_branch,
"refresh.webhook.configuration": refresh.webhook.configuration,
"request.diff.refresh": requests.diff.refresh,
"request.diff.update": requests.diff.update,
"request.generator_definition.check": requests.generator_definition.check,
"request.generator_definition.run": requests.generator_definition.run,
"request.graphql_query_group.update": requests.graphql_query_group.update,
Expand Down
15 changes: 13 additions & 2 deletions backend/infrahub/message_bus/operations/event/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from infrahub.core import registry
from infrahub.core.diff.model.path import BranchTrackingId
from infrahub.core.diff.models import RequestDiffRefresh, RequestDiffUpdate
from infrahub.core.diff.repository.repository import DiffRepository
from infrahub.dependencies.registry import get_component_registry
from infrahub.log import get_logger
from infrahub.message_bus import InfrahubMessage, messages
from infrahub.services import InfrahubServices
from infrahub.workflows.catalogue import (
GIT_REPOSITORIES_CREATE_BRANCH,
REQUEST_DIFF_REFRESH,
REQUEST_DIFF_UPDATE,
TRIGGER_ARTIFACT_DEFINITION_GENERATE,
)

Expand Down Expand Up @@ -72,7 +75,10 @@ async def merge(message: messages.EventBranchMerge, service: InfrahubServices) -
and diff_root.tracking_id
and isinstance(diff_root.tracking_id, BranchTrackingId)
):
events.append(messages.RequestDiffUpdate(branch_name=diff_root.diff_branch_name))
request_diff_update_model = RequestDiffUpdate(branch_name=diff_root.diff_branch_name)
await service.workflow.submit_workflow(
workflow=REQUEST_DIFF_UPDATE, parameters={"model": request_diff_update_model}
)

for event in events:
event.assign_meta(parent=message)
Expand All @@ -95,7 +101,12 @@ async def rebased(message: messages.EventBranchRebased, service: InfrahubService

for diff_root in diff_roots_to_refresh:
if diff_root.base_branch_name != diff_root.diff_branch_name:
events.append(messages.RequestDiffRefresh(branch_name=diff_root.diff_branch_name, diff_id=diff_root.uuid))
request_diff_refresh_model = RequestDiffRefresh(
branch_name=diff_root.diff_branch_name, diff_id=diff_root.uuid
)
await service.workflow.submit_workflow(
workflow=REQUEST_DIFF_REFRESH, parameters={"model": request_diff_refresh_model}
)

for event in events:
event.assign_meta(parent=message)
Expand Down
2 changes: 0 additions & 2 deletions backend/infrahub/message_bus/operations/requests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from . import (
artifact_definition,
diff,
generator_definition,
graphql_query_group,
proposed_change,
Expand All @@ -9,7 +8,6 @@

__all__ = [
"artifact_definition",
"diff",
"generator_definition",
"graphql_query_group",
"proposed_change",
Expand Down
16 changes: 16 additions & 0 deletions backend/infrahub/workflows/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@
function="generate_request_artifact_definition",
)

REQUEST_DIFF_UPDATE = WorkflowDefinition(
name="diff-update",
type=WorkflowType.INTERNAL,
module="infrahub.core.diff.tasks",
function="update_diff",
)

REQUEST_DIFF_REFRESH = WorkflowDefinition(
name="diff-refresh",
type=WorkflowType.INTERNAL,
module="infrahub.core.diff.tasks",
function="refresh_diff",
)

GIT_REPOSITORIES_SYNC = WorkflowDefinition(
name="git_repositories_sync",
type=WorkflowType.INTERNAL,
Expand Down Expand Up @@ -144,4 +158,6 @@
BRANCH_MERGE,
REQUEST_ARTIFACT_DEFINITION_GENERATE,
REQUEST_GENERATOR_RUN,
REQUEST_DIFF_UPDATE,
REQUEST_DIFF_REFRESH,
]
45 changes: 32 additions & 13 deletions backend/tests/unit/message_bus/operations/event/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from infrahub.core.branch import Branch
from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot
from infrahub.core.diff.models import RequestDiffRefresh, RequestDiffUpdate
from infrahub.core.diff.repository.repository import DiffRepository
from infrahub.core.timestamp import Timestamp
from infrahub.dependencies.component.registry import ComponentDependencyRegistry
from infrahub.message_bus import messages
from infrahub.message_bus.operations.event.branch import delete, merge, rebased
from infrahub.services import InfrahubServices, services
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
from infrahub.workflows.catalogue import TRIGGER_ARTIFACT_DEFINITION_GENERATE
from infrahub.workflows.catalogue import REQUEST_DIFF_REFRESH, REQUEST_DIFF_UPDATE, TRIGGER_ARTIFACT_DEFINITION_GENERATE
from tests.adapters.message_bus import BusRecorder


Expand Down Expand Up @@ -100,6 +101,14 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr

expected_calls = [
call(workflow=TRIGGER_ARTIFACT_DEFINITION_GENERATE, parameters={"branch": message.target_branch}),
call(
workflow=REQUEST_DIFF_UPDATE,
parameters={"model": RequestDiffUpdate(branch_name=tracked_diff_roots[0].diff_branch_name)},
),
call(
workflow=REQUEST_DIFF_UPDATE,
parameters={"model": RequestDiffUpdate(branch_name=tracked_diff_roots[1].diff_branch_name)},
),
]
mock_submit_workflow.assert_has_calls(expected_calls)
assert mock_submit_workflow.call_count == len(expected_calls)
Expand All @@ -109,15 +118,9 @@ async def test_merged(default_branch: Branch, init_service: InfrahubServices, pr
)
diff_repo.get_empty_roots.assert_awaited_once_with(base_branch_names=[target_branch_name])

assert len(service.message_bus.messages) == 4
assert len(service.message_bus.messages) == 2
assert service.message_bus.messages[0] == messages.RefreshRegistryBranches()
assert service.message_bus.messages[1] == messages.TriggerGeneratorDefinitionRun(branch=target_branch_name)
assert service.message_bus.messages[2] == messages.RequestDiffUpdate(
branch_name=tracked_diff_roots[0].diff_branch_name
)
assert service.message_bus.messages[3] == messages.RequestDiffUpdate(
branch_name=tracked_diff_roots[1].diff_branch_name
)


async def test_rebased(default_branch: Branch, prefect_test_fixture):
Expand All @@ -128,7 +131,7 @@ async def test_rebased(default_branch: Branch, prefect_test_fixture):

recorder = BusRecorder()
database = MagicMock()
service = InfrahubServices(message_bus=recorder, database=database)
service = InfrahubServices(message_bus=recorder, database=database, workflow=WorkflowLocalExecution())
diff_roots = [
EnrichedDiffRoot(
base_branch_name="main",
Expand All @@ -146,14 +149,30 @@ async def test_rebased(default_branch: Branch, prefect_test_fixture):
mock_get_component_registry = MagicMock(return_value=mock_component_registry)
mock_component_registry.get_component.return_value = diff_repo

with patch("infrahub.message_bus.operations.event.branch.get_component_registry", new=mock_get_component_registry):
with (
patch("infrahub.message_bus.operations.event.branch.get_component_registry", new=mock_get_component_registry),
patch(
"infrahub.services.adapters.workflow.local.WorkflowLocalExecution.submit_workflow"
) as mock_submit_workflow,
):
await rebased(message=message, service=service)

expected_calls = [
call(
workflow=REQUEST_DIFF_REFRESH,
parameters={"model": RequestDiffRefresh(branch_name=branch_name, diff_id=diff_roots[0].uuid)},
),
call(
workflow=REQUEST_DIFF_REFRESH,
parameters={"model": RequestDiffRefresh(branch_name=branch_name, diff_id=diff_roots[1].uuid)},
),
]
mock_submit_workflow.assert_has_calls(expected_calls)
assert mock_submit_workflow.call_count == len(expected_calls)

mock_component_registry.get_component.assert_awaited_once_with(DiffRepository, db=database, branch=default_branch)
diff_repo.get_empty_roots.assert_awaited_once_with(diff_branch_names=[branch_name])
assert len(recorder.messages) == 3
assert len(recorder.messages) == 1
assert isinstance(recorder.messages[0], messages.RefreshRegistryRebasedBranch)
refresh_message: messages.RefreshRegistryRebasedBranch = recorder.messages[0]
assert refresh_message.branch == "cr1234"
assert recorder.messages[1] == messages.RequestDiffRefresh(branch_name=branch_name, diff_id=diff_roots[0].uuid)
assert recorder.messages[2] == messages.RequestDiffRefresh(branch_name=branch_name, diff_id=diff_roots[1].uuid)
Loading
Loading