Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pagerduty AMR test to prevent flakiness #46390

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions integrations/access/accessmonitoring/access_monitoring_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type RuleHandler struct {
pluginName string

fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
onCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// RuleMap is a concurrent map for access monitoring rules.
Expand All @@ -65,6 +66,8 @@ type RuleHandlerConfig struct {

// FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients.
FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
// OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated.
OnCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// NewRuleHandler returns a new RuleHandler.
Expand All @@ -77,6 +80,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler {
pluginType: conf.PluginType,
pluginName: conf.PluginName,
fetchRecipientCallback: conf.FetchRecipientCallback,
onCacheUpdateCallback: conf.OnCacheUpdateCallback,
}
}

Expand All @@ -93,6 +97,9 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err
continue
}
amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, amr.GetMetadata().Name, amr)
}
}
return nil
}
Expand Down Expand Up @@ -123,6 +130,9 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t
return nil
}
amrh.accessMonitoringRules.rules[req.Metadata.Name] = req
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(req.Metadata.Name)
}
return nil
case types.OpDelete:
delete(amrh.accessMonitoringRules.rules, event.Resource.GetName())
Expand Down
15 changes: 15 additions & 0 deletions integrations/access/pagerduty/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ func NewApp(conf Config) (*App, error) {
statusSink: conf.StatusSink,
}

amrhConf := accessmonitoring.RuleHandlerConfig{
Client: conf.Client,
PluginType: types.PluginTypePagerDuty,
FetchRecipientCallback: func(_ context.Context, name string) (*common.Recipient, error) {
return &common.Recipient{
Name: name,
ID: name,
Kind: common.RecipientKindSchedule,
}, nil
},
}
if conf.OnAccessMonitoringRuleCacheUpdateCallback != nil {
amrhConf.OnCacheUpdateCallback = conf.OnAccessMonitoringRuleCacheUpdateCallback
}
app.accessMonitoringRules = accessmonitoring.NewRuleHandler(amrhConf)
app.mainJob = lib.NewServiceJob(app.run)

return app, nil
Expand Down
6 changes: 6 additions & 0 deletions integrations/access/pagerduty/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/gravitational/trace"
"github.com/pelletier/go-toml"

accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/common/teleport"
"github.com/gravitational/teleport/integrations/lib"
Expand All @@ -47,6 +49,10 @@ type Config struct {
// TeleportUser is the name of the Teleport user that will act
// as the access request approver
TeleportUser string

// OnAccessMonitoringRuleCacheUpdateCallback is used for checking when
// the Rule cache is updated in tests
OnAccessMonitoringRuleCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

type PagerdutyConfig struct {
Expand Down
24 changes: 18 additions & 6 deletions integrations/access/pagerduty/testlib/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,12 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

ruleName := "test-pagerduty-amr"
lastRuleUpdated := ""
s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error {
lastRuleUpdated = ruleName
return nil
}
EdwardDowling marked this conversation as resolved.
Show resolved Hide resolved
s.startApp()

_, err := s.ClientByName(integration.RulerUserName).
Expand All @@ -438,7 +444,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &v1.Metadata{
Name: "test-pagerduty-amr",
Name: ruleName,
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Expand All @@ -453,24 +459,30 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
})
assert.NoError(t, err)

// Incident creation may happen before plugins Access Monitoring Rule cache
// has been updated with new rule. Retry until the new cache picks up the rule.
require.EventuallyWithT(t, func(t *assert.CollectT) {
require.Equal(t, lastRuleUpdated, ruleName)
}, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying")
EdwardDowling marked this conversation as resolved.
Show resolved Hide resolved

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)

// Validate the incident has been created in Pagerduty and its ID is stored
// in the plugin_data.
pluginData := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool {
hasIncidentID := s.checkPluginData(ctx, req.GetName(), func(data pagerduty.PluginData) bool {
return data.IncidentID != ""
})

incident, err := s.fakePagerduty.CheckNewIncident(ctx)
require.NoError(t, err, "no new incidents stored")

assert.Equal(t, incident.ID, pluginData.IncidentID)
assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID)
assert.NoError(t, err, "no new incidents stored")
assert.Equal(t, incident.ID, hasIncidentID.IncidentID)

assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey)
assert.Equal(t, "triggered", incident.Status)

assert.Equal(t, s.pdNotifyService2.ID, hasIncidentID.ServiceID)

assert.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr"))
}
Expand Down
Loading