Skip to content

Commit

Permalink
Add access monitoring rules to msteams plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardDowling committed Oct 16, 2024
1 parent dc3052b commit 942da4a
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 27 deletions.
96 changes: 78 additions & 18 deletions integrations/access/msteams/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ package msteams
import (
"context"
"log/slog"
"slices"
"time"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/accessmonitoring"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/common/teleport"
"github.com/gravitational/teleport/integrations/lib"
Expand Down Expand Up @@ -53,7 +55,8 @@ type App struct {
watcherJob lib.ServiceJob
pd *pd.CompareAndSwap[PluginData]

log *slog.Logger
log *slog.Logger
accessMonitoringRules *accessmonitoring.RuleHandler

*lib.Process
}
Expand Down Expand Up @@ -85,13 +88,11 @@ func (a *App) Run(ctx context.Context) error {
}

a.Process = lib.NewProcess(ctx)
a.watcherJob, err = a.newWatcherJob()
if err != nil {
return trace.Wrap(err)
}

a.SpawnCriticalJob(a.mainJob)
a.SpawnCriticalJob(a.watcherJob)

select {
case <-ctx.Done():
Expand All @@ -116,10 +117,14 @@ func (a *App) init(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, initTimeout)
defer cancel()

var err error
a.apiClient, err = common.GetTeleportClient(ctx, a.conf.Teleport)
if err != nil {
return trace.Wrap(err)
if a.conf.Client != nil {
a.apiClient = a.conf.Client
} else {
var err error
a.apiClient, err = common.GetTeleportClient(ctx, a.conf.Teleport)
if err != nil {
return trace.Wrap(err)
}
}

a.pd = pd.NewCAS(
Expand All @@ -145,6 +150,24 @@ func (a *App) init(ctx context.Context) error {
return trace.Wrap(err)
}

a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: a.apiClient,
PluginName: pluginName,
// Map msteams.RecipientData onto the common recipient type used
// by the access monitoring rules watcher.
FetchRecipientCallback: func(ctx context.Context, name string) (*common.Recipient, error) {
msTeamsRecipient, err := a.bot.FetchRecipient(ctx, name)
if err != nil {
return nil, trace.Wrap(err)
}
return &common.Recipient{
Name: name,
ID: msTeamsRecipient.ID,
Kind: string(msTeamsRecipient.Kind),
}, nil
},
})

return a.initBot(ctx)
}

Expand Down Expand Up @@ -187,27 +210,52 @@ func (a *App) initBot(ctx context.Context) error {
return nil
}

// newWatcherJob creates WatcherJob
func (a *App) newWatcherJob() (lib.ServiceJob, error) {
return watcherjob.NewJob(
// run starts the main process
func (a *App) run(ctx context.Context) error {

process := lib.MustGetProcess(ctx)

watchKinds := []types.WatchKind{
{Kind: types.KindAccessRequest},
{Kind: types.KindAccessMonitoringRule},
}
acceptedWatchKinds := make([]string, 0, len(watchKinds))
watcherJob, err := watcherjob.NewJobWithConfirmedWatchKinds(
a.apiClient,
watcherjob.Config{
Watch: types.Watch{
Kinds: []types.WatchKind{{Kind: types.KindAccessRequest}},
},
Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true},
EventFuncTimeout: handlerTimeout,
},
a.onWatcherEvent,
func(ws types.WatchStatus) {
for _, watchKind := range ws.GetKinds() {
acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind)
}
},
)
}

// run starts the main process
func (a *App) run(ctx context.Context) error {
ok, err := a.watcherJob.WaitReady(ctx)
if err != nil {
return trace.Wrap(err)
}

process.SpawnCriticalJob(watcherJob)

ok, err := watcherJob.WaitReady(ctx)
if err != nil {
return trace.Wrap(err)
}
if len(acceptedWatchKinds) == 0 {
return trace.BadParameter("failed to initialize watcher for all the required resources: %+v",
watchKinds)
}
// Check if KindAccessMonitoringRule resources are being watched,
// the role the plugin is running as may not have access.
if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) {
if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil {
return trace.Wrap(err, "initializing Access Monitoring Rule cache")
}
}
a.watcherJob = watcherJob
a.watcherJob.SetReady(ok)
if ok {
a.log.InfoContext(ctx, "Plugin is ready")
} else {
Expand Down Expand Up @@ -243,6 +291,10 @@ func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, err
// onWatcherEvent called when an access request event is received
func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
kind := event.Resource.GetKind()
if kind == types.KindAccessMonitoringRule {
return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event))
}

if kind != types.KindAccessRequest {
return trace.Errorf("unexpected kind %s", kind)
}
Expand Down Expand Up @@ -480,6 +532,14 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest)
recipientSet := stringset.New()

