From 80567990989e1689f8408ad39da4a5639ac1254f Mon Sep 17 00:00:00 2001 From: Stephane Erbrech Date: Wed, 2 Sep 2020 15:16:44 -0700 Subject: [PATCH 1/4] pass the context to the tokenprovider func --- accesstokenconnector.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/accesstokenconnector.go b/accesstokenconnector.go index 8dbe5099..48d8e9aa 100644 --- a/accesstokenconnector.go +++ b/accesstokenconnector.go @@ -16,13 +16,13 @@ var _ driver.Connector = &accessTokenConnector{} type accessTokenConnector struct { Connector - accessTokenProvider func() (string, error) + accessTokenProvider func(ctx context.Context) (string, error) } // NewAccessTokenConnector creates a new connector from a DSN and a token provider. // The token provider func will be called when a new connection is requested and should return a valid access token. // The returned connector may be used with sql.OpenDB. -func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) { +func NewAccessTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (driver.Connector, error) { if tokenProvider == nil { return nil, errors.New("mssql: tokenProvider cannot be nil") } @@ -42,7 +42,7 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) ( // Connect returns a new database connection func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { var err error - c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() + c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider(ctx) if err != nil { return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) } From 5b789b5a14cbbfa7e9c6308e50db451a7a21f940 Mon Sep 17 00:00:00 2001 From: Stephane Erbrech Date: Wed, 2 Sep 2020 15:18:33 -0700 Subject: [PATCH 2/4] update tests --- accesstokenconnector_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/accesstokenconnector_test.go b/accesstokenconnector_test.go index 826dedba..558b9f1e 100644 --- a/accesstokenconnector_test.go +++ b/accesstokenconnector_test.go @@ -13,10 +13,10 @@ import ( func TestNewAccessTokenConnector(t *testing.T) { dsn := "Server=server.database.windows.net;Database=db" - tp := func() (string, error) { return "token", nil } + tp := func(ctx context.Context) (string, error) { return "token", nil } type args struct { dsn string - tokenProvider func() (string, error) + tokenProvider func(ctx context.Context) (string, error) } tests := []struct { name string @@ -44,7 +44,7 @@ func TestNewAccessTokenConnector(t *testing.T) { if tc.accessTokenProvider == nil { return fmt.Errorf("Expected tokenProvider to not be nil") } - t, err := tc.accessTokenProvider() + t, err := tc.accessTokenProvider(context.TODO()) if t != "token" || err != nil { return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) } @@ -80,7 +80,7 @@ func TestNewAccessTokenConnector(t *testing.T) { func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) { errorText := "This is a test" dsn := "Server=server.database.windows.net;Database=db" - tp := func() (string, error) { return "", errors.New(errorText) } + tp := func(ctx context.Context) (string, error) { return "", errors.New(errorText) } sut, err := NewAccessTokenConnector(dsn, tp) if err != nil { t.Fatalf("expected err==nil, but got %+v", err) From 17d7c7d598531e96d1668830c85b668e6f1ea0f3 Mon Sep 17 00:00:00 2001 From: Stephane Erbrech Date: Wed, 2 Sep 2020 15:25:41 -0700 Subject: [PATCH 3/4] update example --- examples/azuread-accesstoken/managed_identity.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/azuread-accesstoken/managed_identity.go b/examples/azuread-accesstoken/managed_identity.go index 0adf99cb..7c7e0447 100644 --- a/examples/azuread-accesstoken/managed_identity.go +++ b/examples/azuread-accesstoken/managed_identity.go @@ -1,6 +1,7 @@ package main import ( + "context" "database/sql" "flag" "fmt" @@ -63,7 +64,7 @@ func main() { fmt.Printf("bye\n") } -func getMSITokenProvider() (func() (string, error), error) { +func getMSITokenProvider() (func(ctx context.Context) (string, error), error) { msiEndpoint, err := adal.GetMSIEndpoint() if err != nil { return nil, err @@ -74,8 +75,8 @@ func getMSITokenProvider() (func() (string, error), error) { return nil, err } - return func() (string, error) { - msi.EnsureFresh() + return func(ctx context.Context) (string, error) { + msi.EnsureFreshContext(ctx) token := msi.OAuthToken() return token, nil }, nil From c3f802842b39fe5468f4526347dbd8d23d2ef8da Mon Sep 17 00:00:00 2001 From: Stephane Erbrech Date: Wed, 2 Sep 2020 15:32:50 -0700 Subject: [PATCH 4/4] update example to newer version --- examples/azuread-accesstoken/go.mod | 2 +- examples/azuread-accesstoken/go.sum | 8 ++++---- examples/azuread-accesstoken/managed_identity.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/azuread-accesstoken/go.mod b/examples/azuread-accesstoken/go.mod index eb113f3e..f85bd1b6 100644 --- a/examples/azuread-accesstoken/go.mod +++ b/examples/azuread-accesstoken/go.mod @@ -4,5 +4,5 @@ go 1.13 require ( github.com/Azure/go-autorest/autorest/adal v0.8.1 - github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 + github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10 ) diff --git a/examples/azuread-accesstoken/go.sum b/examples/azuread-accesstoken/go.sum index 17148ff8..14517678 100644 --- a/examples/azuread-accesstoken/go.sum +++ b/examples/azuread-accesstoken/go.sum @@ -1,7 +1,5 @@ -github.com/Azure/go-autorest v13.3.2+incompatible h1:VxzPyuhtnlBOzc4IWCZHqpyH2d+QMLQEuy3wREyY4oc= github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs= github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= -github.com/Azure/go-autorest/autorest v0.9.4 h1:1cM+NmKw91+8h5vfjgzK4ZGLuN72k87XVZBWyGwNjUM= github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= github.com/Azure/go-autorest/autorest/adal v0.8.1 h1:pZdL8o72rK+avFWl+p9nE8RWi1JInZrWJYlnpfXJwHk= github.com/Azure/go-autorest/autorest/adal v0.8.1/go.mod h1:ZjhuQClTqx435SRJ2iMlOxPYt3d2C/T/7TiQCVZSn3Q= @@ -10,12 +8,14 @@ github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSW github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g= github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.3.0 h1:qJumjCaCudz+OcqE9/XtEPfvtOjOmKaui4EOpFI6zZc= github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM= +github.com/Azure/go-autorest/logger v0.1.0 h1:ruG4BSDXONFRrZZJ2GUXDiUyVpayPmb1GnWeHDdaNKY= github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= -github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 h1:OGNva6WhsKst5OZf7eZOklDztV3hwtTHovdrLHV+MsA= -github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10 h1:uuDqxM2PbeYyXcKIo/IP0ZLGDzougMipEBBrCOzr50w= +github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= diff --git a/examples/azuread-accesstoken/managed_identity.go b/examples/azuread-accesstoken/managed_identity.go index 7c7e0447..489efb32 100644 --- a/examples/azuread-accesstoken/managed_identity.go +++ b/examples/azuread-accesstoken/managed_identity.go @@ -76,7 +76,7 @@ func getMSITokenProvider() (func(ctx context.Context) (string, error), error) { } return func(ctx context.Context) (string, error) { - msi.EnsureFreshContext(ctx) + msi.EnsureFreshWithContext(ctx) token := msi.OAuthToken() return token, nil }, nil