diff --git a/src/robusta/core/model/base_params.py b/src/robusta/core/model/base_params.py index f0ebd1688..4ae29e382 100644 --- a/src/robusta/core/model/base_params.py +++ b/src/robusta/core/model/base_params.py @@ -109,6 +109,7 @@ class AIInvestigateParams(HolmesParams): runbooks: Optional[List[str]] ask: Optional[str] context: Optional[Dict[str, Any]] + stream: bool = False class HolmesToolsResult(BaseModel): diff --git a/src/robusta/core/model/events.py b/src/robusta/core/model/events.py index 9ec060e80..730dc2ea9 100644 --- a/src/robusta/core/model/events.py +++ b/src/robusta/core/model/events.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable from pydantic import BaseModel @@ -59,6 +59,7 @@ class ExecutionBaseEvent: _scheduler: Optional[PlaybooksScheduler] = None _context: Optional[ExecutionContext] = None _event_emitter: Optional[EventEmitter] = None + _ws: Optional[Callable[[str], None]] = None def set_context(self, context: ExecutionContext): self._context = context diff --git a/src/robusta/core/playbooks/internal/ai_integration.py b/src/robusta/core/playbooks/internal/ai_integration.py index 4c580b266..5a837e536 100644 --- a/src/robusta/core/playbooks/internal/ai_integration.py +++ b/src/robusta/core/playbooks/internal/ai_integration.py @@ -59,35 +59,43 @@ def ask_holmes(event: ExecutionBaseEvent, params: AIInvestigateParams): include_tool_calls=True, include_tool_call_results=True, ) - result = requests.post(f"{holmes_url}/api/investigate", data=holmes_req.json()) - result.raise_for_status() - holmes_result = HolmesResult(**json.loads(result.text)) - title_suffix = ( - f" on {params.resource.name}" - if params.resource and params.resource.name and params.resource.name.lower() != "unresolved" - else "" - ) - - kind = params.resource.kind if params.resource else None - finding = Finding( - title=f"AI Analysis of {investigation__title}{title_suffix}", - aggregation_key="HolmesInvestigationResult", - subject=FindingSubject( - name=params.resource.name if params.resource else "", - namespace=params.resource.namespace if params.resource else "", - subject_type=FindingSubjectType.from_kind(kind) if kind else FindingSubjectType.TYPE_NONE, - node=params.resource.node if params.resource else "", - container=params.resource.container if params.resource else "", - ), - finding_type=FindingType.AI_ANALYSIS, - failure=False, - ) - finding.add_enrichment( - [HolmesResultsBlock(holmes_result=holmes_result)], enrichment_type=EnrichmentType.ai_analysis - ) + if params.stream: + with requests.post(f"{holmes_url}/api/stream/investigate", data=holmes_req.json(), stream=True) as resp: + for line in resp.iter_content(chunk_size=None, decode_unicode=True): + event.ws(data=line) + return - event.add_finding(finding) + else: + result = requests.post(f"{holmes_url}/api/investigate", data=holmes_req.json()) + result.raise_for_status() + + holmes_result = HolmesResult(**json.loads(result.text)) + title_suffix = ( + f" on {params.resource.name}" + if params.resource and params.resource.name and params.resource.name.lower() != "unresolved" + else "" + ) + + kind = params.resource.kind if params.resource else None + finding = Finding( + title=f"AI Analysis of {investigation__title}{title_suffix}", + aggregation_key="HolmesInvestigationResult", + subject=FindingSubject( + name=params.resource.name if params.resource else "", + namespace=params.resource.namespace if params.resource else "", + subject_type=FindingSubjectType.from_kind(kind) if kind else FindingSubjectType.TYPE_NONE, + node=params.resource.node if params.resource else "", + container=params.resource.container if params.resource else "", + ), + finding_type=FindingType.AI_ANALYSIS, + failure=False, + ) + finding.add_enrichment( + [HolmesResultsBlock(holmes_result=holmes_result)], enrichment_type=EnrichmentType.ai_analysis + ) + + event.add_finding(finding) except Exception as e: logging.exception( diff --git a/src/robusta/core/playbooks/playbooks_event_handler.py b/src/robusta/core/playbooks/playbooks_event_handler.py index 2433689ac..fd0ee2e8e 100644 --- a/src/robusta/core/playbooks/playbooks_event_handler.py +++ b/src/robusta/core/playbooks/playbooks_event_handler.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from robusta.core.model.events import ExecutionBaseEvent from robusta.core.playbooks.base_trigger import TriggerEvent @@ -39,6 +39,13 @@ def run_external_action( """Execute an external action""" pass + @abstractmethod + def run_external_stream_action( + self, action_name: str, action_params: Optional[dict], stream: Callable[str, Optional[str]] + ) -> Optional[Dict[str, Any]]: + """Execute an external stream action""" + pass + @abstractmethod def get_global_config(self) -> dict: """Return runner global config""" diff --git a/src/robusta/core/playbooks/playbooks_event_handler_impl.py b/src/robusta/core/playbooks/playbooks_event_handler_impl.py index 412186d7e..53f7e3e1f 100644 --- a/src/robusta/core/playbooks/playbooks_event_handler_impl.py +++ b/src/robusta/core/playbooks/playbooks_event_handler_impl.py @@ -197,7 +197,7 @@ def __run_playbook_actions( start_time = time.time() source: str = ( "manual_action" - if any(name == SYNC_RESPONSE_SINK for name in getattr(execution_event, "named_sinks", [])) + if any(name == SYNC_RESPONSE_SINK for name in (execution_event.named_sinks or [])) else "" ) self.__prepare_execution_event(execution_event) @@ -368,3 +368,37 @@ def handle_sigint(self, sig, frame): self.set_cluster_active(False) sys.exit(0) + + def run_external_stream_action( + self, action_name: str, action_params: Optional[dict], ws + ) -> Optional[Dict[str, Any]]: + action_def = self.registry.get_actions().get_action(action_name) + if not action_def: + return self.__error_resp(f"External action not found {action_name}", ErrorCodes.ACTION_NOT_FOUND.value) + + if not action_def.from_params_func: + return self.__error_resp( + f"Action {action_name} cannot run using external event", ErrorCodes.NOT_EXTERNAL_ACTION.value + ) + + try: + instantiation_params = action_def.from_params_parameter_class(**action_params) + except Exception: + return self.__error_resp( + f"Failed to create execution instance for" + f" {action_name} {action_def.from_params_parameter_class}" + f" {action_params} {traceback.format_exc()}", + ErrorCodes.EVENT_PARAMS_INSTANTIATION_FAILED.value, + ) + + execution_event = action_def.from_params_func(instantiation_params) + if not execution_event: + return self.__error_resp( + f"Failed to create execution event for {action_name} {action_params}", + ErrorCodes.EVENT_INSTANTIATION_FAILED.value, + ) + + execution_event.ws = ws + playbook_action = PlaybookAction(action_name=action_name, action_params=action_params) + + return self.__run_playbook_actions(execution_event, [playbook_action]) diff --git a/src/robusta/core/reporting/action_requests.py b/src/robusta/core/reporting/action_requests.py index e6c02f0c6..e7002eb48 100644 --- a/src/robusta/core/reporting/action_requests.py +++ b/src/robusta/core/reporting/action_requests.py @@ -27,6 +27,7 @@ class ExternalActionRequest(BaseModel): partial_auth_b: str = "" # Auth for public key auth protocol option - should be added by the relay request_id: str = "" # If specified, should return a sync response using the specified request_id no_sinks: bool = False # Indicates not to send to sinks at all. The request body has a sink list, + stream: bool = False # however an empty sink list means using the server default sinks diff --git a/src/robusta/integrations/receiver.py b/src/robusta/integrations/receiver.py index e27c6d4ce..71d0a88c6 100644 --- a/src/robusta/integrations/receiver.py +++ b/src/robusta/integrations/receiver.py @@ -124,6 +124,12 @@ def stop(self): def __sync_response(cls, status_code: int, request_id: str, data) -> Dict: return {"action": "response", "request_id": request_id, "status_code": status_code, "data": data} + def __stream_response(self, request_id: str, data: str): + self.ws.send(data=json.dumps({"action": "stream", "request_id": request_id, "data": data})) + + def __close_stream_response(self, request_id: str, data: str): + self.ws.send(data=json.dumps({"action": "stream", "request_id": request_id, "data": data, "close": True})) + def __exec_external_request(self, action_request: ExternalActionRequest, validate_timestamp: bool): logging.debug(f"Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}") sync_response = action_request.request_id != "" # if request_id is set, we need to write back the response @@ -156,6 +162,23 @@ def __exec_external_request(self, action_request: ExternalActionRequest, validat http_code = 200 if response.get("success") else 500 self.ws.send(data=json.dumps(self.__sync_response(http_code, action_request.request_id, response))) + def __exec_external_stream_request(self, action_request: ExternalActionRequest, validate_timestamp: bool): + logging.debug(f"Callback `{action_request.body.action_name}` {to_safe_str(action_request.body.action_params)}") + + validation_response = self.__validate_request(action_request, validate_timestamp) + if validation_response.http_code != 200: + req_json = action_request.json(exclude={"body"}) + body_json = action_request.body.json(exclude={"action_params"}) # action params already printed above + logging.error(f"Failed to validate action request {req_json} {body_json}") + self.__close_stream_response(action_request.request_id, validation_response.dict(exclude={"http_code"})) + return + + res = self.event_handler.run_external_stream_action(action_request.body.action_name, + action_request.body.action_params, + lambda data: self.__stream_response(request_id=action_request.request_id, data=data)) + res = "" if res.get("success") else json.dumps(res) + self.__close_stream_response(action_request.request_id, res) + def _process_action(self, action: ExternalActionRequest, validate_timestamp: bool) -> None: self._executor.submit(self._process_action_sync, action, validate_timestamp) @@ -170,7 +193,10 @@ def _process_action_sync(self, action: ExternalActionRequest, validate_timestamp else: ctx = nullcontext() with ctx: - self.__exec_external_request(action, validate_timestamp) + if action.stream: + self.__exec_external_stream_request(action, validate_timestamp) + else: + self.__exec_external_request(action, validate_timestamp) except Exception: logging.error( f"Failed to run incoming event {self._stringify_incoming_event(action.dict())}",