Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Active Directory authentication #548

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.terraform
*.tfstate*
*.log
*.swp
*~
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ Other supported formats are listed below.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `ServerSPN` - The Kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.

* `FedAuth` - The federated authentication scheme to use.
* `ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Set the `user` to `client-ID@tenant-ID` and the `password` to the client secret. If using client certificates, provide the path to the PKCS#12 file containing the certificate and RSA private key in the `ClientCertPath` parameter, and set the `password` to the value needed to open the PKCS#12 file.
* `ActiveDirectoryMSI` - authenticates using the managed service identity (MSI) attached to the VM, or a specific user-assigned identity if a client ID is specified in the `user` field.
* `ActiveDirectoryPassword` - authenticates an Azure Active Directory user account in the form `[email protected]` with a password. This method is not recommended for general use and does not support multi-factor authentication for accounts.

### The connection string can be specified in one of three formats:


Expand Down Expand Up @@ -106,6 +110,24 @@ Other supported formats are listed below.
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar"

### Using Azure Active Directory access tokens for authentication

Azure Active Directory (AAD) is not supported through the DSN, but through a "token provider" callback.
Since access tokens are relatively short lived and need to be valid when a new connection is made,
a better way to provide them is using a connector:
``` golang
conn, err := mssql.NewAccessTokenConnector(
"Server=test.database.windows.net;Database=testdb",
tokenProvider)
if err != nil {
// handle errors in DSN parsing
}
db := sql.OpenDB(conn)
```
Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements
actually trigger the retrieval of a token, this happens when the first statment is issued and a connection
is created.

## Executing Stored Procedures

To run a stored procedure, set the query text to the procedure name:
Expand Down
51 changes: 51 additions & 0 deletions accesstokenconnector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// +build go1.10

package mssql

import (
"context"
"database/sql/driver"
"errors"
"fmt"
)

var _ driver.Connector = &accessTokenConnector{}

// accessTokenConnector wraps Connector and injects a
// fresh access token when connecting to the database
type accessTokenConnector struct {
Connector

accessTokenProvider func() (string, error)
}

// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}

conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}

c := &accessTokenConnector{
Connector: *conn,
accessTokenProvider: tokenProvider,
}
return c, nil
}

// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
}

return c.Connector.Connect(ctx)
}
86 changes: 86 additions & 0 deletions accesstokenconnector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// +build go1.10

package mssql

import (
"context"
"database/sql/driver"
"errors"
"fmt"
"strings"
"testing"
)

func TestNewAccessTokenConnector(t *testing.T) {
dsn := "Server=server.database.windows.net;Database=db"
tp := func() (string, error) { return "token", nil }
type args struct {
dsn string
tokenProvider func() (string, error)
}
tests := []struct {
name string
args args
want func(driver.Connector) error
wantErr bool
}{
{"Happy path",
args{dsn, tp},
func(c driver.Connector) error {
tc, ok := c.(*accessTokenConnector)
if !ok {
return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c)
}
p := tc.Connector.params
if p.database != "db" {
return fmt.Errorf("expected params.database=db, but got %v", p.database)
}
if p.host != "server.database.windows.net" {
return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host)
}
if tc.accessTokenProvider == nil {
return fmt.Errorf("Expected tokenProvider to not be nil")
}
t, err := tc.accessTokenProvider()
if t != "token" || err != nil {
return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err)
}
return nil
},
false,
},
{"Nil tokenProvider gives error",
args{dsn, nil},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewAccessTokenConnector(tt.args.dsn, tt.args.tokenProvider)
if (err != nil) != tt.wantErr {
t.Errorf("NewAccessTokenConnector() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil {
if err := tt.want(got); err != nil {
t.Error(err)
}
}
})
}
}

func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) {
errorText := "This is a test"
dsn := "Server=server.database.windows.net;Database=db"
tp := func() (string, error) { return "", errors.New(errorText) }
sut, err := NewAccessTokenConnector(dsn, tp)
if err != nil {
t.Fatalf("expected err==nil, but got %+v", err)
}
_, err = sut.Connect(context.TODO())
if err == nil || !strings.Contains(err.Error(), errorText) {
t.Fatalf("expected error to contain %q, but got %q", errorText, err)
}
}
2 changes: 2 additions & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ install:
- go version
- go env
- go get -u github.com/golang-sql/civil
- go get -u golang.org/x/crypto/pkcs12
- go get -u github.com/Azure/go-autorest/autorest/adal

build_script:
- go build
Expand Down
2 changes: 2 additions & 0 deletions codecov.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ignore:
- token_string.go
52 changes: 52 additions & 0 deletions conn_str.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mssql

import (
"errors"
"fmt"
"net"
"net/url"
Expand All @@ -13,6 +14,12 @@ import (

const defaultServerPort = 1433

const (
fedAuthActiveDirectoryPassword = "ActiveDirectoryPassword"
fedAuthActiveDirectoryMSI = "ActiveDirectoryMSI"
fedAuthActiveDirectoryApplication = "ActiveDirectoryApplication"
)

type connectParams struct {
logFlags uint64
port uint64
Expand All @@ -37,6 +44,12 @@ type connectParams struct {
failOverPartner string
failOverPort uint64
packetSize uint16
fedAuthLibrary byte
fedAuthADALWorkflow byte
fedAuthAccessToken string
aadTenantID string
aadClientCertPath string
tlsKeyLogFile string
}

func parseConnectParams(dsn string) (connectParams, error) {
Expand Down Expand Up @@ -229,6 +242,45 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}

p.fedAuthLibrary = fedAuthLibraryReserved
fedAuth, ok := params["fedauth"]
if ok {
switch {
case strings.EqualFold(fedAuth, fedAuthActiveDirectoryPassword):
p.fedAuthLibrary = fedAuthLibraryADAL
p.fedAuthADALWorkflow = fedAuthADALWorkflowPassword
case strings.EqualFold(fedAuth, fedAuthActiveDirectoryMSI):
p.fedAuthLibrary = fedAuthLibraryADAL
p.fedAuthADALWorkflow = fedAuthADALWorkflowMSI
case strings.EqualFold(fedAuth, fedAuthActiveDirectoryApplication):
p.fedAuthLibrary = fedAuthLibrarySecurityToken
p.aadClientCertPath = params["clientcertpath"]

// Split the user name into client id and tenant id at the @ symbol
at := strings.IndexRune(p.user, '@')
if at < 1 || at >= (len(p.user)-1) {
f := "Expecting user id to be clientID@tenantID: found '%s'"
return p, fmt.Errorf(f, p.user)
}

p.aadTenantID = p.user[at+1:]
p.user = p.user[0:at]
default:
f := "Invalid federated authentication type '%s': expected %s, %s or %s"
return p, fmt.Errorf(f, fedAuth, fedAuthActiveDirectoryPassword, fedAuthActiveDirectoryMSI, fedAuthActiveDirectoryApplication)
}

if p.disableEncryption {
f := "Encryption must not be disabled for federated authentication: encrypt='%s'"
return p, fmt.Errorf(f, encrypt)
}
}

p.tlsKeyLogFile, ok = params["tls key log file"]
if ok && p.tlsKeyLogFile != "" && p.disableEncryption {
return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled")
}

return p, nil
}

Expand Down
20 changes: 20 additions & 0 deletions conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ func TestInvalidConnectionString(t *testing.T) {
"trustservercertificate=invalid",
"failoverport=invalid",
"applicationintent=ReadOnly",
"encrypt=DISABLE;tls key log file=key.log",

// AAD
"fedauth=ActiveDirectoryApplication;user id=clientidwithouttenantid;clientcertpath=/secrets/spn.pem",
"fedauth=UnknownType",
// encryption cannot be disabled for AAD
"encrypt=DISABLE;fedauth=ActiveDirectoryPassword;user [email protected];password=secret",
"encrypt=DISABLE;fedauth=ActiveDirectoryMSI",
"encrypt=DISABLE;fedauth=ActiveDirectoryApplication;user id=clientid@tenantid;clientcertpath=/secrets/spn.pem",

// ODBC mode
"odbc:password={",
Expand Down Expand Up @@ -74,6 +83,17 @@ func TestValidConnectionString(t *testing.T) {
{"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }},
{"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }},

// AAD
{"fedauth=ActiveDirectoryPassword;user [email protected];password=secret", func(p connectParams) bool {
return p.fedAuthLibrary == fedAuthLibraryADAL && p.fedAuthADALWorkflow == fedAuthADALWorkflowPassword
}},
{"fedauth=ActiveDirectoryMSI", func(p connectParams) bool {
return p.fedAuthLibrary == fedAuthLibraryADAL && p.fedAuthADALWorkflow == fedAuthADALWorkflowMSI
}},
{"fedauth=ActiveDirectoryApplication;user id=clientid@tenantid;clientcertpath=/secrets/spn.pem", func(p connectParams) bool {
return p.fedAuthLibrary == fedAuthLibrarySecurityToken && p.user == "clientid" && p.aadTenantID == "tenantid" && p.aadClientCertPath == "/secrets/spn.pem"
}},

// those are supported currently, but maybe should not be
{"someparam", func(p connectParams) bool { return true }},
{";;=;", func(p connectParams) bool { return true }},
Expand Down
Loading