diff --git a/go.mod b/go.mod index b30095519..38070c77f 100644 --- a/go.mod +++ b/go.mod @@ -213,3 +213,5 @@ replace github.com/robfig/cron/v3 => github.com/unionai/cron/v3 v3.0.2-0.2022091 // Retracted versions // This was published in error when attempting to create 1.5.1 Flyte release. retract v1.1.94 + +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.5.17-0.20230821222808-485ae223c7f1 diff --git a/go.sum b/go.sum index ed4b6ac85..1300e184b 100644 --- a/go.sum +++ b/go.sum @@ -293,8 +293,8 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.14 h1:+3ewipoOp82fPyIVgvvrMq1lorl5Kz3Lh6sh/a9+loI= -github.com/flyteorg/flyteidl v1.5.14/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.17-0.20230821222808-485ae223c7f1 h1:B0OlEJujyYKWVkgJ7cqsIVbhaRGwMgdToCIAau8JmAY= +github.com/flyteorg/flyteidl v1.5.17-0.20230821222808-485ae223c7f1/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.0.67 h1:d2FXpwxQwX/k4YdmhuusykOemHb/cUTPEob4WBmdpjE= github.com/flyteorg/flyteplugins v1.0.67/go.mod h1:HHt4nKDKVwrZPKDsj99dNtDSIJL378xNotYMA3a/TFA= github.com/flyteorg/flytepropeller v1.1.98 h1:Zk2ENYB9VZRT5tFUIFjm+aCkr0TU2EuyJ5gh52fpLoA= diff --git a/pkg/async/notifications/email.go b/pkg/async/notifications/email.go index 063acf335..7abdcb507 100644 --- a/pkg/async/notifications/email.go +++ b/pkg/async/notifications/email.go @@ -101,7 +101,7 @@ var getTemplateValueFuncs = map[string]GetTemplateValue{ launchPlanVersion: getLaunchPlanVersion, } -func substituteEmailParameters(message string, request admin.WorkflowExecutionEventRequest, execution *admin.Execution) string { +func SubstituteParameters(message string, request admin.WorkflowExecutionEventRequest, execution *admin.Execution) string { for template, function := range getTemplateValueFuncs { message = strings.Replace(message, fmt.Sprintf(substitutionParam, template), function(request, execution), replaceAllInstances) message = strings.Replace(message, fmt.Sprintf(substitutionParamNoSpaces, template), function(request, execution), replaceAllInstances) @@ -118,9 +118,9 @@ func ToEmailMessageFromWorkflowExecutionEvent( execution *admin.Execution) *admin.EmailMessage { return &admin.EmailMessage{ - SubjectLine: substituteEmailParameters(config.NotificationsEmailerConfig.Subject, request, execution), + SubjectLine: SubstituteParameters(config.NotificationsEmailerConfig.Subject, request, execution), SenderEmail: config.NotificationsEmailerConfig.Sender, RecipientsEmail: emailNotification.GetRecipientsEmail(), - Body: substituteEmailParameters(config.NotificationsEmailerConfig.Body, request, execution), + Body: SubstituteParameters(config.NotificationsEmailerConfig.Body, request, execution), } } diff --git a/pkg/async/notifications/email_test.go b/pkg/async/notifications/email_test.go index 612cc4adf..0829c5e33 100644 --- a/pkg/async/notifications/email_test.go +++ b/pkg/async/notifications/email_test.go @@ -59,14 +59,14 @@ func TestSubstituteEmailParameters(t *testing.T) { }, } assert.Equal(t, "{{ unused }}. {{project }} and prod and e124 ended up in succeeded.", - substituteEmailParameters(message, request, workflowExecution)) + SubstituteParameters(message, request, workflowExecution)) request.Event.OutputResult = &event.WorkflowExecutionEvent_Error{ Error: &core.ExecutionError{ Message: "uh-oh", }, } assert.Equal(t, "{{ unused }}. {{project }} and prod and e124 ended up in succeeded. The execution failed with error: [uh-oh].", - substituteEmailParameters(message, request, workflowExecution)) + SubstituteParameters(message, request, workflowExecution)) } func TestSubstituteAllTemplates(t *testing.T) { @@ -95,7 +95,7 @@ func TestSubstituteAllTemplates(t *testing.T) { }, } assert.Equal(t, strings.Join(desiredResult, ","), - substituteEmailParameters(strings.Join(messageTemplate, ","), request, workflowExecution)) + SubstituteParameters(strings.Join(messageTemplate, ","), request, workflowExecution)) } func TestSubstituteAllTemplatesNoSpaces(t *testing.T) { @@ -124,7 +124,7 @@ func TestSubstituteAllTemplatesNoSpaces(t *testing.T) { }, } assert.Equal(t, strings.Join(desiredResult, ","), - substituteEmailParameters(strings.Join(messageTemplate, ","), request, workflowExecution)) + SubstituteParameters(strings.Join(messageTemplate, ","), request, workflowExecution)) } func TestToEmailMessageFromWorkflowExecutionEvent(t *testing.T) { diff --git a/pkg/async/notifications/implementations/aws_processor.go b/pkg/async/notifications/implementations/aws_processor.go index 2cbf6d971..1800c854e 100644 --- a/pkg/async/notifications/implementations/aws_processor.go +++ b/pkg/async/notifications/implementations/aws_processor.go @@ -2,8 +2,6 @@ package implementations import ( "context" - "encoding/base64" - "encoding/json" "time" "github.com/NYTimes/gizmo/pubsub" @@ -17,12 +15,11 @@ import ( // TODO: Add a counter that encompasses the publisher stats grouped by project and domain. type Processor struct { - sub pubsub.Subscriber - email interfaces.Emailer - systemMetrics processorSystemMetrics + email interfaces.Emailer + interfaces.BaseProcessor } -// Currently only email is the supported notification because slack and pagerduty both use +// StartProcessing Currently only email is the supported notification because slack and pagerduty both use // email client to trigger those notifications. // When Pagerduty and other notifications are supported, a publisher per type should be created. func (p *Processor) StartProcessing() { @@ -37,74 +34,37 @@ func (p *Processor) StartProcessing() { func (p *Processor) run() error { var emailMessage admin.EmailMessage var err error - for msg := range p.sub.Start() { - p.systemMetrics.MessageTotal.Inc() - // Currently this is safe because Gizmo takes a string and casts it to a byte array. + for msg := range p.Sub.Start() { + p.SystemMetrics.MessageTotal.Inc() stringMsg := string(msg.Message()) - var snsJSONFormat map[string]interface{} - - // At Lyft, SNS populates SQS. This results in the message body of SQS having the SNS message format. - // The message format is documented here: https://docs.aws.amazon.com/sns/latest/dg/sns-message-and-json-formats.html - // The notification published is stored in the message field after unmarshalling the SQS message. - if err := json.Unmarshal(msg.Message(), &snsJSONFormat); err != nil { - p.systemMetrics.MessageDecodingError.Inc() - logger.Errorf(context.Background(), "failed to unmarshall JSON message [%s] from processor with err: %v", stringMsg, err) - p.markMessageDone(msg) - continue - } - - var value interface{} - var ok bool - var valueString string - - if value, ok = snsJSONFormat["Message"]; !ok { - logger.Errorf(context.Background(), "failed to retrieve message from unmarshalled JSON object [%s]", stringMsg) - p.systemMetrics.MessageDataError.Inc() - p.markMessageDone(msg) + _, messageByte, ok := p.FromSQSMessage(msg) + if !ok { continue } - if valueString, ok = value.(string); !ok { - p.systemMetrics.MessageDataError.Inc() - logger.Errorf(context.Background(), "failed to retrieve notification message (in string format) from unmarshalled JSON object for message [%s]", stringMsg) - p.markMessageDone(msg) - continue - } - - // The Publish method for SNS Encodes the notification using Base64 then stringifies it before - // setting that as the message body for SNS. Do the inverse to retrieve the notification. - notificationBytes, err := base64.StdEncoding.DecodeString(valueString) - if err != nil { - logger.Errorf(context.Background(), "failed to Base64 decode from message string [%s] from message [%s] with err: %v", valueString, stringMsg, err) - p.systemMetrics.MessageDecodingError.Inc() - p.markMessageDone(msg) - continue - } - - if err = proto.Unmarshal(notificationBytes, &emailMessage); err != nil { - logger.Debugf(context.Background(), "failed to unmarshal to notification object from decoded string[%s] from message [%s] with err: %v", valueString, stringMsg, err) - p.systemMetrics.MessageDecodingError.Inc() - p.markMessageDone(msg) + if err = proto.Unmarshal(messageByte, &emailMessage); err != nil { + logger.Debugf(context.Background(), "failed to unmarshal to notification object from decoded string from message [%s] with err: %v", stringMsg, err) + p.SystemMetrics.MessageDecodingError.Inc() + p.MarkMessageDone(msg) continue } if err = p.email.SendEmail(context.Background(), emailMessage); err != nil { - p.systemMetrics.MessageProcessorError.Inc() + p.SystemMetrics.MessageProcessorError.Inc() logger.Errorf(context.Background(), "Error sending an email message for message [%s] with emailM with err: %v", emailMessage.String(), err) } else { - p.systemMetrics.MessageSuccess.Inc() + p.SystemMetrics.MessageSuccess.Inc() } - p.markMessageDone(msg) - + p.MarkMessageDone(msg) } // According to https://github.com/NYTimes/gizmo/blob/f2b3deec03175b11cdfb6642245a49722751357f/pubsub/pubsub.go#L36-L39, // the channel backing the subscriber will just close if there is an error. The call to Err() is needed to identify // there was an error in the channel or there are no more messages left (resulting in no errors when calling Err()). - if err = p.sub.Err(); err != nil { - p.systemMetrics.ChannelClosedError.Inc() + if err = p.Sub.Err(); err != nil { + p.SystemMetrics.ChannelClosedError.Inc() logger.Warningf(context.Background(), "The stream for the subscriber channel closed with err: %v", err) } @@ -112,27 +72,12 @@ func (p *Processor) run() error { return err } -func (p *Processor) markMessageDone(message pubsub.SubscriberMessage) { - if err := message.Done(); err != nil { - p.systemMetrics.MessageDoneError.Inc() - logger.Errorf(context.Background(), "failed to mark message as Done() in processor with err: %v", err) - } -} - -func (p *Processor) StopProcessing() error { - // Note: If the underlying channel is already closed, then Stop() will return an error. - err := p.sub.Stop() - if err != nil { - p.systemMetrics.StopError.Inc() - logger.Errorf(context.Background(), "Failed to stop the subscriber channel gracefully with err: %v", err) - } - return err -} - func NewProcessor(sub pubsub.Subscriber, emailer interfaces.Emailer, scope promutils.Scope) interfaces.Processor { return &Processor{ - sub: sub, - email: emailer, - systemMetrics: newProcessorSystemMetrics(scope.NewSubScope("processor")), + email: emailer, + BaseProcessor: interfaces.BaseProcessor{ + Sub: sub, + SystemMetrics: interfaces.NewProcessorSystemMetrics(scope.NewSubScope("processor")), + }, } } diff --git a/pkg/async/notifications/implementations/gcp_processor.go b/pkg/async/notifications/implementations/gcp_processor.go index 0ce49371e..a07e30e1c 100644 --- a/pkg/async/notifications/implementations/gcp_processor.go +++ b/pkg/async/notifications/implementations/gcp_processor.go @@ -15,16 +15,17 @@ import ( // TODO: Add a counter that encompasses the publisher stats grouped by project and domain. type GcpProcessor struct { - sub pubsub.Subscriber - email interfaces.Emailer - systemMetrics processorSystemMetrics + email interfaces.Emailer + interfaces.BaseProcessor } func NewGcpProcessor(sub pubsub.Subscriber, emailer interfaces.Emailer, scope promutils.Scope) interfaces.Processor { return &GcpProcessor{ - sub: sub, - email: emailer, - systemMetrics: newProcessorSystemMetrics(scope.NewSubScope("gcp_processor")), + email: emailer, + BaseProcessor: interfaces.BaseProcessor{ + Sub: sub, + SystemMetrics: interfaces.NewProcessorSystemMetrics(scope.NewSubScope("processor")), + }, } } @@ -40,52 +41,34 @@ func (p *GcpProcessor) StartProcessing() { func (p *GcpProcessor) run() error { var emailMessage admin.EmailMessage - for msg := range p.sub.Start() { - p.systemMetrics.MessageTotal.Inc() + for msg := range p.Sub.Start() { + p.SystemMetrics.MessageTotal.Inc() if err := proto.Unmarshal(msg.Message(), &emailMessage); err != nil { logger.Debugf(context.Background(), "failed to unmarshal to notification object message [%s] with err: %v", string(msg.Message()), err) - p.systemMetrics.MessageDecodingError.Inc() - p.markMessageDone(msg) + p.SystemMetrics.MessageDecodingError.Inc() + p.MarkMessageDone(msg) continue } if err := p.email.SendEmail(context.Background(), emailMessage); err != nil { - p.systemMetrics.MessageProcessorError.Inc() + p.SystemMetrics.MessageProcessorError.Inc() logger.Errorf(context.Background(), "Error sending an email message for message [%s] with emailM with err: %v", emailMessage.String(), err) } else { - p.systemMetrics.MessageSuccess.Inc() + p.SystemMetrics.MessageSuccess.Inc() } - p.markMessageDone(msg) + p.MarkMessageDone(msg) } // According to https://github.com/NYTimes/gizmo/blob/f2b3deec03175b11cdfb6642245a49722751357f/pubsub/pubsub.go#L36-L39, // the channel backing the subscriber will just close if there is an error. The call to Err() is needed to identify // there was an error in the channel or there are no more messages left (resulting in no errors when calling Err()). - if err := p.sub.Err(); err != nil { - p.systemMetrics.ChannelClosedError.Inc() + if err := p.Sub.Err(); err != nil { + p.SystemMetrics.ChannelClosedError.Inc() logger.Warningf(context.Background(), "The stream for the subscriber channel closed with err: %v", err) return err } return nil } - -func (p *GcpProcessor) markMessageDone(message pubsub.SubscriberMessage) { - if err := message.Done(); err != nil { - p.systemMetrics.MessageDoneError.Inc() - logger.Errorf(context.Background(), "failed to mark message as Done() in processor with err: %v", err) - } -} - -func (p *GcpProcessor) StopProcessing() error { - // Note: If the underlying channel is already closed, then Stop() will return an error. - if err := p.sub.Stop(); err != nil { - p.systemMetrics.StopError.Inc() - logger.Errorf(context.Background(), "Failed to stop the subscriber channel gracefully with err: %v", err) - return err - } - - return nil -} diff --git a/pkg/async/notifications/implementations/gcp_processor_test.go b/pkg/async/notifications/implementations/gcp_processor_test.go index 2aa216d29..92658c93b 100644 --- a/pkg/async/notifications/implementations/gcp_processor_test.go +++ b/pkg/async/notifications/implementations/gcp_processor_test.go @@ -45,7 +45,7 @@ func TestGcpProcessor_StartProcessing(t *testing.T) { // Check fornumber of messages processed. m := &dto.Metric{} - err := testGcpProcessor.(*GcpProcessor).systemMetrics.MessageSuccess.Write(m) + err := testGcpProcessor.(*GcpProcessor).SystemMetrics.MessageSuccess.Write(m) assert.Nil(t, err) assert.Equal(t, "counter: ", m.String()) } @@ -60,7 +60,7 @@ func TestGcpProcessor_StartProcessingNoMessages(t *testing.T) { // Check fornumber of messages processed. m := &dto.Metric{} - err := testGcpProcessor.(*GcpProcessor).systemMetrics.MessageSuccess.Write(m) + err := testGcpProcessor.(*GcpProcessor).SystemMetrics.MessageSuccess.Write(m) assert.Nil(t, err) assert.Equal(t, "counter: ", m.String()) } @@ -93,7 +93,7 @@ func TestGcpProcessor_StartProcessingEmailError(t *testing.T) { // Check for an email error stat. m := &dto.Metric{} - err := testGcpProcessor.(*GcpProcessor).systemMetrics.MessageProcessorError.Write(m) + err := testGcpProcessor.(*GcpProcessor).SystemMetrics.MessageProcessorError.Write(m) assert.Nil(t, err) assert.Equal(t, "counter: ", m.String()) } diff --git a/pkg/async/notifications/interfaces/processor.go b/pkg/async/notifications/interfaces/processor.go index b7ce30d56..27041998d 100644 --- a/pkg/async/notifications/interfaces/processor.go +++ b/pkg/async/notifications/interfaces/processor.go @@ -1,5 +1,15 @@ package interfaces +import ( + "context" + "encoding/base64" + "encoding/json" + + "github.com/NYTimes/gizmo/pubsub" + gizmoGCP "github.com/NYTimes/gizmo/pubsub/gcp" + "github.com/flyteorg/flytestdlib/logger" +) + // Exposes the common methods required for a subscriber. // There is one ProcessNotification per type. type Processor interface { @@ -15,3 +25,98 @@ type Processor interface { // the channel was already closed. StopProcessing() error } + +type BaseProcessor struct { + Sub pubsub.Subscriber + SystemMetrics ProcessorSystemMetrics +} + +// FromPubSubMessage Parse the message from GCP PubSub and return the message subject and the message body. +func (p *BaseProcessor) FromPubSubMessage(msg pubsub.SubscriberMessage) (string, []byte, bool) { + var gcpMsg *gizmoGCP.SubMessage + var ok bool + p.SystemMetrics.MessageTotal.Inc() + if gcpMsg, ok = msg.(*gizmoGCP.SubMessage); !ok { + logger.Errorf(context.Background(), "failed to cast message [%v] to gizmoGCP.SubMessage", msg) + p.SystemMetrics.MessageDataError.Inc() + p.MarkMessageDone(msg) + return "", nil, false + } + subject := gcpMsg.Attributes["key"] + return subject, gcpMsg.Message(), true +} + +// FromSQSMessage Parse the message from AWS SQS and return the message subject and the message body. +func (p *BaseProcessor) FromSQSMessage(msg pubsub.SubscriberMessage) (string, []byte, bool) { + // Currently this is safe because Gizmo takes a string and casts it to a byte array. + stringMsg := string(msg.Message()) + var snsJSONFormat map[string]interface{} + + if err := json.Unmarshal(msg.Message(), &snsJSONFormat); err != nil { + p.SystemMetrics.MessageDecodingError.Inc() + logger.Errorf(context.Background(), "failed to unmarshall JSON message [%s] from processor with err: %v", stringMsg, err) + p.MarkMessageDone(msg) + return "", nil, false + } + + var value interface{} + var ok bool + var valueString string + var subject string + + if value, ok = snsJSONFormat["Message"]; !ok { + logger.Errorf(context.Background(), "failed to retrieve message from unmarshalled JSON object [%s]", stringMsg) + p.SystemMetrics.MessageDataError.Inc() + p.MarkMessageDone(msg) + return "", nil, false + } + + if valueString, ok = value.(string); !ok { + p.SystemMetrics.MessageDataError.Inc() + logger.Errorf(context.Background(), "failed to retrieve notification message (in string format) from unmarshalled JSON object for message [%s]", stringMsg) + p.MarkMessageDone(msg) + return "", nil, false + } + + if value, ok = snsJSONFormat["Subject"]; !ok { + logger.Errorf(context.Background(), "failed to retrieve message type from unmarshalled JSON object [%s]", stringMsg) + p.SystemMetrics.MessageDataError.Inc() + p.MarkMessageDone(msg) + return "", nil, false + } + + if subject, ok = value.(string); !ok { + p.SystemMetrics.MessageDataError.Inc() + logger.Errorf(context.Background(), "failed to retrieve notification message type (in string format) from unmarshalled JSON object for message [%s]", stringMsg) + p.MarkMessageDone(msg) + return "", nil, false + } + + // The Publish method for SNS Encodes the notification using Base64 then stringifies it before + // setting that as the message body for SNS. Do the inverse to retrieve the notification. + messageBytes, err := base64.StdEncoding.DecodeString(valueString) + if err != nil { + logger.Errorf(context.Background(), "failed to Base64 decode from message string [%s] from message [%s] with err: %v", valueString, stringMsg, err) + p.SystemMetrics.MessageDecodingError.Inc() + p.MarkMessageDone(msg) + return "", nil, false + } + return subject, messageBytes, true +} + +func (p *BaseProcessor) StopProcessing() error { + // Note: If the underlying channel is already closed, then Stop() will return an error. + err := p.Sub.Stop() + if err != nil { + p.SystemMetrics.StopError.Inc() + logger.Errorf(context.Background(), "Failed to stop the subscriber channel gracefully with err: %v", err) + } + return err +} + +func (p *BaseProcessor) MarkMessageDone(message pubsub.SubscriberMessage) { + if err := message.Done(); err != nil { + p.SystemMetrics.MessageDoneError.Inc() + logger.Errorf(context.Background(), "failed to mark message as Done() in processor with err: %v", err) + } +} diff --git a/pkg/async/notifications/implementations/processor_metrics.go b/pkg/async/notifications/interfaces/processor_metrics.go similarity index 90% rename from pkg/async/notifications/implementations/processor_metrics.go rename to pkg/async/notifications/interfaces/processor_metrics.go index adc4219f9..ab796a2ca 100644 --- a/pkg/async/notifications/implementations/processor_metrics.go +++ b/pkg/async/notifications/interfaces/processor_metrics.go @@ -1,11 +1,11 @@ -package implementations +package interfaces import ( "github.com/flyteorg/flytestdlib/promutils" "github.com/prometheus/client_golang/prometheus" ) -type processorSystemMetrics struct { +type ProcessorSystemMetrics struct { Scope promutils.Scope MessageTotal prometheus.Counter MessageDoneError prometheus.Counter @@ -17,8 +17,8 @@ type processorSystemMetrics struct { StopError prometheus.Counter } -func newProcessorSystemMetrics(scope promutils.Scope) processorSystemMetrics { - return processorSystemMetrics{ +func NewProcessorSystemMetrics(scope promutils.Scope) ProcessorSystemMetrics { + return ProcessorSystemMetrics{ Scope: scope, MessageTotal: scope.MustNewCounter("message_total", "overall count of messages processed"), MessageDecodingError: scope.MustNewCounter("message_decoding_error", "count of messages with decoding errors"), diff --git a/pkg/async/webhook/factory.go b/pkg/async/webhook/factory.go new file mode 100644 index 000000000..fb38d35c0 --- /dev/null +++ b/pkg/async/webhook/factory.go @@ -0,0 +1,92 @@ +package webhook + +import ( + "context" + "fmt" + "time" + + gizmoGCP "github.com/NYTimes/gizmo/pubsub/gcp" + "github.com/flyteorg/flyteadmin/pkg/common" + + repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + + "github.com/NYTimes/gizmo/pubsub" + gizmoAWS "github.com/NYTimes/gizmo/pubsub/aws" + "github.com/flyteorg/flyteadmin/pkg/async" + notificationsImplementations "github.com/flyteorg/flyteadmin/pkg/async/notifications/implementations" + "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyteadmin/pkg/async/webhook/implementations" + "github.com/flyteorg/flytestdlib/logger" + + webhookInterfaces "github.com/flyteorg/flyteadmin/pkg/async/webhook/interfaces" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flytestdlib/promutils" +) + +var enable64decoding = false + +func GetWebhook(config runtimeInterfaces.WebHookConfig, scope promutils.Scope) webhookInterfaces.Webhook { + switch config.Name { + case implementations.Slack: + return implementations.NewSlackWebhook(config, scope) + default: + panic(fmt.Errorf("no matching webhook implementation for %s", config.Name)) + } +} + +func NewWebhookProcessors(db repoInterfaces.Repository, config runtimeInterfaces.WebhookNotificationsConfig, scope promutils.Scope) []interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + var sub pubsub.Subscriber + var processors []interfaces.Processor + var err error + + for _, cfg := range config.WebhooksConfig { + var processor interfaces.Processor + switch config.Type { + case common.AWS: + sqsConfig := gizmoAWS.SQSConfig{ + QueueName: cfg.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: cfg.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, + } + sqsConfig.Region = config.AWSConfig.Region + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + if err != nil { + logger.Errorf(context.TODO(), "Failed to initialize new gizmo aws subscriber with config [%+v] and err: %v", sqsConfig, err) + } + return err + }) + if err != nil { + panic(err) + } + processor = implementations.NewWebhookProcessor(common.AWS, sub, GetWebhook(cfg, scope), db, scope) + case common.GCP: + projectID := config.GCPConfig.ProjectID + subscription := cfg.NotificationsProcessorConfig.QueueName + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoGCP.NewSubscriber(context.TODO(), projectID, subscription) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo gcp subscriber with config [ProjectID: %s, Subscription: %s] and err: %v", projectID, subscription, err) + } + return err + }) + if err != nil { + panic(err) + } + processor = implementations.NewWebhookProcessor(common.GCP, sub, GetWebhook(cfg, scope), db, scope) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + processor = notificationsImplementations.NewNoopProcess() + } + processors = append(processors, processor) + } + return processors +} diff --git a/pkg/async/webhook/factory_test.go b/pkg/async/webhook/factory_test.go new file mode 100644 index 000000000..be078e35a --- /dev/null +++ b/pkg/async/webhook/factory_test.go @@ -0,0 +1,17 @@ +package webhook + +import ( + "testing" + + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + + "github.com/flyteorg/flyteadmin/pkg/async/webhook/implementations" + "github.com/flyteorg/flytestdlib/promutils" +) + +func TestGetWebhook(t *testing.T) { + cfg := runtimeInterfaces.WebHookConfig{ + Name: implementations.Slack, + } + GetWebhook(cfg, promutils.NewTestScope()) +} diff --git a/pkg/async/webhook/implementations/processer.go b/pkg/async/webhook/implementations/processer.go new file mode 100644 index 000000000..6a347e4cb --- /dev/null +++ b/pkg/async/webhook/implementations/processer.go @@ -0,0 +1,125 @@ +package implementations + +import ( + "context" + "time" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" + repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/golang/protobuf/proto" + + "github.com/NYTimes/gizmo/pubsub" + "github.com/flyteorg/flyteadmin/pkg/async" + "github.com/flyteorg/flyteadmin/pkg/async/notifications" + "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + webhookInterfaces "github.com/flyteorg/flyteadmin/pkg/async/webhook/interfaces" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" +) + +type Processor struct { + subType string + webhook webhookInterfaces.Webhook + db repoInterfaces.Repository + interfaces.BaseProcessor +} + +func (p *Processor) StartProcessing() { + for { + logger.Warningf(context.Background(), "Starting webhook processor") + err := p.run() + logger.Errorf(context.Background(), "error with running processor err: [%v] ", err) + time.Sleep(async.RetryDelay) + } +} + +func (p *Processor) run() error { + var payload admin.WebhookPayload + var request admin.WorkflowExecutionEventRequest + var err error + var subject string + var messageByte []byte + var ok bool + + for msg := range p.Sub.Start() { + p.SystemMetrics.MessageTotal.Inc() + stringMsg := string(msg.Message()) + if p.subType == common.AWS { + subject, messageByte, ok = p.FromSQSMessage(msg) + } else { + subject, messageByte, ok = p.FromPubSubMessage(msg) + } + if !ok { + continue + } + + if subject != proto.MessageName(&admin.WorkflowExecutionEventRequest{}) { + p.MarkMessageDone(msg) + continue + } + + if err = proto.Unmarshal(messageByte, &request); err != nil { + logger.Errorf(context.Background(), "failed to unmarshal to notification object from decoded string from message [%s] with err: %v", stringMsg, err) + p.SystemMetrics.MessageDecodingError.Inc() + p.MarkMessageDone(msg) + continue + } + + if !common.IsExecutionTerminal(request.Event.Phase) { + p.MarkMessageDone(msg) + continue + } + + executionModel, err := util.GetExecutionModel(context.Background(), p.db, *request.Event.ExecutionId) + if err != nil { + p.SystemMetrics.MessageProcessorError.Inc() + logger.Errorf(context.Background(), "failed to retrieve execution model for execution [%+v] from message [%s] with err: %v", request.Event.ExecutionId, stringMsg, err) + p.MarkMessageDone(msg) + continue + } + adminExecution, err := transformers.FromExecutionModel(context.Background(), *executionModel, transformers.DefaultExecutionTransformerOptions) + if err != nil { + p.SystemMetrics.MessageProcessorError.Inc() + logger.Errorf(context.Background(), "failed to transform execution model [%+v] from message [%s] with err: %v", executionModel, stringMsg, err) + p.MarkMessageDone(msg) + continue + } + + payload.Message = notifications.SubstituteParameters(p.webhook.GetConfig().Payload, request, adminExecution) + logger.Info(context.Background(), "Processor is sending message to webhook endpoint") + if err = p.webhook.Post(context.Background(), payload); err != nil { + p.SystemMetrics.MessageProcessorError.Inc() + logger.Errorf(context.Background(), "Error sending an message [%v] to webhook endpoint with err: %v", payload, err) + } else { + p.SystemMetrics.MessageSuccess.Inc() + } + + p.MarkMessageDone(msg) + } + + if err = p.Sub.Err(); err != nil { + p.SystemMetrics.ChannelClosedError.Inc() + logger.Errorf(context.Background(), "The stream for the subscriber channel closed with err: %v", err) + } + + return err +} + +func NewWebhookProcessor(subType string, sub pubsub.Subscriber, webhook webhookInterfaces.Webhook, db repoInterfaces.Repository, scope promutils.Scope) interfaces.Processor { + if subType != common.AWS && subType != common.GCP { + panic("unknown subscriber type [" + subType + "]") + } + + return &Processor{ + subType: subType, + webhook: webhook, + db: db, + BaseProcessor: interfaces.BaseProcessor{ + Sub: sub, + SystemMetrics: interfaces.NewProcessorSystemMetrics(scope.NewSubScope("webhook_processor")), + }, + } +} diff --git a/pkg/async/webhook/implementations/processor_test.go b/pkg/async/webhook/implementations/processor_test.go new file mode 100644 index 000000000..f3b959dbf --- /dev/null +++ b/pkg/async/webhook/implementations/processor_test.go @@ -0,0 +1,213 @@ +package implementations + +import ( + "context" + "encoding/base64" + "errors" + "testing" + "time" + + "github.com/flyteorg/flyteadmin/pkg/common" + + "github.com/NYTimes/gizmo/pubsub" + "github.com/NYTimes/gizmo/pubsub/pubsubtest" + "github.com/aws/aws-sdk-go/aws" + "github.com/flyteorg/flyteadmin/pkg/async/webhook/mocks" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" +) + +var ( + mockWebhook = mocks.MockWebhook{} + repo = repositoryMocks.NewMockRepository() + testWebhook = admin.WebhookPayload{Message: "hello world"} + workflowRequest = &admin.WorkflowExecutionEventRequest{ + Event: &event.WorkflowExecutionEvent{ + Phase: core.WorkflowExecution_FAILED, + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + } + msg, _ = proto.Marshal(workflowRequest) + testSubscriberMessage = map[string]interface{}{ + "Type": "Notification", + "MessageId": "1-a-3-c", + "TopicArn": "arn:aws:sns:my-region:123:flyte-test-notifications", + "Subject": "flyteidl.admin.WorkflowExecutionEventRequest", + "Message": aws.String(base64.StdEncoding.EncodeToString(msg)), + "Timestamp": "2019-01-04T22:59:32.849Z", + "SignatureVersion": "1", + "Signature": "some&ignature==", + "SigningCertURL": "https://sns.my-region.amazonaws.com/afdaf", + "UnsubscribeURL": "https://sns.my-region.amazonaws.com/sns:my-region:123:flyte-test-notifications:1-2-3-4-5"} + testSubscriber pubsubtest.TestSubscriber + mockSub pubsub.Subscriber = &testSubscriber +) + +// This method should be invoked before every test to Subscriber. +func initializeProcessor() { + testSubscriber.GivenStopError = nil + testSubscriber.GivenErrError = nil + testSubscriber.FoundError = nil + testSubscriber.ProtoMessages = nil + testSubscriber.JSONMessages = nil +} + +func TestProcessor_StartProcessing(t *testing.T) { + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testSubscriberMessage) + + sendWebhookValidationFunc := func(ctx context.Context, payload admin.WebhookPayload) error { + assert.Equal(t, payload.Message, testWebhook.Message) + return nil + } + mockWebhook.SetWebhookPostFunc(sendWebhookValidationFunc) + occurredAt := time.Now().UTC() + closure := &admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + StateChangeDetails: &admin.ExecutionStateChangeDetails{ + State: admin.ExecutionState_EXECUTION_ACTIVE, + OccurredAt: testutils.MockCreatedAtProto, + }, + WorkflowId: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, + } + closureBytes, err := proto.Marshal(closure) + assert.Nil(t, err) + spec := &admin.ExecutionSpec{ + LaunchPlan: &core.Identifier{ + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, + } + specBytes, err := proto.Marshal(spec) + assert.Nil(t, err) + repo.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback( + func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + return models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + BaseModel: models.BaseModel{ + ID: uint(8), + }, + Spec: specBytes, + Phase: core.WorkflowExecution_SUCCEEDED.String(), + Closure: closureBytes, + LaunchPlanID: uint(1), + WorkflowID: uint(2), + StartedAt: &occurredAt, + }, nil + }, + ) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingNoMessages(t *testing.T) { + initializeProcessor() + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingNoNotificationMessage(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingMessageWrongDataType(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": 12, + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingBase64DecodeError(t *testing.T) { + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": "NotBase64encoded", + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingProtoMarshallError(t *testing.T) { + var badByte = []byte("atreyu") + var testMessage = map[string]interface{}{ + "Type": "Not a real notification", + "MessageId": "1234", + "Message": aws.String(base64.StdEncoding.EncodeToString(badByte)), + } + initializeProcessor() + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testMessage) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingError(t *testing.T) { + initializeProcessor() + var ret = errors.New("err() returned an error") + testSubscriber.GivenErrError = ret + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Equal(t, ret, testProcessor.(*Processor).run()) +} + +func TestProcessor_StartProcessingWebhookError(t *testing.T) { + initializeProcessor() + webhookError := errors.New("webhook error") + sendWebhookErrorFunc := func(ctx context.Context, payload admin.WebhookPayload) error { + return webhookError + } + mockWebhook.SetWebhookPostFunc(sendWebhookErrorFunc) + testSubscriber.JSONMessages = append(testSubscriber.JSONMessages, testSubscriberMessage) + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.(*Processor).run()) +} + +func TestProcessor_StopProcessing(t *testing.T) { + initializeProcessor() + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Nil(t, testProcessor.StopProcessing()) +} + +func TestProcessor_StopProcessingError(t *testing.T) { + initializeProcessor() + var stopError = errors.New("stop() returns an error") + testSubscriber.GivenStopError = stopError + testProcessor := NewWebhookProcessor(common.AWS, mockSub, &mockWebhook, repo, promutils.NewTestScope()) + assert.Equal(t, stopError, testProcessor.StopProcessing()) +} diff --git a/pkg/async/webhook/implementations/slack_webhook.go b/pkg/async/webhook/implementations/slack_webhook.go new file mode 100644 index 000000000..2a0d87b99 --- /dev/null +++ b/pkg/async/webhook/implementations/slack_webhook.go @@ -0,0 +1,78 @@ +package implementations + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net/http" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + + "github.com/flyteorg/flytestdlib/logger" + + "github.com/flyteorg/flyteadmin/pkg/async/webhook/interfaces" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flytestdlib/promutils" +) + +const Slack = "slack" + +type SlackWebhook struct { + Config runtimeInterfaces.WebHookConfig + systemMetrics webhookMetrics +} + +func (s *SlackWebhook) GetConfig() runtimeInterfaces.WebHookConfig { + return s.Config +} + +func (s *SlackWebhook) Post(ctx context.Context, payload admin.WebhookPayload) error { + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + webhookURL, err := sm.Get(ctx, s.Config.URLSecretName) + if err != nil { + logger.Errorf(ctx, "Failed to get url from secret manager with error: %v", err) + return err + } + data := []byte(fmt.Sprintf("{'text': '%s'}", payload.Message)) + request, err := http.NewRequest("POST", webhookURL, bytes.NewBuffer(data)) + if err != nil { + logger.Errorf(ctx, "Failed to create request to Slack webhook with error: %v", err) + return err + } + request.Header.Add("Content-Type", "application/json") + if len(s.Config.TokenSecretName) != 0 { + token, err := sm.Get(ctx, s.Config.TokenSecretName) + if err != nil { + logger.Errorf(ctx, "Failed to get bearer token from secret manager with error: %v", err) + return err + } + request.Header.Add("Authorization", "Bearer "+token) + } + + client := &http.Client{} + resp, err := client.Do(request) + if err != nil { + logger.Errorf(ctx, "Failed to post to Slack webhook with error: %v", err) + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + respBody, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("received an error response (%d): %s", + resp.StatusCode, + string(respBody), + ) + } + + return nil +} + +func NewSlackWebhook(config runtimeInterfaces.WebHookConfig, scope promutils.Scope) interfaces.Webhook { + + return &SlackWebhook{ + Config: config, + systemMetrics: newWebhookMetrics(scope.NewSubScope("slack_webhook")), + } +} diff --git a/pkg/async/webhook/implementations/slack_webhook_test.go b/pkg/async/webhook/implementations/slack_webhook_test.go new file mode 100644 index 000000000..ab21da36d --- /dev/null +++ b/pkg/async/webhook/implementations/slack_webhook_test.go @@ -0,0 +1,15 @@ +package implementations + +import ( + "testing" + + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/stretchr/testify/assert" +) + +func TestSlackWebhook(t *testing.T) { + cfg := runtimeInterfaces.WebHookConfig{Name: Slack} + webhook := NewSlackWebhook(cfg, promutils.NewTestScope()) + assert.Equal(t, webhook.GetConfig().Name, cfg.Name) +} diff --git a/pkg/async/webhook/implementations/webhook_metrics.go b/pkg/async/webhook/implementations/webhook_metrics.go new file mode 100644 index 000000000..72be315ef --- /dev/null +++ b/pkg/async/webhook/implementations/webhook_metrics.go @@ -0,0 +1,22 @@ +package implementations + +import ( + "github.com/flyteorg/flytestdlib/promutils" + "github.com/prometheus/client_golang/prometheus" +) + +type webhookMetrics struct { + Scope promutils.Scope + SendSuccess prometheus.Counter + SendError prometheus.Counter + SendTotal prometheus.Counter +} + +func newWebhookMetrics(scope promutils.Scope) webhookMetrics { + return webhookMetrics{ + Scope: scope, + SendSuccess: scope.MustNewCounter("send_success", "Number of successful emails sent via Emailer."), + SendError: scope.MustNewCounter("send_error", "Number of errors when sending email via Emailer"), + SendTotal: scope.MustNewCounter("send_total", "Total number of emails attempted to be sent"), + } +} diff --git a/pkg/async/webhook/interfaces/webhook.go b/pkg/async/webhook/interfaces/webhook.go new file mode 100644 index 000000000..95deca050 --- /dev/null +++ b/pkg/async/webhook/interfaces/webhook.go @@ -0,0 +1,21 @@ +package interfaces + +import ( + "context" + + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) + +//go:generate mockery -name=Webhook -output=../mocks -case=underscore + +type Payload struct { + Value string `protobuf:"bytes,1,opt,value=value"` +} + +// Webhook Defines the interface for Publishing execution event to other services, such as slack. +type Webhook interface { + // Post The notificationType is inferred from the Notification object in the Execution Spec. + Post(ctx context.Context, payload admin.WebhookPayload) error + GetConfig() runtimeInterfaces.WebHookConfig +} diff --git a/pkg/async/webhook/mocks/processor.go b/pkg/async/webhook/mocks/processor.go new file mode 100644 index 000000000..dbc6713b0 --- /dev/null +++ b/pkg/async/webhook/mocks/processor.go @@ -0,0 +1,53 @@ +package mocks + +import ( + "context" + + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) + +type RunFunc func() error + +type StopFunc func() error + +type WebhookPostFunc func(ctx context.Context, payload admin.WebhookPayload) error + +type MockSubscriber struct { + runFunc RunFunc + stopFunc StopFunc +} + +func (m *MockSubscriber) Run() error { + if m.runFunc != nil { + return m.runFunc() + } + return nil +} + +func (m *MockSubscriber) Stop() error { + if m.stopFunc != nil { + return m.stopFunc() + } + return nil +} + +type MockWebhook struct { + post WebhookPostFunc +} + +func (m *MockWebhook) GetConfig() runtimeInterfaces.WebHookConfig { + return runtimeInterfaces.WebHookConfig{Payload: "hello world"} +} + +func (m *MockWebhook) SetWebhookPostFunc(webhookPostFunc WebhookPostFunc) { + m.post = webhookPostFunc +} + +func (m *MockWebhook) Post(ctx context.Context, payload admin.WebhookPayload) error { + if m.post != nil { + return m.post(ctx, payload) + } + return nil +} diff --git a/pkg/async/webhook/mocks/publisher.go b/pkg/async/webhook/mocks/publisher.go new file mode 100644 index 000000000..ccfa041eb --- /dev/null +++ b/pkg/async/webhook/mocks/publisher.go @@ -0,0 +1,24 @@ +package mocks + +import ( + "context" + + "github.com/golang/protobuf/proto" +) + +type PublishFunc func(ctx context.Context, key string, msg proto.Message) error + +type MockPublisher struct { + publishFunc PublishFunc +} + +func (m *MockPublisher) SetPublishCallback(publishFunction PublishFunc) { + m.publishFunc = publishFunction +} + +func (m *MockPublisher) Publish(ctx context.Context, notificationType string, msg proto.Message) error { + if m.publishFunc != nil { + return m.publishFunc(ctx, notificationType, msg) + } + return nil +} diff --git a/pkg/async/webhook/mocks/webhook.go b/pkg/async/webhook/mocks/webhook.go new file mode 100644 index 000000000..b0f84498c --- /dev/null +++ b/pkg/async/webhook/mocks/webhook.go @@ -0,0 +1,82 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + interfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// Webhook is an autogenerated mock type for the Webhook type +type Webhook struct { + mock.Mock +} + +type Webhook_GetConfig struct { + *mock.Call +} + +func (_m Webhook_GetConfig) Return(_a0 interfaces.WebHookConfig) *Webhook_GetConfig { + return &Webhook_GetConfig{Call: _m.Call.Return(_a0)} +} + +func (_m *Webhook) OnGetConfig() *Webhook_GetConfig { + c_call := _m.On("GetConfig") + return &Webhook_GetConfig{Call: c_call} +} + +func (_m *Webhook) OnGetConfigMatch(matchers ...interface{}) *Webhook_GetConfig { + c_call := _m.On("GetConfig", matchers...) + return &Webhook_GetConfig{Call: c_call} +} + +// GetConfig provides a mock function with given fields: +func (_m *Webhook) GetConfig() interfaces.WebHookConfig { + ret := _m.Called() + + var r0 interfaces.WebHookConfig + if rf, ok := ret.Get(0).(func() interfaces.WebHookConfig); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(interfaces.WebHookConfig) + } + + return r0 +} + +type Webhook_Post struct { + *mock.Call +} + +func (_m Webhook_Post) Return(_a0 error) *Webhook_Post { + return &Webhook_Post{Call: _m.Call.Return(_a0)} +} + +func (_m *Webhook) OnPost(ctx context.Context, payload admin.WebhookPayload) *Webhook_Post { + c_call := _m.On("Post", ctx, payload) + return &Webhook_Post{Call: c_call} +} + +func (_m *Webhook) OnPostMatch(matchers ...interface{}) *Webhook_Post { + c_call := _m.On("Post", matchers...) + return &Webhook_Post{Call: c_call} +} + +// Post provides a mock function with given fields: ctx, payload +func (_m *Webhook) Post(ctx context.Context, payload admin.WebhookPayload) error { + ret := _m.Called(ctx, payload) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, admin.WebhookPayload) error); ok { + r0 = rf(ctx, payload) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index e881ef10e..8ee1cb7cf 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -1289,16 +1289,18 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi if err != nil { // The only errors that publishNotifications will forward are those related // to unexpected data and transformation errors. - logger.Debugf(ctx, "failed to publish notifications for CreateWorkflowEvent [%+v] due to err: %v", + logger.Errorf(ctx, "failed to publish notifications for CreateWorkflowEvent [%+v] due to err: %v", request, err) return nil, err } } - if err := m.eventPublisher.Publish(ctx, proto.MessageName(&request), &request); err != nil { - m.systemMetrics.PublishEventError.Inc() - logger.Infof(ctx, "error publishing event [%+v] with err: [%v]", request.RequestId, err) - } + go func() { + if err := m.eventPublisher.Publish(ctx, proto.MessageName(&request), &request); err != nil { + m.systemMetrics.PublishEventError.Inc() + logger.Infof(ctx, "error publishing event [%+v] with err: [%v]", request.RequestId, err) + } + }() go func() { if err := m.cloudEventPublisher.Publish(ctx, proto.MessageName(&request), &request); err != nil { diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 77c78f480..4046ea1a3 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -5,6 +5,8 @@ import ( "fmt" "runtime/debug" + "github.com/flyteorg/flyteadmin/pkg/async/webhook" + "github.com/flyteorg/flyteadmin/plugins" "github.com/flyteorg/flyteadmin/pkg/async/cloudevent" @@ -107,6 +109,14 @@ func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, confi processor.StartProcessing() }() + webhookProcessors := webhook.NewWebhookProcessors(repo, *configuration.ApplicationConfiguration().GetWebhookNotificationConfig(), adminScope) + go func() { + logger.Info(ctx, "Started processing webhook events.") + for _, webhookProcessor := range webhookProcessors { + webhookProcessor.StartProcessing() + } + }() + // Configure workflow scheduler async processes. schedulerConfig := configuration.ApplicationConfiguration().GetSchedulerConfig() workflowScheduler := schedule.NewWorkflowScheduler(repo, schedule.WorkflowSchedulerConfig{ diff --git a/pkg/runtime/application_config_provider.go b/pkg/runtime/application_config_provider.go index 3b8b0a270..91aa8dda9 100644 --- a/pkg/runtime/application_config_provider.go +++ b/pkg/runtime/application_config_provider.go @@ -12,6 +12,7 @@ const flyteAdmin = "flyteadmin" const scheduler = "scheduler" const remoteData = "remoteData" const notifications = "notifications" +const webhookNotifications = "webhookNotifications" const domains = "domains" const externalEvents = "externalEvents" const cloudEvents = "cloudEvents" @@ -62,6 +63,9 @@ var remoteDataConfig = config.MustRegisterSection(remoteData, &interfaces.Remote var notificationsConfig = config.MustRegisterSection(notifications, &interfaces.NotificationsConfig{ Type: common.Local, }) +var webhookNotificationsConfig = config.MustRegisterSection(webhookNotifications, &interfaces.WebhookNotificationsConfig{ + Type: common.Local, +}) var domainsConfig = config.MustRegisterSection(domains, &interfaces.DomainsConfig{ { ID: "development", @@ -119,6 +123,10 @@ func (p *ApplicationConfigurationProvider) GetCloudEventsConfig() *interfaces.Cl return cloudEventsConfig.GetConfig().(*interfaces.CloudEventsConfig) } +func (p *ApplicationConfigurationProvider) GetWebhookNotificationConfig() *interfaces.WebhookNotificationsConfig { + return webhookNotificationsConfig.GetConfig().(*interfaces.WebhookNotificationsConfig) +} + func NewApplicationConfigurationProvider() interfaces.ApplicationConfiguration { return &ApplicationConfigurationProvider{} } diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index cf9bf2e9e..d02ef27a9 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -552,6 +552,30 @@ type NotificationsConfig struct { ReconnectDelaySeconds int `json:"reconnectDelaySeconds"` } +type WebHookConfig struct { + // Type of webhook service to use. Currently only "slack" is supported. + Name string `json:"name"` + URLSecretName string `json:"urlSecretName" pflag:",Secret name to use for the webhook URL"` + Payload string `json:"payload"` + TokenSecretName string `json:"tokenSecretName" pflag:",Secret name to use for the barer token"` + NotificationsProcessorConfig NotificationsProcessorConfig `json:"processor"` +} + +// WebhookNotificationsConfig defines the configuration for the webhook service. +type WebhookNotificationsConfig struct { + // Defines the cloud provider that backs the scheduler. In the absence of a specification the no-op, 'local' + // scheme is used. + Type string `json:"type"` + AWSConfig AWSConfig `json:"aws"` + GCPConfig GCPConfig `json:"gcp"` + // Defines the list of webhooks to be configured. + WebhooksConfig []WebHookConfig `json:"webhooks"` + // Number of times to attempt recreating a notifications processor client should there be any disruptions. + ReconnectAttempts int `json:"reconnectAttempts"` + // Specifies the time interval to wait before attempting to reconnect the notifications processor client. + ReconnectDelaySeconds int `json:"reconnectDelaySeconds"` +} + // Domains are always globally set in the application config, whereas individual projects can be individually registered. type Domain struct { // Unique identifier for a domain. @@ -572,4 +596,5 @@ type ApplicationConfiguration interface { GetDomainsConfig() *DomainsConfig GetExternalEventsConfig() *ExternalEventsConfig GetCloudEventsConfig() *CloudEventsConfig + GetWebhookNotificationConfig() *WebhookNotificationsConfig } diff --git a/pkg/runtime/mocks/mock_application_provider.go b/pkg/runtime/mocks/mock_application_provider.go index 13079619c..f32c13b49 100644 --- a/pkg/runtime/mocks/mock_application_provider.go +++ b/pkg/runtime/mocks/mock_application_provider.go @@ -6,14 +6,19 @@ import ( ) type MockApplicationProvider struct { - dbConfig database.DbConfig - topLevelConfig interfaces.ApplicationConfig - schedulerConfig interfaces.SchedulerConfig - remoteDataConfig interfaces.RemoteDataConfig - notificationsConfig interfaces.NotificationsConfig - domainsConfig interfaces.DomainsConfig - externalEventsConfig interfaces.ExternalEventsConfig - cloudEventConfig interfaces.CloudEventsConfig + dbConfig database.DbConfig + topLevelConfig interfaces.ApplicationConfig + schedulerConfig interfaces.SchedulerConfig + remoteDataConfig interfaces.RemoteDataConfig + notificationsConfig interfaces.NotificationsConfig + webhookNotificationsConfig interfaces.WebhookNotificationsConfig + domainsConfig interfaces.DomainsConfig + externalEventsConfig interfaces.ExternalEventsConfig + cloudEventConfig interfaces.CloudEventsConfig +} + +func (p *MockApplicationProvider) GetWebhookNotificationConfig() *interfaces.WebhookNotificationsConfig { + return &p.webhookNotificationsConfig } func (p *MockApplicationProvider) GetDbConfig() *database.DbConfig {