diff --git a/.gitignore b/.gitignore index 4fa5e7b8..1dda7039 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,8 @@ /.idea /.connstr +.vscode +.terraform +*.tfstate* +*.log +*.swp +*~ diff --git a/accesstokenconnector.go b/accesstokenconnector.go index 8dbe5099..086aef38 100644 --- a/accesstokenconnector.go +++ b/accesstokenconnector.go @@ -6,19 +6,8 @@ import ( "context" "database/sql/driver" "errors" - "fmt" ) -var _ driver.Connector = &accessTokenConnector{} - -// accessTokenConnector wraps Connector and injects a -// fresh access token when connecting to the database -type accessTokenConnector struct { - Connector - - accessTokenProvider func() (string, error) -} - // NewAccessTokenConnector creates a new connector from a DSN and a token provider. // The token provider func will be called when a new connection is requested and should return a valid access token. // The returned connector may be used with sql.OpenDB. @@ -32,20 +21,26 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) ( return nil, err } - c := &accessTokenConnector{ - Connector: *conn, - accessTokenProvider: tokenProvider, + conn.FederatedAuthenticationProvider = &accessTokenProvider{ + tokenProvider: tokenProvider, } - return c, nil + + return conn, nil } -// Connect returns a new database connection -func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { - var err error - c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() - if err != nil { - return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) - } +type accessTokenProvider struct { + tokenProvider func() (string, error) +} + +func (p *accessTokenProvider) ConfigureProvider(fa *FederatedAuthenticationState) error { + fa.FedAuthLibrary = FedAuthLibrarySecurityToken + return nil +} + +func (p *accessTokenProvider) ProvideSecurityToken(ctx context.Context) (string, error) { + return p.tokenProvider() +} - return c.Connector.Connect(ctx) +func (p *accessTokenProvider) ProvideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) { + return p.tokenProvider() } diff --git a/accesstokenconnector_test.go b/accesstokenconnector_test.go index 826dedba..cbe7c06a 100644 --- a/accesstokenconnector_test.go +++ b/accesstokenconnector_test.go @@ -30,21 +30,21 @@ func TestNewAccessTokenConnector(t *testing.T) { dsn: dsn, tokenProvider: tp}, want: func(c driver.Connector) error { - tc, ok := c.(*accessTokenConnector) + tc, ok := c.(*Connector) if !ok { - return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c) + return fmt.Errorf("Expected driver to be of type *Connector, but got %T", c) } - p := tc.Connector.params + p := tc.params if p.database != "db" { return fmt.Errorf("expected params.database=db, but got %v", p.database) } if p.host != "server.database.windows.net" { return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host) } - if tc.accessTokenProvider == nil { - return fmt.Errorf("Expected tokenProvider to not be nil") + if tc.FederatedAuthenticationProvider == nil { + return fmt.Errorf("Expected federated authentication provider to not be nil") } - t, err := tc.accessTokenProvider() + t, err := tc.FederatedAuthenticationProvider.ProvideSecurityToken(context.TODO()) if t != "token" || err != nil { return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) } diff --git a/conn_str.go b/conn_str.go index bef0b9fc..bbe47d0b 100644 --- a/conn_str.go +++ b/conn_str.go @@ -37,7 +37,8 @@ type connectParams struct { failOverPartner string failOverPort uint64 packetSize uint16 - fedAuthAccessToken string + fedAuthWorkflow string + aadClientCertPath string } // default packet size for TDS buffer @@ -232,6 +233,13 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.fedAuthWorkflow, ok = params["fedauth"] + if ok && p.disableEncryption { + f := "Encryption must not be disabled for federated authentication: encrypt='%s'" + return p, fmt.Errorf(f, encrypt) + } + p.aadClientCertPath, _ = params["clientcertpath"] + return p, nil } @@ -247,8 +255,8 @@ func (p connectParams) toUrl() *url.URL { } res := url.URL{ Scheme: "sqlserver", - Host: p.host, - User: url.UserPassword(p.user, p.password), + Host: p.host, + User: url.UserPassword(p.user, p.password), } if p.instance != "" { res.Path = p.instance diff --git a/conn_str_test.go b/conn_str_test.go index bb6e2682..8929ded2 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -34,6 +34,9 @@ func TestInvalidConnectionString(t *testing.T) { // URL mode "sqlserver://\x00", "sqlserver://host?key=value1&key=value2", // duplicate keys + + // cannot use federated authentication when encryption is disabled + "encrypt=disable;fedauth=ActiveDirectoryMSI", } for _, connStr := range connStrings { _, err := parseConnectParams(connStr) @@ -78,6 +81,10 @@ func TestValidConnectionString(t *testing.T) { {"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }}, {"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }}, + // federated authentication workflow + {"fedauth=ActiveDirectoryPassword", func(p connectParams) bool { return p.fedAuthWorkflow == "ActiveDirectoryPassword" }}, + {"clientcertpath=client.pem", func(p connectParams) bool { return p.aadClientCertPath == "client.pem" }}, + // those are supported currently, but maybe should not be {"someparam", func(p connectParams) bool { return true }}, {";;=;", func(p connectParams) bool { return true }}, diff --git a/fedauth.go b/fedauth.go new file mode 100644 index 00000000..b6565025 --- /dev/null +++ b/fedauth.go @@ -0,0 +1,94 @@ +package mssql + +import ( + "context" +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + // FedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme + FedAuthLibraryLiveIDCompactToken = 0x00 + + // FedAuthLibrarySecurityToken specifies a token-based authentication where the token is available + // without additional information provided during the login sequence. + FedAuthLibrarySecurityToken = 0x01 + + // FedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the + // login sequence using the server SPN and STS URL provided by the server during login. + FedAuthLibraryADAL = 0x02 + + // FedAuthLibraryReserved is used to indicate that no federated authentication scheme applies. + FedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + // FedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory + FedAuthADALWorkflowPassword = 0x01 + + // FedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory + FedAuthADALWorkflowIntegrated = 0x02 + + // FedAuthADALWorkflowMSI uses the managed identity service to obtain a token + FedAuthADALWorkflowMSI = 0x03 +) + +type FederatedAuthenticationState struct { + // FedAuthWorkflow captures the "fedauth" connection parameter + FedAuthWorkflow string + + // UserName is initially set to the user id connection parameter. + // The federated authentication configurer can modify this value to + // change what is sent in the login packet. + UserName string + + // Password is initially set to the user id connection parameter. + // The federated authentication configurer can modify this value to + // change what is sent in the login packet. + Password string + + // Password is initially set to the client cert path connection parameter. + ClientCertPath string + + // FedAuthLibrary is populated by the federated authentication provider. + FedAuthLibrary int + + // ADALWorkflow is populated by the federated authentication provider. + ADALWorkflow byte + + // FedAuthEcho is populated from the prelogin response + FedAuthEcho bool + + // FedAuthToken is populated during login with the value from the provider. + FedAuthToken string + + // Nonce is populated during login with the value from the provider. + Nonce []byte + + // Signature is populated during login with the value from the server. + Signature []byte +} + +// FederatedAuthenticationProvider implementations use the connection string +// parameters to determine the library and workflow, if any, and obtain tokens +// during the login sequence. +type FederatedAuthenticationProvider interface { + // Configure accepts the incoming connection parameters and determines + // the values for the authentication library and ADAL workflow. + ConfigureProvider(fedAuth *FederatedAuthenticationState) error + + // ProvideActiveDirectoryToken implementations are called during federated + // authentication login sequences where the server provides a service + // principal name and security token service endpoint that should be used + // to obtain the token. Implementations should contact the security token + // service specified and obtain the appropriate token, or return an error + // to indicate why a token is not available. + ProvideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) + + // ProvideSecurityToken implementations are called during federated + // authentication security token login sequences at the point when the + // security token is required. The string returned should be the access + // token to supply to the server, otherwise an error can be returned to + // indicate why a token is not available. + ProvideSecurityToken(ctx context.Context) (string, error) +} diff --git a/mssql.go b/mssql.go index 6a66cda3..28445bcc 100644 --- a/mssql.go +++ b/mssql.go @@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) { if err != nil { return nil, err } + return &Connector{ params: params, driver: d, @@ -126,6 +127,10 @@ type Connector struct { // Dialer sets a custom dialer for all network operations. // If Dialer is not set, normal net dialers are used. Dialer Dialer + + // FederatedAuthenticationProvider handles choosing the parameters + // and obtaining tokens for federated authentication login scenarios. + FederatedAuthenticationProvider FederatedAuthenticationProvider } type Dialer interface { @@ -148,7 +153,7 @@ type Conn struct { processQueryText bool connectionGood bool - outs map[string]interface{} + outs map[string]interface{} } func (c *Conn) checkBadConn(err error) error { @@ -653,9 +658,9 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { } type Rows struct { - stmt *Stmt - cols []columnStruct - reader *tokenProcessor + stmt *Stmt + cols []columnStruct + reader *tokenProcessor nextCols []columnStruct cancel func() @@ -669,7 +674,7 @@ func (rc *Rows) Close() error { for { tok, err := rc.reader.nextToken() if err == nil { - if tok == nil { + if tok == nil { return nil } else { // continue consuming tokens diff --git a/tds.go b/tds.go index 5d30bbf1..46291884 100644 --- a/tds.go +++ b/tds.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" "sort" "strconv" @@ -81,20 +82,21 @@ const ( // packet types // https://msdn.microsoft.com/en-us/library/dd304214.aspx const ( - packSQLBatch packetType = 1 - packRPCRequest packetType = 3 - packReply packetType = 4 + packSQLBatch packetType = 1 + packRPCRequest packetType = 3 + packReply packetType = 4 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx - packAttention packetType = 6 - - packBulkLoadBCP packetType = 7 - packTransMgrReq packetType = 14 - packNormal packetType = 15 - packLogin7 packetType = 16 - packSSPIMessage packetType = 17 - packPrelogin packetType = 18 + packAttention packetType = 6 + + packBulkLoadBCP packetType = 7 + packFedAuthToken packetType = 8 + packTransMgrReq packetType = 14 + packNormal packetType = 15 + packLogin7 packetType = 16 + packSSPIMessage packetType = 17 + packPrelogin packetType = 18 ) // prelogin fields @@ -118,6 +120,17 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -244,6 +257,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -251,12 +274,6 @@ const ( fReadOnlyIntent = 32 ) -// OptionFlags3 -// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac -const ( - fExtension = 0x10 -) - type login struct { TDSVersion uint32 PacketSize uint32 @@ -330,37 +347,43 @@ func (e featureExts) toBytes() []byte { return d } -type featureExtFedAuthSTS struct { - FedAuthEcho bool - FedAuthToken string - Nonce []byte +// Implement featureExt interface based on federated authentication state +func (e *FederatedAuthenticationState) featureID() byte { + return featExtFEDAUTH } -func (e *featureExtFedAuthSTS) featureID() byte { - return 0x02 -} - -func (e *featureExtFedAuthSTS) toBytes() []byte { +func (e *FederatedAuthenticationState) toBytes() []byte { if e == nil { return nil } - options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT + options := byte(e.FedAuthLibrary) << 1 if e.FedAuthEcho { options |= 1 // fFedAuthEcho } - d := make([]byte, 5) - d[0] = options + // Feature extension format depends on the federated auth library. + // Options are described at + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac + var d []byte + + switch e.FedAuthLibrary { + case FedAuthLibrarySecurityToken: + d = make([]byte, 5) + d[0] = options + + // looks like string in + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 + tokenBytes := str2ucs2(e.FedAuthToken) + binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work + d = append(d, tokenBytes...) - // looks like string in - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 - tokenBytes := str2ucs2(e.FedAuthToken) - binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work - d = append(d, tokenBytes...) + if len(e.Nonce) == 32 { + d = append(d, e.Nonce...) + } - if len(e.Nonce) == 32 { - d = append(d, e.Nonce...) + case FedAuthLibraryADAL: + d = []byte{options, e.ADALWorkflow} } return d @@ -440,7 +463,7 @@ func manglePassword(password string) []byte { } // http://msdn.microsoft.com/en-us/library/dd304019.aspx -func sendLogin(w *tdsBuffer, login login) error { +func sendLogin(w *tdsBuffer, login *login) error { w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) @@ -576,6 +599,36 @@ func sendLogin(w *tdsBuffer, login login) error { return w.FinishPacket() } +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, fedAuth *FederatedAuthenticationState) (err error) { + fedauthtoken := str2ucs2(fedAuth.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(fedAuth.Nonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) + if err != nil { + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return + } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(fedAuth.Nonce) + if err != nil { + return + } + + return w.FinishPacket() +} + func readUcs2(r io.Reader, numchars int) (res string, err error) { buf := make([]byte, numchars*2) _, err = io.ReadFull(r, buf) @@ -835,6 +888,144 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne return conn, err } +func configureFedAuth(c *Connector, p connectParams) (fe *FederatedAuthenticationState, err error) { + if c == nil || c.FederatedAuthenticationProvider == nil { + return &FederatedAuthenticationState{ + FedAuthLibrary: FedAuthLibraryReserved, + }, nil + } + + fe = &FederatedAuthenticationState{ + FedAuthWorkflow: p.fedAuthWorkflow, + UserName: p.user, + Password: p.password, + ClientCertPath: p.aadClientCertPath, + + FedAuthLibrary: FedAuthLibraryReserved, + } + + err = c.FederatedAuthenticationProvider.ConfigureProvider(fe) + if err != nil { + return nil, err + } + + return +} + +func preparePreloginFields(p connectParams, fe *FederatedAuthenticationState) map[uint8][]byte { + instance_buf := []byte(p.instance) + instance_buf = append(instance_buf, 0) // zero terminate instance name + + var encrypt byte + if p.disableEncryption { + encrypt = encryptNotSup + } else if p.encrypt { + encrypt = encryptOn + } else { + encrypt = encryptOff + } + + fields := map[uint8][]byte{ + preloginVERSION: {0, 0, 0, 0, 0, 0}, + preloginENCRYPTION: {encrypt}, + preloginINSTOPT: instance_buf, + preloginTHREADID: {0, 0, 0, 0}, + preloginMARS: {0}, // MARS disabled + } + + if fe.FedAuthLibrary != FedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + + return fields +} + +func interpretPreloginResponse(p connectParams, fe *FederatedAuthenticationState, fields map[uint8][]byte) (encrypt byte, err error) { + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) != 1 { + return 0, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + + // We need to be able to echo the value back to the server + fe.FedAuthEcho = fedAuthSupport[0] != 0 + } else if fe.FedAuthLibrary != FedAuthLibraryReserved { + return 0, fmt.Errorf("Federated authentication is not supported by the server") + } + + encryptBytes, ok := fields[preloginENCRYPTION] + if !ok { + return 0, fmt.Errorf("encrypt negotiation failed") + } + encrypt = encryptBytes[0] + if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { + return 0, fmt.Errorf("server does not support encryption") + } + + return +} + +func prepareLogin(ctx context.Context, c *Connector, p connectParams, auth auth, fe *FederatedAuthenticationState, packetSize uint32) (l *login, err error) { + l = &login{ + TDSVersion: verTDS74, + PacketSize: packetSize, + Database: p.database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.workstation, + ServerName: p.host, + AppName: p.appname, + TypeFlags: p.typeFlags, + } + switch { + case fe.FedAuthLibrary == FedAuthLibrarySecurityToken: + if p.logFlags&logDebug != 0 { + log.Println("Starting federated authentication using security token") + } + + fe.FedAuthToken, err = c.FederatedAuthenticationProvider.ProvideSecurityToken(ctx) + if err != nil { + if p.logFlags&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + + l.FeatureExt.Add(fe) + + case fe.FedAuthLibrary == FedAuthLibraryADAL: + if p.logFlags&logDebug != 0 { + log.Println("Starting federated authentication using ADAL") + } + + l.UserName = fe.UserName + l.Password = fe.Password + + l.FeatureExt.Add(fe) + + case auth != nil: + if p.logFlags&logDebug != 0 { + log.Println("Starting SSPI login") + } + + l.SSPI, err = auth.InitialBytes() + if err != nil { + return nil, err + } + + l.OptionFlags2 |= fIntSecurity + return l, nil + + default: + // Default to SQL server authentication with user and password + l.UserName = p.user + l.Password = p.password + } + + return l, nil +} + func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { dialCtx := ctx if p.dial_timeout > 0 { @@ -847,7 +1038,7 @@ func connect(ctx context.Context, c *Connector, log optionalLogger, p connectPar // both instance name and port specified // when port is specified instance name is not used // you should not provide instance name when you provide port - log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored"); + log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored") } if p.instance != "" && p.port == 0 { p.instance = strings.ToUpper(p.instance) @@ -885,24 +1076,13 @@ initiate_connection: logFlags: p.logFlags, } - instance_buf := []byte(p.instance) - instance_buf = append(instance_buf, 0) // zero terminate instance name - var encrypt byte - if p.disableEncryption { - encrypt = encryptNotSup - } else if p.encrypt { - encrypt = encryptOn - } else { - encrypt = encryptOff - } - fields := map[uint8][]byte{ - preloginVERSION: {0, 0, 0, 0, 0, 0}, - preloginENCRYPTION: {encrypt}, - preloginINSTOPT: instance_buf, - preloginTHREADID: {0, 0, 0, 0}, - preloginMARS: {0}, // MARS disabled + fedAuth, err := configureFedAuth(c, p) + if err != nil { + return nil, err } + fields := preparePreloginFields(p, fedAuth) + err = writePrelogin(packPrelogin, outbuf, fields) if err != nil { return nil, err @@ -913,13 +1093,9 @@ initiate_connection: return nil, err } - encryptBytes, ok := fields[preloginENCRYPTION] - if !ok { - return nil, fmt.Errorf("encrypt negotiation failed") - } - encrypt = encryptBytes[0] - if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { - return nil, fmt.Errorf("server does not support encryption") + encrypt, err := interpretPreloginResponse(p, fedAuth, fields) + if err != nil { + return nil, err } if encrypt != encryptNotSup { @@ -959,81 +1135,92 @@ initiate_connection: } } - login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, - } auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) - switch { - case p.fedAuthAccessToken != "": // accesstoken ignores user/password - featurext := &featureExtFedAuthSTS{ - FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, - FedAuthToken: p.fedAuthAccessToken, - Nonce: fields[preloginNONCEOPT], - } - login.FeatureExt.Add(featurext) - case authOk: - login.SSPI, err = auth.InitialBytes() - if err != nil { - return nil, err - } - login.OptionFlags2 |= fIntSecurity + if authOk { defer auth.Free() - default: - login.UserName = p.user - login.Password = p.password + } else { + auth = nil } + + login, err := prepareLogin(ctx, c, p, auth, fedAuth, uint32(outbuf.PackageSize())) + if err != nil { + return nil, err + } + err = sendLogin(outbuf, login) if err != nil { return nil, err } - // processing login response - reader := startReading(&sess, ctx, nil) - for { - tok, err := reader.nextToken() - if err == nil { + // Loop until a packet containing a login acknowledgement is received. + // SSPI and federated authentication scenarios may require multiple + // packet exchanges to complete the login sequence. + for loginAck := false; !loginAck; { + reader := startReading(&sess, ctx, nil) + + for { + tok, err := reader.nextToken() + if err != nil { + return nil, err + } + if tok == nil { break - } else { - switch token := tok.(type) { - case sspiMsg: - sspi_msg, err := auth.NextBytes(token) + } + + switch token := tok.(type) { + case sspiMsg: + sspi_msg, err := auth.NextBytes(token) + if err != nil { + return nil, err + } + if len(sspi_msg) > 0 { + outbuf.BeginPacket(packSSPIMessage, false) + _, err = outbuf.Write(sspi_msg) if err != nil { return nil, err } - if len(sspi_msg) > 0 { - outbuf.BeginPacket(packSSPIMessage, false) - _, err = outbuf.Write(sspi_msg) - if err != nil { - return nil, err - } - err = outbuf.FinishPacket() - if err != nil { - return nil, err - } - sspi_msg = nil - } - case loginAckStruct: - sess.loginAck = token - /*case error: - return nil, fmt.Errorf("login error: %s", token.Error())*/ - case doneStruct: - if token.isError() { - return nil, fmt.Errorf("login error: %s", token.getError()) + err = outbuf.FinishPacket() + if err != nil { + return nil, err } + sspi_msg = nil + } + // TODO: for Live ID authentication it may be necessary to + // compare fedAuth.Nonce == token.Nonce and keep track of signature + //case fedAuthAckStruct: + //fedAuth.Signature = token.Signature + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if c == nil || c.FederatedAuthenticationProvider == nil { + continue } + + // Request the AD token given the server SPN and STS URL + fedAuth.FedAuthToken, err = c.FederatedAuthenticationProvider.ProvideActiveDirectoryToken(ctx, token.ServerSPN, token.STSURL) + if err != nil { + return nil, err + } + + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, fedAuth) + if err != nil { + return nil, err + } + case loginAckStruct: + sess.loginAck = token + loginAck = true + case doneStruct: + if token.isError() { + return nil, fmt.Errorf("login error: %s", token.getError()) + } + case error: + return nil, fmt.Errorf("login error: %s", token.Error()) } - } else { - return nil, err } } + if sess.routedServer != "" { toconn.Close() p.host = sess.routedServer @@ -1050,4 +1237,4 @@ func (sess *tdsSession) setReturnStatus(status ReturnStatus) { if sess.returnStatus != nil { *sess.returnStatus = status } -} \ No newline at end of file +} diff --git a/tds_login_test.go b/tds_login_test.go new file mode 100644 index 00000000..abbeaa8c --- /dev/null +++ b/tds_login_test.go @@ -0,0 +1,368 @@ +package mssql + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "regexp" + "sync/atomic" + "testing" +) + +type MockTransportDialer struct { + expected, responses []string + server, client net.Conn + result chan error + count int32 +} + +func NewMockTransportDialer(expected, responses []string) *MockTransportDialer { + server, client := net.Pipe() + + return &MockTransportDialer{ + expected: expected, + responses: responses, + server: server, + client: client, + result: make(chan error, 2), + } +} + +func (d *MockTransportDialer) DialContext(ctx context.Context, network string, addr string) (conn net.Conn, err error) { + if atomic.AddInt32(&d.count, 1) != 1 { + return nil, errors.New("No concurrent connections to mock dialer") + } + + go testLoginSequenceServer(d.result, d.server, d.expected, d.responses) + + return d.client, nil +} + +func testLoginSequenceServer(result chan error, conn net.Conn, expectedPackets, responsePackets []string) { + defer func() { + conn.Close() + close(result) + }() + + spacesRE := regexp.MustCompile("\\s+") + + packet := make([]byte, 1024) + for i, expectedHex := range expectedPackets { + expectedBytes, err := hex.DecodeString(spacesRE.ReplaceAllString(expectedHex, "")) + if err != nil { + result <- err + return + } + + for b := 0; b < len(expectedBytes) && err == nil; { + n, err := conn.Read(packet) + + // Ignore EOF: ErrPipeClosed is the real signal + if err == io.EOF { + err = nil + continue + } + + if err != nil { + result <- err + return + } + + for bi := 0; bi < n; bi++ { + if expectedBytes[bi+b] != packet[bi] { + err = fmt.Errorf("Client sent unexpected byte %02X != %02X at offset %d of packet %d", + packet[bi], expectedBytes[bi+b], bi+b, i) + result <- err + return + } + } + + b = b + n + } + + if i >= len(responsePackets) || responsePackets[i] == "" { + continue + } + + responseBytes, err := hex.DecodeString(spacesRE.ReplaceAllString(responsePackets[i], "")) + if err != nil { + result <- err + return + } + + for b := 0; b < len(responseBytes); { + n, err := conn.Write(responseBytes[b:]) + + if err != nil { + result <- err + return + } + + b = b + n + } + } + + result <- nil +} + +func TestLoginWithSQLServerAuth(t *testing.T) { + conn, err := NewConnector("sqlserver://test:secret@localhost:1433?Workstation ID=localhost&log=128") + if err != nil { + t.Errorf("Unable to parse dummy DSN: %v", err) + } + SetLogger(testLogger{t}) + + mock := NewMockTransportDialer( + []string{ + " 12 01 00 2f 00 00 01 00 00 00 1a 00 06 01 00 20\n" + + "00 01 02 00 21 00 01 03 00 22 00 04 04 00 26 00\n" + + "01 ff 00 00 00 00 00 00 00 00 00 00 00 00 00\n", + " 10 01 00 b2 00 00 01 00 aa 00 00 00 04 00 00 74\n" + + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + + "00 02 00 00 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "70 00 04 00 78 00 06 00 84 00 0a 00 98 00 09 00\n" + + "00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00 00 00\n" + + "00 00 00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00\n" + + "00 00 00 00 00 00 6c 00 6f 00 63 00 61 00 6c 00\n" + + "68 00 6f 00 73 00 74 00 74 00 65 00 73 00 74 00\n" + + "92 a5 f3 a5 93 a5 82 a5 f3 a5 e2 a5 67 00 6f 00\n" + + "2d 00 6d 00 73 00 73 00 71 00 6c 00 64 00 62 00\n" + + "6c 00 6f 00 63 00 61 00 6c 00 68 00 6f 00 73 00\n" + + "74 00\n", + }, + []string{ + " 04 01 00 20 00 00 01 00 00 00 10 00 06 01 00 16\n" + + "00 01 06 00 17 00 01 FF 0C 00 07 D0 00 00 02 01\n", + " 04 01 00 4A 00 00 01 00 AD 32 00 01 74 00 00 04\n" + + "14 4d 00 69 00 63 00 72 00 6f 00 73 00 6f 00 66\n" + + "00 74 00 20 00 53 00 51 00 4c 00 20 00 53 00 65\n" + + "00 72 00 76 00 65 00 72 00 0c 00 07 d0 fd 00 00\n" + + "00 00 00 00 00 00 00 00 00 00\n", + }, + ) + + conn.Dialer = mock + + _, err = connect(context.Background(), conn, driverInstanceNoProcess.log, conn.params) + + err = <-mock.result + if err != nil { + t.Error(err) + } +} + +func TestLoginWithSecurityTokenAuth(t *testing.T) { + conn, err := NewConnector("sqlserver://localhost:1433?Workstation ID=localhost&log=128") + if err != nil { + t.Errorf("Unable to parse dummy DSN: %v", err) + } + + conn.FederatedAuthenticationProvider = &accessTokenProvider{ + tokenProvider: func() (string, error) { + return "", nil + }, + } + + SetLogger(testLogger{t}) + + mock := NewMockTransportDialer( + []string{ + " 12 01 00 35 00 00 01 00 00 00 1F 00 06 01 00 25\n" + + "00 01 02 00 26 00 01 03 00 27 00 04 04 00 2B 00\n" + + "01 06 00 2c 00 01 ff 00 00 00 00 00 00 00 00 00\n" + + "00 00 00 00 01\n", + " 10 01 00 BB 00 00 01 00 B3 00 00 00 04 00 00 74\n" + + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + + "00 02 00 10 00 00 00 00 00 00 00 00 5E 00 09 00\n" + + "70 00 00 00 70 00 00 00 70 00 0A 00 84 00 09 00\n" + + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + + "00 00 00 00 00 00 6C 00 6F 00 63 00 61 00 6C 00\n" + + "68 00 6F 00 73 00 74 00 67 00 6F 00 2D 00 6D 00\n" + + "73 00 73 00 71 00 6C 00 64 00 62 00 6C 00 6F 00\n" + + "63 00 61 00 6C 00 68 00 6F 00 73 00 74 00 9A 00\n" + + "00 00 02 13 00 00 00 03 0E 00 00 00 3C 00 74 00\n" + + "6F 00 6B 00 65 00 6E 00 3E 00 FF\n", + }, + []string{ + " 04 01 00 20 00 00 01 00 00 00 10 00 06 01 00 16\n" + + "00 01 06 00 17 00 01 FF 0C 00 07 D0 00 00 02 01\n", + " 04 01 00 4A 00 00 01 00 AD 32 00 01 74 00 00 04\n" + + "14 4d 00 69 00 63 00 72 00 6f 00 73 00 6f 00 66\n" + + "00 74 00 20 00 53 00 51 00 4c 00 20 00 53 00 65\n" + + "00 72 00 76 00 65 00 72 00 0c 00 07 d0 fd 00 00\n" + + "00 00 00 00 00 00 00 00 00 00\n", + }, + ) + + conn.Dialer = mock + + _, err = connect(context.Background(), conn, driverInstanceNoProcess.log, conn.params) + + err = <-mock.result + if err != nil { + t.Error(err) + } +} + +type testADALProvider struct { + configureError error + tokenError error + workflow byte +} + +func (p *testADALProvider) ConfigureProvider(fa *FederatedAuthenticationState) error { + if p.configureError != nil { + return p.configureError + } + + fa.FedAuthLibrary = FedAuthLibraryADAL + fa.ADALWorkflow = p.workflow + + return nil +} + +func (p *testADALProvider) ProvideSecurityToken(ctx context.Context) (string, error) { + if p.tokenError != nil { + return "", p.tokenError + } + + return "", nil +} + +func (p *testADALProvider) ProvideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) { + if p.tokenError != nil { + return "", p.tokenError + } + + return "", nil +} + +func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { + conn, err := NewConnector("sqlserver://localhost:1433?Workstation ID=localhost&log=128&fedauth=ActiveDirectoryPassword") + if err != nil { + t.Errorf("Unable to parse dummy DSN: %v", err) + } + + conn.FederatedAuthenticationProvider = &testADALProvider{ + workflow: FedAuthADALWorkflowPassword, + } + + SetLogger(testLogger{t}) + + mock := NewMockTransportDialer( + []string{ + " 12 01 00 35 00 00 01 00 00 00 1F 00 06 01 00 25\n" + + "00 01 02 00 26 00 01 03 00 27 00 04 04 00 2B 00\n" + + "01 06 00 2C 00 01 ff 00 00 00 00 00 00 00 00 00\n" + + "00 00 00 00 01\n", + " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + + "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + + "00 00 00 00 00 00 6c 00 6f 00 63 00 61 00 6c 00\n" + + "68 00 6f 00 73 00 74 00 67 00 6f 00 2d 00 6d 00\n" + + "73 00 73 00 71 00 6c 00 64 00 62 00 6c 00 6f 00\n" + + "63 00 61 00 6c 00 68 00 6f 00 73 00 74 00 9a 00\n" + + "00 00 02 02 00 00 00 05 01 ff\n", + " 08 01 00 1e 00 00 01 00 12 00 00 00 0e 00 00 00\n" + + "3c 00 74 00 6f 00 6b 00 65 00 6e 00 3e 00\n", + }, + []string{ + " 04 01 00 20 00 00 01 00 00 00 10 00 06 01 00 16\n" + + "00 01 06 00 17 00 01 FF 0C 00 07 D0 00 00 02 01\n", + " 04 01 00 97 00 00 01 00 EE 8A 00 00 00 02 00 00\n" + + "00 02 3A 00 00 00 16 00 00 00 01 3A 00 00 00 50\n" + + "00 00 00 68 00 74 00 74 00 70 00 73 00 3A 00 2F\n" + + "00 2F 00 64 00 61 00 74 00 61 00 62 00 61 00 73\n" + + "00 65 00 2E 00 77 00 69 00 6E 00 64 00 6F 00 77\n" + + "00 73 00 2E 00 6E 00 65 00 74 00 2F 00 68 00 74\n" + + "00 74 00 70 00 73 00 3A 00 2F 00 2F 00 65 00 78\n" + + "00 61 00 6D 00 70 00 6C 00 65 00 2E 00 63 00 6F\n" + + "00 6D 00 2F 00 61 00 75 00 74 00 68 00 6F 00 72\n" + + "00 69 00 74 00 79 00\n", + " 04 01 00 4A 00 00 01 00 AD 32 00 01 74 00 00 04\n" + + "14 4d 00 69 00 63 00 72 00 6f 00 73 00 6f 00 66\n" + + "00 74 00 20 00 53 00 51 00 4c 00 20 00 53 00 65\n" + + "00 72 00 76 00 65 00 72 00 0c 00 07 d0 fd 00 00\n" + + "00 00 00 00 00 00 00 00 00 00\n", + }, + ) + + conn.Dialer = mock + + _, err = connect(context.Background(), conn, driverInstanceNoProcess.log, conn.params) + + err = <-mock.result + if err != nil { + t.Error(err) + } +} + +func TestLoginWithADALManagedIdentityAuth(t *testing.T) { + conn, err := NewConnector("sqlserver://localhost:1433?Workstation ID=localhost&log=128&fedauth=ActiveDirectoryMSI") + if err != nil { + t.Errorf("Unable to parse dummy DSN: %v", err) + } + + conn.FederatedAuthenticationProvider = &testADALProvider{ + workflow: FedAuthADALWorkflowMSI, + } + + SetLogger(testLogger{t}) + + mock := NewMockTransportDialer( + []string{ + " 12 01 00 35 00 00 01 00 00 00 1F 00 06 01 00 25\n" + + "00 01 02 00 26 00 01 03 00 27 00 04 04 00 2B 00\n" + + "01 06 00 2C 00 01 ff 00 00 00 00 00 00 00 00 00\n" + + "00 00 00 00 01\n", + " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + + "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + + "00 00 00 00 00 00 6c 00 6f 00 63 00 61 00 6c 00\n" + + "68 00 6f 00 73 00 74 00 67 00 6f 00 2d 00 6d 00\n" + + "73 00 73 00 71 00 6c 00 64 00 62 00 6c 00 6f 00\n" + + "63 00 61 00 6c 00 68 00 6f 00 73 00 74 00 9a 00\n" + + "00 00 02 02 00 00 00 05 03 ff\n", + " 08 01 00 1e 00 00 01 00 12 00 00 00 0e 00 00 00\n" + + "3c 00 74 00 6f 00 6b 00 65 00 6e 00 3e 00\n", + }, + []string{ + " 04 01 00 20 00 00 01 00 00 00 10 00 06 01 00 16\n" + + "00 01 06 00 17 00 01 FF 0C 00 07 D0 00 00 02 01\n", + " 04 01 00 97 00 00 01 00 EE 8A 00 00 00 02 00 00\n" + + "00 02 3A 00 00 00 16 00 00 00 01 3A 00 00 00 50\n" + + "00 00 00 68 00 74 00 74 00 70 00 73 00 3A 00 2F\n" + + "00 2F 00 64 00 61 00 74 00 61 00 62 00 61 00 73\n" + + "00 65 00 2E 00 77 00 69 00 6E 00 64 00 6F 00 77\n" + + "00 73 00 2E 00 6E 00 65 00 74 00 2F 00 68 00 74\n" + + "00 74 00 70 00 73 00 3A 00 2F 00 2F 00 65 00 78\n" + + "00 61 00 6D 00 70 00 6C 00 65 00 2E 00 63 00 6F\n" + + "00 6D 00 2F 00 61 00 75 00 74 00 68 00 6F 00 72\n" + + "00 69 00 74 00 79 00\n", + " 04 01 00 4A 00 00 01 00 AD 32 00 01 74 00 00 04\n" + + "14 4d 00 69 00 63 00 72 00 6f 00 73 00 6f 00 66\n" + + "00 74 00 20 00 53 00 51 00 4c 00 20 00 53 00 65\n" + + "00 72 00 76 00 65 00 72 00 0c 00 07 d0 fd 00 00\n" + + "00 00 00 00 00 00 00 00 00 00\n", + }, + ) + + conn.Dialer = mock + + _, err = connect(context.Background(), conn, driverInstanceNoProcess.log, conn.params) + + err = <-mock.result + if err != nil { + t.Error(err) + } +} diff --git a/tds_test.go b/tds_test.go index 6998328b..37696c1d 100644 --- a/tds_test.go +++ b/tds_test.go @@ -19,6 +19,27 @@ func (t *MockTransport) Close() error { return nil } +func TestConstantsDefined(t *testing.T) { + // This test is just here to avoid complaints about unused code. + // These constants are part of the spec but not yet used. + for _, b := range []byte{ + featExtSESSIONRECOVERY, featExtCOLUMNENCRYPTION, featExtGLOBALTRANSACTIONS, + featExtAZURESQLSUPPORT, featExtDATACLASSIFICATION, featExtUTF8SUPPORT, + } { + if b == 0 { + t.Fail() + } + } + + for _, i := range []int{ + FedAuthLibraryLiveIDCompactToken, fChangePassword, fSendYukonBinaryXML, + } { + if i < 0 { + t.Fail() + } + } +} + func TestSendLogin(t *testing.T) { memBuf := new(MockTransport) buf := newTdsBuffer(1024, memBuf) @@ -42,7 +63,7 @@ func TestSendLogin(t *testing.T) { ClientLCID: 0x204, AtchDBFile: "filepath", } - err := sendLogin(buf, login) + err := sendLogin(buf, &login) if err != nil { t.Error("sendLogin should succeed") } @@ -89,24 +110,28 @@ func TestSendLoginWithFeatureExt(t *testing.T) { Database: "database", ClientLCID: 0x204, } - login.FeatureExt.Add(&featureExtFedAuthSTS{ - FedAuthToken: "fedauthtoken", + login.FeatureExt.Add(&FederatedAuthenticationState{ + FedAuthLibrary: FedAuthLibrarySecurityToken, + FedAuthToken: "fedauthtoken", }) - err := sendLogin(buf, login) + err := sendLogin(buf, &login) if err != nil { t.Error("sendLogin should succeed") } ref := []byte{ - 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, - 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, - 0, 94, 0, 7, 0, 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, 176, - 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, 18, 52, 86, 120, 144, 171, - 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, - 0, 100, 0, 101, 0, 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, - 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, 97, - 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, 114, 0, 121, 0, 101, - 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, - 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, + 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, + 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, + 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, + 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, + 176, 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, + 18, 52, 86, 120, 144, 171, 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, + 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, 0, 100, 0, 101, 0, + 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, 109, 0, + 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, + 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, + 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, + 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 2, 29, 0, 0, + 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} out := memBuf.Bytes() if !bytes.Equal(ref, out) { @@ -211,6 +236,10 @@ func TestConnect(t *testing.T) { func TestConnectViaIp(t *testing.T) { params := testConnParams(t) + if params.encrypt { + t.Skip("Unable to test connection to IP for servers that expect encryption") + } + ips, err := net.LookupIP(params.host) if err != nil { t.Fatal("Unable to lookup IP", err) diff --git a/token.go b/token.go index 914e5f58..c9d45256 100644 --- a/token.go +++ b/token.go @@ -6,10 +6,11 @@ import ( "errors" "fmt" "io" + "io/ioutil" "strconv" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte @@ -27,6 +28,7 @@ const ( tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 @@ -68,6 +70,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -402,6 +409,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + // returns via panic + } + + if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + // returns via panic + } + + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + switch opts[i].fedAuthInfoID { + case fedAuthInfoSTSURL: + STSURL, err = ucs22str(optData) + case fedAuthInfoSPN: + SPN, err = ucs22str(optData) + default: + err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -426,19 +505,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a -func parseFeatureExtAck(r *tdsBuffer) { - // at most 1 featureAck per feature in featureExt - // go-mssqldb will add at most 1 feature, the spec defines 7 different features - for i := 0; i < 8; i++ { - featureID := r.byte() // FeatureID - if featureID == 0xff { - return +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) } - size := r.uint32() // FeatureAckDataLen - d := make([]byte, size) - r.ReadFull(d) } - panic("parsed more than 7 featureAck's, protocol implementation error?") + + return ack } // http://msdn.microsoft.com/en-us/library/dd357363.aspx @@ -556,7 +659,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } var columns []columnStruct errs := make([]Error, 0, 5) - for tokens := 0; ; tokens += 1{ + for tokens := 0; ; tokens += 1 { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { sess.log.Printf("got token %v", token) @@ -565,6 +668,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus @@ -572,7 +678,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: - parseFeatureExtAck(sess.buf) + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order @@ -648,12 +755,12 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } type tokenProcessor struct { - tokChan chan tokenStruct - ctx context.Context - sess *tdsSession - outs map[string]interface{} - lastRow []interface{} - rowCount int64 + tokChan chan tokenStruct + ctx context.Context + sess *tdsSession + outs map[string]interface{} + lastRow []interface{} + rowCount int64 firstError error } @@ -662,9 +769,9 @@ func startReading(sess *tdsSession, ctx context.Context, outs map[string]interfa go processSingleResponse(sess, tokChan, outs) return &tokenProcessor{ tokChan: tokChan, - ctx: ctx, - sess: sess, - outs: outs, + ctx: ctx, + sess: sess, + outs: outs, } } @@ -693,7 +800,7 @@ func (t *tokenProcessor) iterateResponse() error { } case ReturnStatus: t.sess.setReturnStatus(token) - /*case error: + /*case error: if resultError == nil { resultError = token }*/ diff --git a/token_string.go b/token_string.go index c075b23b..74389fdd 100644 --- a/token_string.go +++ b/token_string.go @@ -1,29 +1,46 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[tokenReturnStatus-121] + _ = x[tokenColMetadata-129] + _ = x[tokenOrder-169] + _ = x[tokenError-170] + _ = x[tokenInfo-171] + _ = x[tokenReturnValue-172] + _ = x[tokenLoginAck-173] + _ = x[tokenFeatureExtAck-174] + _ = x[tokenRow-209] + _ = x[tokenNbcRow-210] + _ = x[tokenEnvChange-227] + _ = x[tokenSSPI-237] + _ = x[tokenFedAuthInfo-238] + _ = x[tokenDone-253] + _ = x[tokenDoneProc-254] + _ = x[tokenDoneInProc-255] +} const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +49,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/token_test.go b/token_test.go new file mode 100644 index 00000000..a77ea483 --- /dev/null +++ b/token_test.go @@ -0,0 +1,42 @@ +package mssql + +import ( + "encoding/hex" + "regexp" + "testing" +) + +func TestParseFeatureExtAck(t *testing.T) { + spacesRE := regexp.MustCompile("\\s+") + + tests := []string{ + " FF", + " 01 03 00 00 00 AB CD EF FF", + " 02 00 00 00 00 FF\n", + " 02 20 00 00 00 00 01 02 03 04 05 06 07 08 09 0A\n" + + "0B 0C 0D 0E 0F 10 11 12 13 14 15 16 17 18 19 1A\n" + + "1B 1C 1D 1E 1F FF\n", + " 02 40 00 00 00 00 01 02 03 04 05 06 07 08 09 0A\n" + + "0B 0C 0D 0E 0F 10 11 12 13 14 15 16 17 18 19 1A\n" + + "1B 1C 1D 1E 1F 20 21 22 23 24 25 26 27 28 29 2A\n" + + "2B 2C 2D 2E 2F 30 31 32 33 34 35 36 37 38 39 3A\n" + + "3B 3C 3D 3E 3F FF\n", + } + + for _, tst := range tests { + b, err := hex.DecodeString(spacesRE.ReplaceAllString(tst, "")) + if err != nil { + t.Log(err) + t.FailNow() + } + + r := &tdsBuffer{ + packetSize: len(b), + rbuf: b, + rpos: 0, + rsize: len(b), + } + + parseFeatureExtAck(r) + } +}