a.log.DebugContext(ctx, "Getting suggested reviewer recipients")
accessRuleRecipients := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req)
accessRuleRecipients.ForEach(func(r common.Recipient) {
recipientSet.Add(r.Name)
})
if recipientSet.Len() != 0 {
return recipientSet.ToSlice()
}

var validEmailsSuggReviewers []string
for _, reviewer := range req.GetSuggestedReviewers() {
if !lib.IsEmail(reviewer) {
Expand Down
80 changes: 71 additions & 9 deletions integrations/access/msteams/testlib/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
v1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/msteams"
Expand All @@ -37,7 +39,7 @@ import (
"github.com/gravitational/teleport/integrations/lib/testing/integration"
)

// MsTeamsBaseSuite is the Slack access plugin test suite.
// MsTeamsBaseSuite is the MsTeams access plugin test suite.
// It implements the testify.TestingSuite interface.
type MsTeamsBaseSuite struct {
*integration.AccessRequestSuite
Expand All @@ -51,16 +53,19 @@ type MsTeamsBaseSuite struct {
reviewer2TeamsUser msapi.User
}

// SetupTest starts a fake Slack, generates the plugin configuration, and loads
// the fixtures in Slack. It runs for each test.
// SetupTest starts a fake MsTeams, generates the plugin configuration, and loads
// the fixtures in MsTeams. It runs for each test.
func (s *MsTeamsBaseSuite) SetupTest() {
t := s.T()

err := logger.Setup(logger.Config{Severity: "debug"})
require.NoError(t, err)
s.raceNumber = runtime.GOMAXPROCS(0)

s.fakeTeams = NewFakeTeams(s.raceNumber)
t.Cleanup(s.fakeTeams.Close)

// We need requester users as well, the slack plugin sends messages to users
// We need requester users as well, the MsTeams plugin sends messages to users
// when their access request got approved.
s.requesterOSSTeamsUser = s.fakeTeams.StoreUser(msapi.User{Name: "Requester OSS", Mail: integration.RequesterOSSUserName})
s.requester1TeamsUser = s.fakeTeams.StoreUser(msapi.User{Name: "Requester Ent", Mail: integration.Requester1UserName})
Expand All @@ -71,16 +76,17 @@ func (s *MsTeamsBaseSuite) SetupTest() {

var conf msteams.Config
conf.Teleport = s.TeleportConfig()
apiClient, err := common.GetTeleportClient(context.Background(), s.TeleportConfig())
require.NoError(t, err)
conf.Client = apiClient
conf.StatusSink = s.fakeStatusSink
conf.MSAPI = s.fakeTeams.Config
conf.MSAPI.SetBaseURLs(s.fakeTeams.URL(), s.fakeTeams.URL(), s.fakeTeams.URL())
conf.Log = logger.Config{
Severity: "debug",
}

s.appConfig = &conf
}

// startApp starts the Slack plugin, waits for it to become ready and returns.
// startApp starts the MsTeams plugin, waits for it to become ready and returns.
func (s *MsTeamsBaseSuite) startApp() {
s.T().Helper()
t := s.T()
Expand Down Expand Up @@ -414,7 +420,9 @@ func (s *MsTeamsSuiteEnterprise) TestRace() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
t.Cleanup(cancel)

s.appConfig.Log.Severity = "debug" // Turn off noisy debug logging
err := logger.Setup(logger.Config{Severity: "info"}) // Turn off noisy debug logging
require.NoError(t, err)

s.startApp()

var (
Expand Down Expand Up @@ -527,3 +535,57 @@ func (s *MsTeamsSuiteEnterprise) TestRace() {
return next
})
}

func (s *MsTeamsSuiteOSS) TestRecipientsFromAccessMonitoringRule() {
t := s.T()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

s.startApp()

_, err := s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().
CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &v1.Metadata{
Name: "test-msteams-amr",
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "!is_empty(access_request.spec.roles)",
Notification: &accessmonitoringrulesv1.Notification{
Name: "msteams",
Recipients: []string{
s.reviewer1TeamsUser.ID,
s.reviewer2TeamsUser.Mail,
},
},
},
})
assert.NoError(t, err)

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)

s.checkPluginData(ctx, req.GetName(), func(data msteams.PluginData) bool {
return len(data.TeamsData) > 0
})

title := "Access Request " + req.GetName()
msgs, err := s.getNewMessages(ctx, 2)
require.NoError(t, err)

var body1 testTeamsMessage
require.NoError(t, json.Unmarshal([]byte(msgs[0].Body), &body1))
body1.checkTitle(t, title)
require.Equal(t, msgs[0].RecipientID, s.reviewer1TeamsUser.ID)

var body2 testTeamsMessage
require.NoError(t, json.Unmarshal([]byte(msgs[1].Body), &body2))
body1.checkTitle(t, title)
require.Equal(t, msgs[1].RecipientID, s.reviewer2TeamsUser.ID)

assert.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-msteams-amr"))
}

0 comments on commit 942da4a

Please sign in to comment.