diff --git a/core/notification/mocks/receiver_service.go b/core/notification/mocks/receiver_service.go index 10199362..297fd1d6 100644 --- a/core/notification/mocks/receiver_service.go +++ b/core/notification/mocks/receiver_service.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. +// Code generated by mockery v2.42.3. DO NOT EDIT. package mocks @@ -23,6 +23,80 @@ func (_m *ReceiverService) EXPECT() *ReceiverService_Expecter { return &ReceiverService_Expecter{mock: &_m.Mock} } +// Get provides a mock function with given fields: ctx, id, opts +func (_m *ReceiverService) Get(ctx context.Context, id uint64, opts ...receiver.GetOption) (*receiver.Receiver, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, id) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *receiver.Receiver + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uint64, ...receiver.GetOption) (*receiver.Receiver, error)); ok { + return rf(ctx, id, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, uint64, ...receiver.GetOption) *receiver.Receiver); ok { + r0 = rf(ctx, id, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*receiver.Receiver) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uint64, ...receiver.GetOption) error); ok { + r1 = rf(ctx, id, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReceiverService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type ReceiverService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id uint64 +// - opts ...receiver.GetOption +func (_e *ReceiverService_Expecter) Get(ctx interface{}, id interface{}, opts ...interface{}) *ReceiverService_Get_Call { + return &ReceiverService_Get_Call{Call: _e.mock.On("Get", + append([]interface{}{ctx, id}, opts...)...)} +} + +func (_c *ReceiverService_Get_Call) Run(run func(ctx context.Context, id uint64, opts ...receiver.GetOption)) *ReceiverService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]receiver.GetOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(receiver.GetOption) + } + } + run(args[0].(context.Context), args[1].(uint64), variadicArgs...) + }) + return _c +} + +func (_c *ReceiverService_Get_Call) Return(_a0 *receiver.Receiver, _a1 error) *ReceiverService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReceiverService_Get_Call) RunAndReturn(run func(context.Context, uint64, ...receiver.GetOption) (*receiver.Receiver, error)) *ReceiverService_Get_Call { + _c.Call.Return(run) + return _c +} + // List provides a mock function with given fields: ctx, flt func (_m *ReceiverService) List(ctx context.Context, flt receiver.Filter) ([]receiver.Receiver, error) { ret := _m.Called(ctx, flt) diff --git a/core/notification/notification.go b/core/notification/notification.go index c11153b4..fd55e343 100644 --- a/core/notification/notification.go +++ b/core/notification/notification.go @@ -36,16 +36,16 @@ type Transactor interface { // Notification is a model of notification type Notification struct { - ID string `json:"id"` - NamespaceID uint64 `json:"namespace_id"` - Type string `json:"type"` - Data map[string]any `json:"data"` - Labels map[string]string `json:"labels"` - ValidDuration time.Duration `json:"valid_duration"` - Template string `json:"template"` - UniqueKey string `json:"unique_key"` - ReceiverSelectors []map[string]string `json:"receiver_selectors"` - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + NamespaceID uint64 `json:"namespace_id"` + Type string `json:"type"` + Data map[string]any `json:"data"` + Labels map[string]string `json:"labels"` + ValidDuration time.Duration `json:"valid_duration"` + Template string `json:"template"` + UniqueKey string `json:"unique_key"` + ReceiverSelectors []map[string]interface{} `json:"receiver_selectors"` + CreatedAt time.Time `json:"created_at"` // won't be stored in notification table, only to propagate this to notification_subscriber AlertIDs []int64 diff --git a/core/notification/notification_test.go b/core/notification/notification_test.go index 6da50a90..3b66fce4 100644 --- a/core/notification/notification_test.go +++ b/core/notification/notification_test.go @@ -41,7 +41,7 @@ func TestNotification_Validate(t *testing.T) { Labels: map[string]string{ "receiver_id": "2", }, - ReceiverSelectors: []map[string]string{ + ReceiverSelectors: []map[string]interface{}{ { "varkey1": "value1", }, diff --git a/core/notification/router_receiver_service.go b/core/notification/router_receiver_service.go index b129df1b..137c818f 100644 --- a/core/notification/router_receiver_service.go +++ b/core/notification/router_receiver_service.go @@ -2,6 +2,8 @@ package notification import ( "context" + "fmt" + "strconv" "github.com/goto/siren/core/log" "github.com/goto/siren/core/receiver" @@ -26,10 +28,26 @@ func (s *RouterReceiverService) PrepareMetaMessages(ctx context.Context, n Notif return nil, nil, errors.ErrInvalid.WithMsgf("number of receiver selectors should be less than or equal threshold %d", s.deps.Cfg.MaxNumReceiverSelectors) } - rcvs, err := s.deps.ReceiverService.List(ctx, receiver.Filter{ - MultipleLabels: n.ReceiverSelectors, - Expanded: true, - }) + var rcvs []receiver.Receiver + var userConfigs map[uint64]map[string]interface{} + + // Check if any selector contains a config + hasConfig := false + for _, selector := range n.ReceiverSelectors { + if _, ok := selector["config"]; ok { + hasConfig = true + break + } + } + + if hasConfig { + // Handle case when config is provided + rcvs, userConfigs, err = s.handleConfigCase(ctx, n.ReceiverSelectors) + } else { + // Handle case when only receiver IDs are provided + rcvs, err = s.handleIDOnlyCase(ctx, n.ReceiverSelectors) + } + if err != nil { return nil, nil, err } @@ -38,11 +56,31 @@ func (s *RouterReceiverService) PrepareMetaMessages(ctx context.Context, n Notif return nil, nil, errors.ErrNotFound } + // Check if the number of receivers exceeds the max messages receiver flow + if len(rcvs) > s.deps.Cfg.MaxMessagesReceiverFlow { + return nil, nil, errors.ErrInvalid.WithMsgf("sending %d messages exceed max messages receiver flow threshold %d. this will spam and broadcast to %d channel. found %d receiver selectors passed, you might want to check your receiver selectors configuration", len(rcvs), s.deps.Cfg.MaxMessagesReceiverFlow, len(rcvs), len(n.ReceiverSelectors)) + } + for _, rcv := range rcvs { - var rcvView = &subscription.ReceiverView{} + rcvView := &subscription.ReceiverView{} rcvView.FromReceiver(rcv) - metaMessages = append(metaMessages, n.MetaMessage(*rcvView)) + if config, ok := userConfigs[rcv.ID]; ok { + // Merge user-provided config with receiver config + for k, v := range config { + rcvView.Configurations[k] = v + } + } + + // Ensure required fields are set + if err := s.validateConfigurations(rcvView.Configurations); err != nil { + return nil, nil, err + } + + metaMessage := n.MetaMessage(*rcvView) + metaMessage.NotificationIDs = []string{n.ID} + + metaMessages = append(metaMessages, metaMessage) notificationLogs = append(notificationLogs, log.Notification{ NamespaceID: n.NamespaceID, NotificationID: n.ID, @@ -51,10 +89,65 @@ func (s *RouterReceiverService) PrepareMetaMessages(ctx context.Context, n Notif }) } - var metaMessagesNum = len(metaMessages) - if metaMessagesNum > s.deps.Cfg.MaxMessagesReceiverFlow { - return nil, nil, errors.ErrInvalid.WithMsgf("sending %d messages exceed max messages receiver flow threshold %d. this will spam and broadcast to %d channel. found %d receiver selectors passed, you might want to check your receiver selectors configuration", metaMessagesNum, s.deps.Cfg.MaxMessagesReceiverFlow, metaMessagesNum, len(n.ReceiverSelectors)) + return metaMessages, notificationLogs, nil +} + +func (s *RouterReceiverService) handleConfigCase(ctx context.Context, selectors []map[string]interface{}) ([]receiver.Receiver, map[uint64]map[string]interface{}, error) { + var receiverIDs []uint64 + userConfigs := make(map[uint64]map[string]interface{}) + + for _, selector := range selectors { + if idStr, ok := selector["id"].(string); ok { + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return nil, nil, errors.ErrInvalid.WithMsgf("invalid receiver id: %s", idStr) + } + receiverIDs = append(receiverIDs, id) + } + if config, ok := selector["config"].(map[string]interface{}); ok { + for _, id := range receiverIDs { + if userConfigs[id] == nil { + userConfigs[id] = make(map[string]interface{}) + } + for k, v := range config { + userConfigs[id][k] = v + } + } + } } - return metaMessages, notificationLogs, nil + rcvs := make([]receiver.Receiver, 0, len(receiverIDs)) + for _, id := range receiverIDs { + rcv, err := s.deps.ReceiverService.Get(ctx, id) + if err != nil { + return nil, nil, err + } + rcvs = append(rcvs, *rcv) + } + + return rcvs, userConfigs, nil +} + +func (s *RouterReceiverService) handleIDOnlyCase(ctx context.Context, selectors []map[string]interface{}) ([]receiver.Receiver, error) { + convertedSelectors := make([]map[string]string, len(selectors)) + for i, selector := range selectors { + convertedSelectors[i] = make(map[string]string) + for k, v := range selector { + convertedSelectors[i][k] = fmt.Sprint(v) + } + } + return s.deps.ReceiverService.List(ctx, receiver.Filter{ + MultipleLabels: convertedSelectors, + Expanded: true, + }) +} + +func (s *RouterReceiverService) validateConfigurations(configs map[string]interface{}) error { + requiredFields := []string{"token", "workspace", "channel_name"} + for _, field := range requiredFields { + if _, ok := configs[field]; !ok { + return errors.ErrInvalid.WithMsgf("%s is required in the config", field) + } + } + return nil } diff --git a/core/notification/router_receiver_service_test.go b/core/notification/router_receiver_service_test.go index 60443c4c..e2532304 100644 --- a/core/notification/router_receiver_service_test.go +++ b/core/notification/router_receiver_service_test.go @@ -26,7 +26,7 @@ func TestRouterReceiverService_PrepareMetaMessage(t *testing.T) { { name: "should return error if number of receiver selector is more than threshold", n: notification.Notification{ - ReceiverSelectors: []map[string]string{ + ReceiverSelectors: []map[string]interface{}{ { "k1": "v1", }, @@ -73,28 +73,61 @@ func TestRouterReceiverService_PrepareMetaMessage(t *testing.T) { }, { name: "should return no error if succeed", - n: notification.Notification{}, + n: notification.Notification{ + ID: "test-notification-id", + NamespaceID: 123, + }, setup: func(rs *mocks.ReceiverService, n *mocks.Notifier) { rs.EXPECT().List(mock.AnythingOfType("context.todoCtx"), mock.AnythingOfType("receiver.Filter")).Return([]receiver.Receiver{ { ID: 1, + Configurations: map[string]interface{}{ + "token": "token1", + "workspace": "workspace1", + "channel_name": "channel1", + }, }, { ID: 2, + Configurations: map[string]interface{}{ + "token": "token2", + "workspace": "workspace2", + "channel_name": "channel2", + }, }, }, nil) }, want: []notification.MetaMessage{ { - ReceiverID: 1, + ReceiverID: 1, + NotificationIDs: []string{"test-notification-id"}, + ReceiverConfigs: map[string]interface{}{ + "token": "token1", + "workspace": "workspace1", + "channel_name": "channel1", + }, }, { - ReceiverID: 2, + ReceiverID: 2, + NotificationIDs: []string{"test-notification-id"}, + ReceiverConfigs: map[string]interface{}{ + "token": "token2", + "workspace": "workspace2", + "channel_name": "channel2", + }, }, }, want1: []log.Notification{ - {ReceiverID: 1}, - {ReceiverID: 2}, + { + ReceiverID: 1, + NotificationID: "test-notification-id", + NamespaceID: 123, + }, + { + ReceiverID: 2, + NotificationID: "test-notification-id", + NamespaceID: 123, + }, }, }, } diff --git a/core/notification/service.go b/core/notification/service.go index 468b5cbd..0ee350ca 100644 --- a/core/notification/service.go +++ b/core/notification/service.go @@ -33,6 +33,7 @@ type SubscriptionService interface { type ReceiverService interface { List(ctx context.Context, flt receiver.Filter) ([]receiver.Receiver, error) + Get(ctx context.Context, id uint64, opts ...receiver.GetOption) (*receiver.Receiver, error) } type SilenceService interface { @@ -237,6 +238,7 @@ func (s *Service) dispatchInternal(ctx context.Context, ns []Notification) (noti } if err := s.deps.Q.Enqueue(ctx, messages...); err != nil { + fmt.Printf("Context: %+v\n", ctx) return nil, fmt.Errorf("failed enqueuing messages: %w", err) } diff --git a/internal/api/v1beta1/notification.go b/internal/api/v1beta1/notification.go index 5ffa8bc6..84fc455b 100644 --- a/internal/api/v1beta1/notification.go +++ b/internal/api/v1beta1/notification.go @@ -19,7 +19,11 @@ import ( const notificationAPIScope = "notification_api" -func (s *GRPCServer) validatePostNotificationPayload(receiverSelectors []map[string]string, labels map[string]string) error { +func (s *GRPCServer) validatePostNotificationPayload(receiverSelectors []map[string]interface{}, labels map[string]string) error { + if len(receiverSelectors) == 0 && len(labels) == 0 { + return errors.ErrInvalid.WithMsgf("receivers or labels must be provided") + } + if len(receiverSelectors) > 0 && len(labels) > 0 { return errors.ErrInvalid.WithMsgf("receivers and labels cannot being used at the same time, should be used either one of them") } @@ -27,6 +31,28 @@ func (s *GRPCServer) validatePostNotificationPayload(receiverSelectors []map[str return nil } +func (s *GRPCServer) parseReceivers(pbSelectors []*structpb.Struct) ([]map[string]interface{}, error) { + var receiverSelectors []map[string]interface{} + + for _, pbSelector := range pbSelectors { + selector := make(map[string]interface{}) + for k, v := range pbSelector.AsMap() { + if k == "config" { + configMap, ok := v.(map[string]interface{}) + if !ok { + return nil, errors.ErrInvalid.WithMsgf("invalid config format, expected map[string]interface{}") + } + selector[k] = configMap + } else { + selector[k] = v + } + } + receiverSelectors = append(receiverSelectors, selector) + } + + return receiverSelectors, nil +} + func (s *GRPCServer) PostNotification(ctx context.Context, req *sirenv1beta1.PostNotificationRequest) (*sirenv1beta1.PostNotificationResponse, error) { idempotencyScope := api.GetHeaderString(ctx, s.headers.IdempotencyScope) if idempotencyScope == "" { @@ -35,29 +61,17 @@ func (s *GRPCServer) PostNotification(ctx context.Context, req *sirenv1beta1.Pos idempotencyKey := api.GetHeaderString(ctx, s.headers.IdempotencyKey) if idempotencyKey != "" { - if notificationID, err := s.notificationService.CheckIdempotency(ctx, idempotencyScope, idempotencyKey); notificationID != "" { + if notificationID, err := s.notificationService.CheckIdempotency(ctx, idempotencyScope, idempotencyKey); err == nil { return &sirenv1beta1.PostNotificationResponse{ NotificationId: notificationID, }, nil - } else if errors.Is(err, errors.ErrNotFound) { - s.logger.Debug("no idempotency found with detail", "scope", idempotencyScope, "key", idempotencyKey) - } else { + } else if !errors.Is(err, errors.ErrNotFound) { return nil, api.GenerateRPCErr(s.logger, fmt.Errorf("error when checking idempotency: %w", err)) } } - - var receiverSelectors = []map[string]string{} - for _, pbSelector := range req.GetReceivers() { - var mss = make(map[string]string) - for k, v := range pbSelector.AsMap() { - vString, ok := v.(string) - if !ok { - err := errors.ErrInvalid.WithMsgf("invalid receiver selectors, value must be string but found %v", v) - return nil, api.GenerateRPCErr(s.logger, err) - } - mss[k] = vString - } - receiverSelectors = append(receiverSelectors, mss) + receiverSelectors, err := s.parseReceivers(req.GetReceivers()) + if err != nil { + return nil, api.GenerateRPCErr(s.logger, fmt.Errorf("error while parsing receivers: %w", err)) } if err := s.validatePostNotificationPayload(receiverSelectors, req.GetLabels()); err != nil { diff --git a/internal/api/v1beta1/notification_test.go b/internal/api/v1beta1/notification_test.go index 892d8f82..b44b484e 100644 --- a/internal/api/v1beta1/notification_test.go +++ b/internal/api/v1beta1/notification_test.go @@ -25,12 +25,26 @@ func TestGRPCServer_PostNotification(t *testing.T) { testCases := []struct { name string idempotencyKey string + request *sirenv1beta1.PostNotificationRequest setup func(*mocks.NotificationService) + expectedID string errString string }{ + { + name: "should return invalid argument if no receivers or labels provided", + idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{}, + setup: func(ns *mocks.NotificationService) { + ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) + }, + errString: "rpc error: code = InvalidArgument desc = receivers or labels must be provided", + }, { name: "should return invalid argument if post notification return invalid argument", idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, setup: func(ns *mocks.NotificationService) { ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return(nil, errors.ErrInvalid) @@ -40,15 +54,21 @@ func TestGRPCServer_PostNotification(t *testing.T) { { name: "should return internal error if post notification return some error", idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, setup: func(ns *mocks.NotificationService) { ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) - ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return(nil, errors.New("some error")) + ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return(nil, errors.New("some unexpected error")) }, errString: "rpc error: code = Internal desc = some unexpected error occurred", }, { - name: "should return invalid error if post notification return err no message", + name: "should return invalid error if post notification return err_no_message", idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, setup: func(ns *mocks.NotificationService) { ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return(nil, notification.ErrNoMessage) @@ -57,10 +77,15 @@ func TestGRPCServer_PostNotification(t *testing.T) { }, { name: "should return success if request is idempotent", - idempotencyKey: "test", + idempotencyKey: "test-idempotent", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, setup: func(ns *mocks.NotificationService) { - ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(notificationID, nil) + ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), "test-idempotent").Return(notificationID, nil) }, + expectedID: notificationID, + errString: "", }, { name: "should return error if idempotency checking return error", @@ -71,24 +96,33 @@ func TestGRPCServer_PostNotification(t *testing.T) { errString: "rpc error: code = Internal desc = some unexpected error occurred", }, { - name: "should return error if error inserting idempotency", - idempotencyKey: "test", - setup: func(ns *mocks.NotificationService) { - ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) - ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return([]string{notificationID}, nil) - ns.EXPECT().InsertIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(errors.New("some error")) - }, - errString: "rpc error: code = Internal desc = some unexpected error occurred", - }, + name: "should return error if error inserting idempotency", + idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, + setup: func(ns *mocks.NotificationService) { + ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) + ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return([]string{notificationID}, nil) + ns.EXPECT().InsertIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), notificationID).Return(errors.New("some error")) + }, + errString: "rpc error: code = Internal desc = some unexpected error occurred", + }, + { - name: "should return OK response if post notification succeed", - idempotencyKey: "test", - setup: func(ns *mocks.NotificationService) { - ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) - ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return([]string{notificationID}, nil) - ns.EXPECT().InsertIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return(nil) - }, - }, + name: "should return OK response if post notification succeed", + idempotencyKey: "test", + request: &sirenv1beta1.PostNotificationRequest{ + Labels: map[string]string{"key": "value"}, + }, + setup: func(ns *mocks.NotificationService) { + ns.EXPECT().CheckIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return("", errors.ErrNotFound) + ns.EXPECT().Dispatch(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("[]notification.Notification")).Return([]string{notificationID}, nil) + ns.EXPECT().InsertIdempotency(mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), notificationID).Return(nil) + }, + expectedID: notificationID, + errString: "", + }, } for _, tc := range testCases { @@ -109,10 +143,19 @@ func TestGRPCServer_PostNotification(t *testing.T) { ctx := metadata.NewIncomingContext(context.TODO(), metadata.New(map[string]string{ idempotencyHeaderKey: tc.idempotencyKey, })) - _, err = dummyGRPCServer.PostNotification(ctx, &sirenv1beta1.PostNotificationRequest{}) + resp, err := dummyGRPCServer.PostNotification(ctx, tc.request) - if (err != nil) && tc.errString != err.Error() { - t.Errorf("PostNotification() error = %v, wantErr %v", err, tc.errString) + if tc.errString != "" { + if err == nil || err.Error() != tc.errString { + t.Errorf("PostNotification() error = %v, wantErr %v", err, tc.errString) + } + } else { + if err != nil { + t.Errorf("PostNotification() unexpected error = %v", err) + } + if resp == nil || resp.NotificationId != tc.expectedID { + t.Errorf("PostNotification() got notification ID = %v, want %v", resp.GetNotificationId(), tc.expectedID) + } } mockNotificationService.AssertExpectations(t) @@ -133,7 +176,7 @@ func TestGRPCServer_ListNotifications(t *testing.T) { "data-key": "data-value", }, Labels: map[string]string{}, - ReceiverSelectors: []map[string]string{}, + ReceiverSelectors: []map[string]interface{}{}, }, } diff --git a/internal/store/model/notification.go b/internal/store/model/notification.go index d1c4f03f..94e26445 100644 --- a/internal/store/model/notification.go +++ b/internal/store/model/notification.go @@ -9,16 +9,16 @@ import ( ) type Notification struct { - ID string `db:"id"` - NamespaceID sql.NullInt64 `db:"namespace_id"` - Type string `db:"type"` - Data pgc.StringAnyMap `db:"data"` - Labels pgc.StringStringMap `db:"labels"` - ValidDuration pgc.TimeDuration `db:"valid_duration"` - UniqueKey sql.NullString `db:"unique_key"` - Template sql.NullString `db:"template"` - CreatedAt time.Time `db:"created_at"` - ReceiverSelectors pgc.ListStringStringMap `db:"receiver_selectors"` + ID string `db:"id"` + NamespaceID sql.NullInt64 `db:"namespace_id"` + Type string `db:"type"` + Data pgc.StringAnyMap `db:"data"` + Labels pgc.StringStringMap `db:"labels"` + ValidDuration pgc.TimeDuration `db:"valid_duration"` + UniqueKey sql.NullString `db:"unique_key"` + Template sql.NullString `db:"template"` + CreatedAt time.Time `db:"created_at"` + ReceiverSelectors pgc.ListStringAnyMap `db:"receiver_selectors"` } func (n *Notification) FromDomain(d notification.Notification) { @@ -51,6 +51,10 @@ func (n *Notification) FromDomain(d notification.Notification) { } func (n *Notification) ToDomain() *notification.Notification { + receiverSelectors := n.ReceiverSelectors + if receiverSelectors == nil { + receiverSelectors = []map[string]interface{}{} + } return ¬ification.Notification{ ID: n.ID, NamespaceID: uint64(n.NamespaceID.Int64), @@ -61,6 +65,6 @@ func (n *Notification) ToDomain() *notification.Notification { Template: n.Template.String, UniqueKey: n.UniqueKey.String, CreatedAt: n.CreatedAt, - ReceiverSelectors: n.ReceiverSelectors, + ReceiverSelectors: receiverSelectors, } } diff --git a/internal/store/postgres/notification.go b/internal/store/postgres/notification.go index fb66fabe..2bd3804c 100644 --- a/internal/store/postgres/notification.go +++ b/internal/store/postgres/notification.go @@ -176,7 +176,15 @@ func (r *NotificationRepository) List(ctx context.Context, flt notification.Filt if err := rows.StructScan(¬ificationModel); err != nil { return nil, err } - notificationsDomain = append(notificationsDomain, *notificationModel.ToDomain()) + + notificationDomain := notificationModel.ToDomain() + + // Always ensure ReceiverSelectors is initialized + if notificationDomain.ReceiverSelectors == nil { + notificationDomain.ReceiverSelectors = []map[string]interface{}{} + } + + notificationsDomain = append(notificationsDomain, *notificationDomain) } return notificationsDomain, nil } diff --git a/internal/store/postgres/notification_test.go b/internal/store/postgres/notification_test.go index f209884d..31771ce8 100644 --- a/internal/store/postgres/notification_test.go +++ b/internal/store/postgres/notification_test.go @@ -227,7 +227,7 @@ func (s *NotificationRepositoryTestSuite) TestList() { "label-key": "label-value", }, Template: "", - ReceiverSelectors: []map[string]string{ + ReceiverSelectors: []map[string]interface{}{ { "team": "gotocompany-infra", "severity": "WARNING", @@ -252,10 +252,11 @@ func (s *NotificationRepositoryTestSuite) TestList() { Data: map[string]interface{}{ "data-key": "data-value", }, - Labels: map[string]string{}, - Template: "", - ValidDuration: 0, - UniqueKey: "", + Labels: map[string]string{}, + Template: "", + ValidDuration: 0, + UniqueKey: "", + ReceiverSelectors: []map[string]interface{}{}, }, { ID: "10911", @@ -264,10 +265,11 @@ func (s *NotificationRepositoryTestSuite) TestList() { Data: map[string]any{ "data-key": "data-value", }, - Labels: map[string]string{}, - ValidDuration: time.Duration(0), - Template: "expiry-alert", - UniqueKey: "", + Labels: map[string]string{}, + ValidDuration: time.Duration(0), + Template: "expiry-alert", + UniqueKey: "", + ReceiverSelectors: []map[string]interface{}{}, }, }, }, @@ -284,15 +286,16 @@ func (s *NotificationRepositoryTestSuite) TestList() { Data: map[string]any{ "data-key": "data-value", }, - Labels: map[string]string{}, - ValidDuration: time.Duration(0), - Template: "expiry-alert", - UniqueKey: "", + Labels: map[string]string{}, + ValidDuration: time.Duration(0), + Template: "expiry-alert", + UniqueKey: "", + ReceiverSelectors: []map[string]interface{}{}, }, }, }, { - Description: "should get all notifications with lable filter", + Description: "should get all notifications with label filter", Filter: notification.Filter{ Labels: map[string]string{ "label-key": "label-value", @@ -306,8 +309,11 @@ func (s *NotificationRepositoryTestSuite) TestList() { Data: map[string]interface{}{ "data-key": "data-value", }, - Labels: map[string]string{"label-key": "label-value"}, - Template: "", + Labels: map[string]string{"label-key": "label-value"}, + Template: "", + ReceiverSelectors: []map[string]interface{}{}, // Initialize this + UniqueKey: "", + ValidDuration: 0, }, { ID: "789", @@ -320,7 +326,7 @@ func (s *NotificationRepositoryTestSuite) TestList() { "label-key": "label-value", }, Template: "", - ReceiverSelectors: []map[string]string{ + ReceiverSelectors: []map[string]interface{}{ { "team": "gotocompany-infra", "severity": "WARNING", @@ -329,6 +335,8 @@ func (s *NotificationRepositoryTestSuite) TestList() { "id": "2", }, }, + UniqueKey: "", + ValidDuration: 0, }, }, }, diff --git a/pkg/pgc/type.go b/pkg/pgc/type.go index 25509497..d9ed524d 100644 --- a/pkg/pgc/type.go +++ b/pkg/pgc/type.go @@ -117,3 +117,24 @@ func (a ListString) Value() (driver.Value, error) { } return json.Marshal(a) } + +type ListStringAnyMap []map[string]interface{} + +func (m *ListStringAnyMap) Scan(value interface{}) error { + if value == nil { + *m = make(ListStringAnyMap, 0) + return nil + } + b, ok := value.([]byte) + if !ok { + return errors.New("failed type assertion to []byte") + } + return json.Unmarshal(b, m) +} + +func (a ListStringAnyMap) Value() (driver.Value, error) { + if len(a) == 0 { + return nil, nil + } + return json.Marshal(a) +}