Skip to content

Commit

Permalink
🧹 Use cnquery defined methods for AZ authentication. (#1239)
Browse files Browse the repository at this point in the history
Signed-off-by: Preslav <[email protected]>
  • Loading branch information
preslavgerchev authored Apr 8, 2024
1 parent dfd01fa commit 9ea0733
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 71 deletions.
72 changes: 2 additions & 70 deletions cli/reporter/azure_service_bus_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@ package reporter

import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
pol "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
"github.com/rs/zerolog/log"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/util/azauth"
"go.mondoo.com/cnspec/v10/policy"
)

Expand All @@ -36,7 +33,7 @@ func (h *azureSbusHandler) WriteReport(ctx context.Context, report *policy.Repor
senderName := parts[len(parts)-1]
sbusUrl := strings.TrimSuffix(trimmedUrl, "/"+senderName)

cred, err := h.GetTokenChain(ctx, &azidentity.DefaultAzureCredentialOptions{})
cred, err := azauth.GetChainedToken(&azidentity.DefaultAzureCredentialOptions{})
if err != nil {
return err
}
Expand Down Expand Up @@ -82,68 +79,3 @@ func (h *azureSbusHandler) convertReport(report *policy.ReportCollection) ([]byt
return nil, fmt.Errorf("'%s' is not supported in the azure service bus handler, please use one of the other formats", string(h.format))
}
}

// sometimes we run into a 'managed identity timed out' error when using a managed identity.
// according to https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/azidentity/TROUBLESHOOTING.md#troubleshoot-defaultazurecredential-authentication-issues
// we should instead use the NewManagedIdentityCredential directly. this function adds a bit more by
// also using other credentials to create a chained token credential
func (h *azureSbusHandler) GetTokenChain(ctx context.Context, options *azidentity.DefaultAzureCredentialOptions) (*azidentity.ChainedTokenCredential, error) {
if options == nil {
options = &azidentity.DefaultAzureCredentialOptions{}
}

chain := []azcore.TokenCredential{}

cli, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{})
if err == nil {
chain = append(chain, cli)
}
envCred, err := azidentity.NewEnvironmentCredential(&azidentity.EnvironmentCredentialOptions{ClientOptions: options.ClientOptions})
if err == nil {
chain = append(chain, envCred)
}
mic, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions})
if err == nil {
timedMic := &TimedManagedIdentityCredential{mic: *mic, timeout: 5 * time.Second}
chain = append(chain, timedMic)
}
wic, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
TenantID: options.TenantID,
})
if err == nil {
chain = append(chain, wic)
}

return azidentity.NewChainedTokenCredential(chain, nil)
}

type TimedManagedIdentityCredential struct {
mic azidentity.ManagedIdentityCredential
timeout time.Duration
}

func (t *TimedManagedIdentityCredential) GetToken(ctx context.Context, opts pol.TokenRequestOptions) (azcore.AccessToken, error) {
ctx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
var tk azcore.AccessToken
var err error
if t.timeout > 0 {
c, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
tk, err = t.mic.GetToken(c, opts)
if err != nil {
var authFailedErr *azidentity.AuthenticationFailedError
if errors.As(err, &authFailedErr) && strings.Contains(err.Error(), "context deadline exceeded") {
err = azidentity.NewCredentialUnavailableError("managed identity request timed out")
}
} else {
// some managed identity implementation is available, so don't apply the timeout to future calls
t.timeout = 0
}
} else {
tk, err = t.mic.GetToken(ctx, opts)
}
return tk, err
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ go 1.22
toolchain go1.22.0

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1
github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus v1.7.0
github.com/Masterminds/semver v1.5.0
Expand Down Expand Up @@ -65,6 +64,7 @@ require (
github.com/Antonboom/nilnil v0.1.7 // indirect
github.com/Antonboom/testifylint v1.2.0 // indirect
github.com/Azure/azure-amqp-common-go/v3 v3.2.3 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect
github.com/Azure/go-amqp v1.0.5 // indirect
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
Expand Down

0 comments on commit 9ea0733

Please sign in to comment.