From 18029cc487a19537b02568c17525549c8339eb79 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Mon, 9 Sep 2024 17:07:12 +0100 Subject: [PATCH 1/7] Fix pagerduty AMR test to prevent flakiness --- .../access/pagerduty/testlib/suite.go | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index c379c85219a5..cdf8a105f16b 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -453,24 +453,27 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { }) assert.NoError(t, err) - // Test execution: create an access request - req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) - - // Validate the incident has been created in Pagerduty and its ID is stored - // in the plugin_data. - pluginData := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { - return data.IncidentID != "" - }) + // Incident creation may happen before plugins Access Monitoring Rule cache + // has been updated with new rule. Retry until the new rule starts applying. + require.Eventually(t, func() bool { + // Test execution: create an access request + req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + + // Validate the incident has been created in Pagerduty and its ID is stored + // in the plugin_data. + pluginData := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { + return data.IncidentID != "" + }) - incident, err := s.fakePagerduty.CheckNewIncident(ctx) - require.NoError(t, err, "no new incidents stored") + incident, err := s.fakePagerduty.CheckNewIncident(ctx) + assert.NoError(t, err, "no new incidents stored") + assert.Equal(t, incident.ID, pluginData.IncidentID) - assert.Equal(t, incident.ID, pluginData.IncidentID) - assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID) - - assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) - assert.Equal(t, "triggered", incident.Status) + assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) + assert.Equal(t, "triggered", incident.Status) + return s.pdNotifyService2.ID == pluginData.ServiceID + }, 10*time.Second, time.Second, "new access monitoring rule did not begin applying") assert.NoError(t, s.ClientByName(integration.RulerUserName). AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr")) } From 64620a0a50140225d945bf254ec3933f9b4002ba Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Tue, 10 Sep 2024 16:49:36 +0100 Subject: [PATCH 2/7] Update integrations/access/pagerduty/testlib/suite.go Co-authored-by: Zac Bergquist --- integrations/access/pagerduty/testlib/suite.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index cdf8a105f16b..995841a2ab18 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -461,7 +461,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { // Validate the incident has been created in Pagerduty and its ID is stored // in the plugin_data. - pluginData := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { + hasIncidentID := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { return data.IncidentID != "" }) From 1b4576496f06ea9216694632f56c93588e659e08 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Tue, 10 Sep 2024 16:53:04 +0100 Subject: [PATCH 3/7] Swap pagerduty test to use EventuallyWith --- integrations/access/pagerduty/testlib/suite.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index 995841a2ab18..6f5c94f0527d 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -455,7 +455,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { // Incident creation may happen before plugins Access Monitoring Rule cache // has been updated with new rule. Retry until the new rule starts applying. - require.Eventually(t, func() bool { + require.EventuallyWithT(t, func(t *assert.CollectT) { // Test execution: create an access request req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) @@ -467,12 +467,12 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { incident, err := s.fakePagerduty.CheckNewIncident(ctx) assert.NoError(t, err, "no new incidents stored") - assert.Equal(t, incident.ID, pluginData.IncidentID) + assert.Equal(t, incident.ID, hasIncidentID.IncidentID) assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) assert.Equal(t, "triggered", incident.Status) - return s.pdNotifyService2.ID == pluginData.ServiceID + assert.Equal(t, s.pdNotifyService2.ID, hasIncidentID.ServiceID) }, 10*time.Second, time.Second, "new access monitoring rule did not begin applying") assert.NoError(t, s.ClientByName(integration.RulerUserName). AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr")) From 0fbc2bab41291cac9f885c2e5915d690e47935b6 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Mon, 7 Oct 2024 17:53:54 +0100 Subject: [PATCH 4/7] Update pagerduty tests to not create several access requests --- .../access_monitoring_rules.go | 10 +++++ integrations/access/pagerduty/app.go | 15 +++++++ integrations/access/pagerduty/config.go | 4 ++ .../access/pagerduty/testlib/suite.go | 41 +++++++++++-------- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 72eb921f3f4e..5004f81665cf 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -48,6 +48,7 @@ type RuleHandler struct { pluginName string fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) + onCacheUpdateCallback func(name string) error } // RuleMap is a concurrent map for access monitoring rules. @@ -65,6 +66,8 @@ type RuleHandlerConfig struct { // FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients. FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) + // OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated. + OnCacheUpdateCallback func(name string) error } // NewRuleHandler returns a new RuleHandler. @@ -77,6 +80,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler { pluginType: conf.PluginType, pluginName: conf.PluginName, fetchRecipientCallback: conf.FetchRecipientCallback, + onCacheUpdateCallback: conf.OnCacheUpdateCallback, } } @@ -93,6 +97,9 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err continue } amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr + if amrh.onCacheUpdateCallback != nil { + amrh.onCacheUpdateCallback(amr.GetMetadata().Name) + } } return nil } @@ -123,6 +130,9 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t return nil } amrh.accessMonitoringRules.rules[req.Metadata.Name] = req + if amrh.onCacheUpdateCallback != nil { + amrh.onCacheUpdateCallback(req.Metadata.Name) + } return nil case types.OpDelete: delete(amrh.accessMonitoringRules.rules, event.Resource.GetName()) diff --git a/integrations/access/pagerduty/app.go b/integrations/access/pagerduty/app.go index abe1fdf2f8ce..ca3592468d46 100644 --- a/integrations/access/pagerduty/app.go +++ b/integrations/access/pagerduty/app.go @@ -79,6 +79,21 @@ func NewApp(conf Config) (*App, error) { statusSink: conf.StatusSink, } + amrhConf := accessmonitoring.RuleHandlerConfig{ + Client: conf.Client, + PluginType: types.PluginTypePagerDuty, + FetchRecipientCallback: func(_ context.Context, name string) (*common.Recipient, error) { + return &common.Recipient{ + Name: name, + ID: name, + Kind: common.RecipientKindSchedule, + }, nil + }, + } + if conf.OnAccessMonitoringRuleCacheUpdateCallback != nil { + amrhConf.OnCacheUpdateCallback = conf.OnAccessMonitoringRuleCacheUpdateCallback + } + app.accessMonitoringRules = accessmonitoring.NewRuleHandler(amrhConf) app.mainJob = lib.NewServiceJob(app.run) return app, nil diff --git a/integrations/access/pagerduty/config.go b/integrations/access/pagerduty/config.go index f76e9d2f955f..a879cdbc8a69 100644 --- a/integrations/access/pagerduty/config.go +++ b/integrations/access/pagerduty/config.go @@ -47,6 +47,10 @@ type Config struct { // TeleportUser is the name of the Teleport user that will act // as the access request approver TeleportUser string + + // OnAccessMonitoringRuleCacheUpdateCallback is used for checking when + // the Rule cache is updated in tests + OnAccessMonitoringRuleCacheUpdateCallback func(name string) error } type PagerdutyConfig struct { diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index 6f5c94f0527d..d68530614f30 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -430,6 +430,12 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) t.Cleanup(cancel) + ruleName := "test-pagerduty-amr" + lastRuleUpdated := "" + s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(name string) error { + lastRuleUpdated = ruleName + return nil + } s.startApp() _, err := s.ClientByName(integration.RulerUserName). @@ -438,7 +444,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { Kind: types.KindAccessMonitoringRule, Version: types.V1, Metadata: &v1.Metadata{ - Name: "test-pagerduty-amr", + Name: ruleName, }, Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ Subjects: []string{types.KindAccessRequest}, @@ -454,26 +460,29 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { assert.NoError(t, err) // Incident creation may happen before plugins Access Monitoring Rule cache - // has been updated with new rule. Retry until the new rule starts applying. + // has been updated with new rule. Retry until the new cache picks up the rule. require.EventuallyWithT(t, func(t *assert.CollectT) { - // Test execution: create an access request - req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + require.Equal(t, lastRuleUpdated, ruleName) + }, 10*time.Second, time.Second, "new access monitoring rule did not begin applying") - // Validate the incident has been created in Pagerduty and its ID is stored - // in the plugin_data. - hasIncidentID := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { - return data.IncidentID != "" - }) + // Test execution: create an access request + req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) - incident, err := s.fakePagerduty.CheckNewIncident(ctx) - assert.NoError(t, err, "no new incidents stored") - assert.Equal(t, incident.ID, hasIncidentID.IncidentID) + // Validate the incident has been created in Pagerduty and its ID is stored + // in the plugin_data. + hasIncidentID := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool { + return data.IncidentID != "" + }) - assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) - assert.Equal(t, "triggered", incident.Status) + incident, err := s.fakePagerduty.CheckNewIncident(ctx) + assert.NoError(t, err, "no new incidents stored") + assert.Equal(t, incident.ID, hasIncidentID.IncidentID) + + assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey) + assert.Equal(t, "triggered", incident.Status) + + assert.Equal(t, s.pdNotifyService2.ID, hasIncidentID.ServiceID) - assert.Equal(t, s.pdNotifyService2.ID, hasIncidentID.ServiceID) - }, 10*time.Second, time.Second, "new access monitoring rule did not begin applying") assert.NoError(t, s.ClientByName(integration.RulerUserName). AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr")) } From 24251a17b8621f6028b000cb01c0a4bca1c8e288 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Thu, 10 Oct 2024 16:24:18 +0100 Subject: [PATCH 5/7] Make more information available to AMR cache update callback --- .../access/accessmonitoring/access_monitoring_rules.go | 6 +++--- integrations/access/pagerduty/config.go | 4 +++- integrations/access/pagerduty/testlib/suite.go | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 5004f81665cf..0146e07fa0d9 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -48,7 +48,7 @@ type RuleHandler struct { pluginName string fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) - onCacheUpdateCallback func(name string) error + onCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } // RuleMap is a concurrent map for access monitoring rules. @@ -67,7 +67,7 @@ type RuleHandlerConfig struct { // FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients. FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error) // OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated. - OnCacheUpdateCallback func(name string) error + OnCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } // NewRuleHandler returns a new RuleHandler. @@ -98,7 +98,7 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err } amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr if amrh.onCacheUpdateCallback != nil { - amrh.onCacheUpdateCallback(amr.GetMetadata().Name) + amrh.onCacheUpdateCallback(types.OpPut, amr.GetMetadata().Name, amr) } } return nil diff --git a/integrations/access/pagerduty/config.go b/integrations/access/pagerduty/config.go index a879cdbc8a69..8bf7060652b0 100644 --- a/integrations/access/pagerduty/config.go +++ b/integrations/access/pagerduty/config.go @@ -24,6 +24,8 @@ import ( "github.com/gravitational/trace" "github.com/pelletier/go-toml" + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" @@ -50,7 +52,7 @@ type Config struct { // OnAccessMonitoringRuleCacheUpdateCallback is used for checking when // the Rule cache is updated in tests - OnAccessMonitoringRuleCacheUpdateCallback func(name string) error + OnAccessMonitoringRuleCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error } type PagerdutyConfig struct { diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index d68530614f30..e3b518b6c5d1 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -432,7 +432,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { ruleName := "test-pagerduty-amr" lastRuleUpdated := "" - s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(name string) error { + s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error { lastRuleUpdated = ruleName return nil } @@ -463,7 +463,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { // has been updated with new rule. Retry until the new cache picks up the rule. require.EventuallyWithT(t, func(t *assert.CollectT) { require.Equal(t, lastRuleUpdated, ruleName) - }, 10*time.Second, time.Second, "new access monitoring rule did not begin applying") + }, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying") // Test execution: create an access request req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) From 47743e41a0fe5e6973ad1e3c19fe8eb907401d77 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Wed, 16 Oct 2024 16:11:09 +0100 Subject: [PATCH 6/7] Update integrations/access/pagerduty/testlib/suite.go Co-authored-by: Tiago Silva --- integrations/access/pagerduty/testlib/suite.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index e3b518b6c5d1..4e6bfdaf8693 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -430,10 +430,13 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) t.Cleanup(cancel) - ruleName := "test-pagerduty-amr" - lastRuleUpdated := "" + const ruleName = "test-pagerduty-amr" + var collectedNames []string + var mu sync.Mutex s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error { - lastRuleUpdated = ruleName + mu.Lock() + collectedNames=append(collectedNames, name) + mu.Unlock() return nil } s.startApp() From aae10852e81c9e63a81ea5326b6f27da91e1f237 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Wed, 16 Oct 2024 16:11:16 +0100 Subject: [PATCH 7/7] Update integrations/access/pagerduty/testlib/suite.go Co-authored-by: Tiago Silva --- integrations/access/pagerduty/testlib/suite.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integrations/access/pagerduty/testlib/suite.go b/integrations/access/pagerduty/testlib/suite.go index 4e6bfdaf8693..5d6bb3aadb2e 100644 --- a/integrations/access/pagerduty/testlib/suite.go +++ b/integrations/access/pagerduty/testlib/suite.go @@ -465,7 +465,9 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() { // Incident creation may happen before plugins Access Monitoring Rule cache // has been updated with new rule. Retry until the new cache picks up the rule. require.EventuallyWithT(t, func(t *assert.CollectT) { - require.Equal(t, lastRuleUpdated, ruleName) + mu.Lock() + require.Contains(t, collectedNames, ruleName) + mu.UnLock() }, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying") // Test execution: create an access request