diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f110691d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.terraform +*.tfstate* +*.log +*.swp +*~ diff --git a/README.md b/README.md index b655176b..1bd940c6 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,14 @@ Other supported formats are listed below. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. * `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `ServerSPN` - The Kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - +* `FedAuth` - The federated authentication scheme to use. + * `ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Set the `user` to `client-ID@tenant-ID` and the `password` to the client secret. If using client certificates, provide the path to the PKCS#12 file containing the certificate and RSA private key in the `ClientCertPath` parameter, and set the `password` to the value needed to open the PKCS#12 file. + * `ActiveDirectoryMSI` - authenticates using the managed service identity (MSI) attached to the VM, or a specific user-assigned identity if a client ID is specified in the `user` field. + * `ActiveDirectoryPassword` - authenticates an Azure Active Directory user account in the form `user@domain.com` with a password. This method is not recommended for general use and does not support multi-factor authentication for accounts. + ### The connection string can be specified in one of three formats: @@ -106,6 +110,24 @@ Other supported formats are listed below. * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" +### Using Azure Active Directory access tokens for authentication + +Azure Active Directory (AAD) is not supported through the DSN, but through a "token provider" callback. +Since access tokens are relatively short lived and need to be valid when a new connection is made, +a better way to provide them is using a connector: +``` golang +conn, err := mssql.NewAccessTokenConnector( + "Server=test.database.windows.net;Database=testdb", + tokenProvider) +if err != nil { + // handle errors in DSN parsing +} +db := sql.OpenDB(conn) +``` +Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements +actually trigger the retrieval of a token, this happens when the first statment is issued and a connection +is created. + ## Executing Stored Procedures To run a stored procedure, set the query text to the procedure name: diff --git a/accesstokenconnector.go b/accesstokenconnector.go new file mode 100644 index 00000000..8dbe5099 --- /dev/null +++ b/accesstokenconnector.go @@ -0,0 +1,51 @@ +// +build go1.10 + +package mssql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" +) + +var _ driver.Connector = &accessTokenConnector{} + +// accessTokenConnector wraps Connector and injects a +// fresh access token when connecting to the database +type accessTokenConnector struct { + Connector + + accessTokenProvider func() (string, error) +} + +// NewAccessTokenConnector creates a new connector from a DSN and a token provider. +// The token provider func will be called when a new connection is requested and should return a valid access token. +// The returned connector may be used with sql.OpenDB. +func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + c := &accessTokenConnector{ + Connector: *conn, + accessTokenProvider: tokenProvider, + } + return c, nil +} + +// Connect returns a new database connection +func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() + if err != nil { + return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) + } + + return c.Connector.Connect(ctx) +} diff --git a/accesstokenconnector_test.go b/accesstokenconnector_test.go new file mode 100644 index 00000000..1adc49b6 --- /dev/null +++ b/accesstokenconnector_test.go @@ -0,0 +1,86 @@ +// +build go1.10 + +package mssql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "strings" + "testing" +) + +func TestNewAccessTokenConnector(t *testing.T) { + dsn := "Server=server.database.windows.net;Database=db" + tp := func() (string, error) { return "token", nil } + type args struct { + dsn string + tokenProvider func() (string, error) + } + tests := []struct { + name string + args args + want func(driver.Connector) error + wantErr bool + }{ + {"Happy path", + args{dsn, tp}, + func(c driver.Connector) error { + tc, ok := c.(*accessTokenConnector) + if !ok { + return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c) + } + p := tc.Connector.params + if p.database != "db" { + return fmt.Errorf("expected params.database=db, but got %v", p.database) + } + if p.host != "server.database.windows.net" { + return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host) + } + if tc.accessTokenProvider == nil { + return fmt.Errorf("Expected tokenProvider to not be nil") + } + t, err := tc.accessTokenProvider() + if t != "token" || err != nil { + return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) + } + return nil + }, + false, + }, + {"Nil tokenProvider gives error", + args{dsn, nil}, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewAccessTokenConnector(tt.args.dsn, tt.args.tokenProvider) + if (err != nil) != tt.wantErr { + t.Errorf("NewAccessTokenConnector() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want != nil { + if err := tt.want(got); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) { + errorText := "This is a test" + dsn := "Server=server.database.windows.net;Database=db" + tp := func() (string, error) { return "", errors.New(errorText) } + sut, err := NewAccessTokenConnector(dsn, tp) + if err != nil { + t.Fatalf("expected err==nil, but got %+v", err) + } + _, err = sut.Connect(context.TODO()) + if err == nil || !strings.Contains(err.Error(), errorText) { + t.Fatalf("expected error to contain %q, but got %q", errorText, err) + } +} diff --git a/appveyor.yml b/appveyor.yml index c4d2bb06..51de3c26 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -30,6 +30,8 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - go get -u golang.org/x/crypto/pkcs12 + - go get -u github.com/Azure/go-autorest/autorest/adal build_script: - go build diff --git a/codecov.yaml b/codecov.yaml new file mode 100644 index 00000000..878af57d --- /dev/null +++ b/codecov.yaml @@ -0,0 +1,2 @@ +ignore: + - token_string.go \ No newline at end of file diff --git a/conn_str.go b/conn_str.go index 4ff54b89..9de5762f 100644 --- a/conn_str.go +++ b/conn_str.go @@ -1,6 +1,7 @@ package mssql import ( + "errors" "fmt" "net" "net/url" @@ -13,6 +14,12 @@ import ( const defaultServerPort = 1433 +const ( + fedAuthActiveDirectoryPassword = "ActiveDirectoryPassword" + fedAuthActiveDirectoryMSI = "ActiveDirectoryMSI" + fedAuthActiveDirectoryApplication = "ActiveDirectoryApplication" +) + type connectParams struct { logFlags uint64 port uint64 @@ -37,6 +44,12 @@ type connectParams struct { failOverPartner string failOverPort uint64 packetSize uint16 + fedAuthLibrary byte + fedAuthADALWorkflow byte + fedAuthAccessToken string + aadTenantID string + aadClientCertPath string + tlsKeyLogFile string } func parseConnectParams(dsn string) (connectParams, error) { @@ -229,6 +242,45 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.fedAuthLibrary = fedAuthLibraryReserved + fedAuth, ok := params["fedauth"] + if ok { + switch { + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryPassword): + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowPassword + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryMSI): + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowMSI + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryApplication): + p.fedAuthLibrary = fedAuthLibrarySecurityToken + p.aadClientCertPath = params["clientcertpath"] + + // Split the user name into client id and tenant id at the @ symbol + at := strings.IndexRune(p.user, '@') + if at < 1 || at >= (len(p.user)-1) { + f := "Expecting user id to be clientID@tenantID: found '%s'" + return p, fmt.Errorf(f, p.user) + } + + p.aadTenantID = p.user[at+1:] + p.user = p.user[0:at] + default: + f := "Invalid federated authentication type '%s': expected %s, %s or %s" + return p, fmt.Errorf(f, fedAuth, fedAuthActiveDirectoryPassword, fedAuthActiveDirectoryMSI, fedAuthActiveDirectoryApplication) + } + + if p.disableEncryption { + f := "Encryption must not be disabled for federated authentication: encrypt='%s'" + return p, fmt.Errorf(f, encrypt) + } + } + + p.tlsKeyLogFile, ok = params["tls key log file"] + if ok && p.tlsKeyLogFile != "" && p.disableEncryption { + return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled") + } + return p, nil } diff --git a/conn_str_test.go b/conn_str_test.go index 7419572c..7e52a69c 100644 --- a/conn_str_test.go +++ b/conn_str_test.go @@ -17,6 +17,15 @@ func TestInvalidConnectionString(t *testing.T) { "trustservercertificate=invalid", "failoverport=invalid", "applicationintent=ReadOnly", + "encrypt=DISABLE;tls key log file=key.log", + + // AAD + "fedauth=ActiveDirectoryApplication;user id=clientidwithouttenantid;clientcertpath=/secrets/spn.pem", + "fedauth=UnknownType", + // encryption cannot be disabled for AAD + "encrypt=DISABLE;fedauth=ActiveDirectoryPassword;user id=tester@tenant.com;password=secret", + "encrypt=DISABLE;fedauth=ActiveDirectoryMSI", + "encrypt=DISABLE;fedauth=ActiveDirectoryApplication;user id=clientid@tenantid;clientcertpath=/secrets/spn.pem", // ODBC mode "odbc:password={", @@ -74,6 +83,17 @@ func TestValidConnectionString(t *testing.T) { {"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }}, {"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }}, + // AAD + {"fedauth=ActiveDirectoryPassword;user id=tester@tenant.com;password=secret", func(p connectParams) bool { + return p.fedAuthLibrary == fedAuthLibraryADAL && p.fedAuthADALWorkflow == fedAuthADALWorkflowPassword + }}, + {"fedauth=ActiveDirectoryMSI", func(p connectParams) bool { + return p.fedAuthLibrary == fedAuthLibraryADAL && p.fedAuthADALWorkflow == fedAuthADALWorkflowMSI + }}, + {"fedauth=ActiveDirectoryApplication;user id=clientid@tenantid;clientcertpath=/secrets/spn.pem", func(p connectParams) bool { + return p.fedAuthLibrary == fedAuthLibrarySecurityToken && p.user == "clientid" && p.aadTenantID == "tenantid" && p.aadClientCertPath == "/secrets/spn.pem" + }}, + // those are supported currently, but maybe should not be {"someparam", func(p connectParams) bool { return true }}, {";;=;", func(p connectParams) bool { return true }}, diff --git a/doc/how-to-test-azure-ad-authentication.md b/doc/how-to-test-azure-ad-authentication.md new file mode 100644 index 00000000..56ccfb68 --- /dev/null +++ b/doc/how-to-test-azure-ad-authentication.md @@ -0,0 +1,176 @@ +# How to test Azure AD authentication + +To test Azure AD authentication requires an Azure SQL server configured with an +[Active Directory administrator](https://docs.microsoft.com/en-us/azure/sql-database/sql-database-aad-authentication-configure). +To test managed identity authentication, an Azure virtual machine configured with +[system-assigned and/or user-assigned identities](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm) +is also required. + +The necessary resources can be set up through any means including the +[Azure Portal](https://portal.azure.com/), the Azure CLI, the Azure PowerShell cmdlets or +[Terraform](https://terraform.io/). To support these instructions, use the Terraform script at +[examples/azuread/testing.tf](../examples/azuread/testing.tf). + +## Create Azure infrastructure + +Download [Terraform](https://terraform.io/) to a location on your PATH. + +Log in to Azure using the Azure CLI. + +```console +you@workstation:~$ az login +you@workstation:~$ az account show +``` + +If your Azure account has access to multiple subscriptions, use +`az account set --subscription ` to choose the correct one. You will need to have at +least Contributor access to the portal and permissions in Azure Active Directory to create users +and grants. + +Check out this source repository (if you haven't already), change to the `examples/azuread` +directory and run Terraform: + +```console +you@workstation:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +you@workstation:~$ cd go-mssqldb/examples/azuread +you@workstation:azuread$ terraform init +you@workstation:azuread$ terraform apply +``` + +This will create an Azure resource group, a SQL server with a database, a virtual machine with a +system-assigned identity and user-assigned identity. Resources are named based on a random +prefix: to specify the prefix, use `terraform apply -var prefix=`. + +Upon successful completion, Terraform will display some key details of the infrastructure that has + been created. This includes the SSH key to access the VM, the administrator account and password + for the Azure SQL server, and all the relevant resource names. + +Save the settings to a JSON file: + +```console +you@workstation:azuread$ terraform output -json > settings.json +``` + +Save the SSH private key to a file: + +```console +you@workstation:azuread$ terraform output vm_user_ssh_private_key > ssh-identity +``` + +Copy the `settings.json` to the new VM: + +```console +you@workstation:azuread$ scp -i ssh-identity settings.json "$(terraform output vm_admin_name)@$(terraform output vm_ip_address):" +``` + +## Set up Azure Virtual Machine for testing + +SSH to the new VM to continue setup: + +```console +you@workstation:azuread$ ssh -i ssh-identity "$(terraform output vm_admin_name)@$(terraform output vm_ip_address)" +``` + +Once on the VM, update the system and install some basic packages: + +```console +azureuser@azure-vm:~$ sudo apt update -y +azureuser@azure-vm:~$ sudo apt upgrade -y +azureuser@azure-vm:~$ sudo apt install -y git openssl jq build-essential +azureuser@azure-vm:~$ sudo snap install go --classic +``` + +Install the Azure CLI using the script as shown below, or follow the +[manual install instructions](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli-apt): + +```console +azureuser@azure-vm:~$ curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash +``` + +## Generate service principal certificate file + +Log in to Azure on the VM and set the subscription: + +```console +azureuser@azure-vm:~$ az login +azureuser@azure-vm:~$ az account set --subscription "$(jq -r '.subscription_id.value' settings.json)" +``` + +Use OpenSSL to create a new certificate and key in PEM format, using the : + +```console +azureuser@azure-vm:~$ openssl rand -writerand ~/.rnd +azureuser@azure-vm:~$ openssl req -x509 -nodes -newkey rsa:4096 -keyout client.key -out client.crt \ + -subj "/C=US/ST=MA/L=Boston/O=Global Security/OU=IT Department/CN=AD-SP" +azureuser@azure-vm:~$ openssl pkcs12 -export -out client.p12 -inkey client.key -in client.crt \ + -passout "pass:$(jq -r '.app_sp_client_secret.value' settings.json)" +azureuser@azure-vm:~$ export APP_SP_CLIENT_CERT="$PWD/client.p12" +``` + +Use the Azure CLI to add the client certificate to the application service principal: + +```console +azureuser@azure-vm:~$ az ad sp credential reset --append --cert @client.crt \ + --name "$(jq -r '.app_sp_client_id.value' settings.json)" +``` + +## Build source code and authorize users in database + +Clone this repository, build and run the `examples/azuread` helper that verifies the database +exists and sets up access for the system-assigned and user-assigned identities. + +```console +azureuser@azure-vm:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +azureuser@azure-vm:~$ cd go-mssqldb +azureuser@azure-vm:go-mssqldb$ go generate ./... +azureuser@azure-vm:go-mssqldb$ go build -o azuread ./examples/azuread +azureuser@azure-vm:go-mssqldb$ eval "$(jq -r -f examples/azuread/environment-settings.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ ./azuread -fedauth ActiveDirectoryPassword +``` + +For some basic connectivity tests, use the `examples/simple` helper. Run these commands on the +Azure VM so that identity authentication is possible. + +```console +azureuser@azure-vm:go-mssqldb$ go build -o simple ./examples/simple +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-password-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-service-principal-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +``` + +## Running the integration tests + +Now that your environment is configured, you can run `go test`: + +```console +azureuser@azure-vm:go-mssqldb$ export SQLSERVER_DSN="$(jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ go test -coverprofile=coverage.out ./... +``` + +## Tear down environment + +After you complete your testing, use Terraform to destroy the infrastructure you created. + +```console +you@workstation:azuread$ terraform destroy +``` + +## Troubleshooting + +After Terraform runs you should be able to see resources that were created in the +[Azure Portal](https://portal.azure.com/). + +If the Azure SQL server is successfully created you can connect to it using the AD admin user +and password in SSMS. SSMS will prompt you to create firewall rules if they are missing. You +can read the AD admin user and password from the `settings.json`, or run: + +```console +you@workstation:azuread$ terraform output sql_ad_admin_user +you@workstation:azuread$ terraform output sql_ad_admin_password +``` + diff --git a/examples/azure-managed-identity/README.md b/examples/azure-managed-identity/README.md new file mode 100644 index 00000000..cb20a760 --- /dev/null +++ b/examples/azure-managed-identity/README.md @@ -0,0 +1,9 @@ +## Azure Managed Identity example + +This example shows how Azure Managed Identity can be used to access SQL Azure. Take note of the +trust boundary before using MSI to prevent exposure of the tokens outside of the trust boundary. + +This example can only be run from a Azure Virtual Machine with Managed Identity configured. +You can follow the steps from this tutorial to turn on managed identity for your VM and grant the +VM access to a SQL Azure database: +https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/tutorial-windows-vm-access-sql diff --git a/examples/azure-managed-identity/managed_identity.go b/examples/azure-managed-identity/managed_identity.go new file mode 100644 index 00000000..2b989406 --- /dev/null +++ b/examples/azure-managed-identity/managed_identity.go @@ -0,0 +1,88 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + + "github.com/Azure/go-autorest/autorest/adal" + mssql "github.com/denisenkom/go-mssqldb" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", "", "the database server") + database = flag.String("database", "", "the database") +) + +func main() { + flag.Parse() + + if *debug { + fmt.Printf(" server:%s\n", *server) + fmt.Printf(" database:%s\n", *database) + } + + if *server == "" { + log.Fatal("Server name cannot be left empty") + } + + if *database == "" { + log.Fatal("Database name cannot be left empty") + } + + connString := fmt.Sprintf("Server=%s;Database=%s", *server, *database) + if *debug { + fmt.Printf(" connString:%s\n", connString) + } + + tokenProvider, err := getMSITokenProvider() + if err != nil { + log.Fatal("Error creating token provider for system assigned Azure Managed Identity:", err.Error()) + } + + connector, err := mssql.NewAccessTokenConnector( + connString, tokenProvider) + if err != nil { + log.Fatal("Connector creation failed:", err.Error()) + } + conn := sql.OpenDB(connector) + defer conn.Close() + + stmt, err := conn.Prepare("select 1, 'abc'") + if err != nil { + log.Fatal("Prepare failed:", err.Error()) + } + defer stmt.Close() + + row := stmt.QueryRow() + var somenumber int64 + var somechars string + err = row.Scan(&somenumber, &somechars) + if err != nil { + log.Fatal("Scan failed:", err.Error()) + } + fmt.Printf("somenumber:%d\n", somenumber) + fmt.Printf("somechars:%s\n", somechars) + + fmt.Printf("bye\n") +} + +func getMSITokenProvider() (func() (string, error), error) { + msiEndpoint, err := adal.GetMSIEndpoint() + if err != nil { + return nil, err + } + msi, err := adal.NewServicePrincipalTokenFromMSI( + msiEndpoint, "https://database.windows.net/") + if err != nil { + return nil, err + } + + return func() (string, error) { + msi.EnsureFresh() + token := msi.OAuthToken() + return token, nil + }, nil +} diff --git a/examples/azuread/ad-service-principal-dsn.jq b/examples/azuread/ad-service-principal-dsn.jq new file mode 100644 index 00000000..b2426037 --- /dev/null +++ b/examples/azuread/ad-service-principal-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.app_sp_client_id.value)%40\(.tenant_id.value):\(.app_sp_client_secret.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryApplication" \ No newline at end of file diff --git a/examples/azuread/ad-system-assigned-id-dsn.jq b/examples/azuread/ad-system-assigned-id-dsn.jq new file mode 100644 index 00000000..288f2b7c --- /dev/null +++ b/examples/azuread/ad-system-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-assigned-id-dsn.jq b/examples/azuread/ad-user-assigned-id-dsn.jq new file mode 100644 index 00000000..df31d09d --- /dev/null +++ b/examples/azuread/ad-user-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.user_assigned_identity_client_id.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-password-dsn.jq b/examples/azuread/ad-user-password-dsn.jq new file mode 100644 index 00000000..beebc5e1 --- /dev/null +++ b/examples/azuread/ad-user-password-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_ad_admin_user.value):\(.sql_ad_admin_password.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryPassword" \ No newline at end of file diff --git a/examples/azuread/azuread.go b/examples/azuread/azuread.go new file mode 100644 index 00000000..cbcf9342 --- /dev/null +++ b/examples/azuread/azuread.go @@ -0,0 +1,144 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "net/url" + "os" + "strings" + "time" + + _ "github.com/denisenkom/go-mssqldb" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", os.Getenv("SQL_SERVER"), "the database server name") + port = flag.Int("port", 1433, "the database port") + database = flag.String("database", os.Getenv("SQL_DATABASE"), "the database name") + user = flag.String("user", os.Getenv("SQL_AD_ADMIN_USER"), "the AD administrator user name") + password = flag.String("password", os.Getenv("SQL_AD_ADMIN_PASSWORD"), "the AD administrator password") + fedauth = flag.String("fedauth", "ActiveDirectoryPassword", "the federated authentication scheme to use") + appName = flag.String("app-name", os.Getenv("APP_NAME"), "the application name to authorize") + vmName = flag.String("vm-name", os.Getenv("VM_NAME"), "the system identity name to authorize for this VM") + uaName = flag.String("ua-name", os.Getenv("UA_NAME"), "the user assigned identity name to authorize for this VM") +) + +func createConnStr(database string) string { + connString := fmt.Sprintf("sqlserver://%s:%s@%s:%d?encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port) + + if database != "" && database != "master" { + connString = connString + "&database=" + url.QueryEscape(database) + } + + if *fedauth != "" { + connString = connString + "&fedauth=" + url.QueryEscape(*fedauth) + } + + if *debug { + connString = connString + "&log=127" + } + + return connString +} + +func createDatabaseIfNotExists() error { + // Check database exists by connecting to master on the Azure SQL server + connString := createConnStr("master") + + log.Printf("Open: %s\n", connString) + + conn, err := sql.Open("sqlserver", connString) + if err != nil { + return err + } + + defer conn.Close() + + if err = conn.Ping(); err != nil { + return err + } + + quoted := strings.Replace(*database, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.databases WHERE name = @p1)\n CREATE DATABASE [" + quoted + "] ( SERVICE_OBJECTIVE = 'S0' )" + log.Printf("Exec: @p1 = '%s'\n%s\n", *database, sql) + _, err = conn.Exec(sql, *database) + + return err +} + +func addExternalUserIfNotExists(user string) error { + connString := createConnStr(*database) + + log.Printf("Open: %s\n", connString) + + var conn *sql.DB + var err error + + for retry := 0; retry < 8; retry++ { + conn, err = sql.Open("sqlserver", connString) + if err == nil { + if err = conn.Ping(); err == nil { + break + } + } + log.Printf("Connection failed: %v", err) + log.Println("Retry in 15 seconds") + time.Sleep(15 * time.Second) + } + if err != nil { + log.Printf("Connection failed: %v", err) + log.Println("No further retries will be attempted") + return err + } + + defer conn.Close() + + quoted := strings.Replace(user, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = @p1)\n CREATE USER [" + quoted + "] FROM EXTERNAL PROVIDER" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + if err != nil { + return err + } + + sql = "IF IS_ROLEMEMBER('db_owner', @p1) = 0\n ALTER ROLE [db_owner] ADD MEMBER [" + quoted + "]" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + + return err +} + +func main() { + flag.Parse() + + err := createDatabaseIfNotExists() + if err != nil { + log.Fatalf("Unable to create database [%s]: %v", *database, err) + } + + if *vmName != "" { + err = addExternalUserIfNotExists(*vmName) + if err != nil { + log.Fatalf("Unable to create user for system-assigned identity [%s]: %v", *vmName, err) + } + } + + if *appName != "" { + err = addExternalUserIfNotExists(*appName) + if err != nil { + log.Fatalf("Unable to create user for application identity [%s]: %v", *appName, err) + } + } + + if *uaName != "" { + err = addExternalUserIfNotExists(*uaName) + if err != nil { + log.Fatalf("Unable to create user for user-assigned identity [%s]: %v", *uaName, err) + } + } +} diff --git a/examples/azuread/environment-settings.jq b/examples/azuread/environment-settings.jq new file mode 100644 index 00000000..a8c9192a --- /dev/null +++ b/examples/azuread/environment-settings.jq @@ -0,0 +1,20 @@ +# Convert Terraform settings to shell environment exports. +[ + "set -a", + "SQL_SERVER=" + (.sql_server_fqdn.value | @sh), + "SQL_ADMIN_USER=" + (.sql_admin_user.value | @sh), + "SQL_ADMIN_PASSWORD=" + (.sql_admin_password.value | @sh), + "SQL_AD_ADMIN_USER=" + (.sql_ad_admin_user.value | @sh), + "SQL_AD_ADMIN_PASSWORD=" + (.sql_ad_admin_password.value | @sh), + "APP_SP_CLIENT_ID=" + (.app_sp_client_id.value | @sh), + "APP_SP_CLIENT_SECRET=" + (.app_sp_client_secret.value | @sh), + "SQL_DATABASE=" + (.sql_database_name.value | @sh), + "APP_NAME=" + (.app_name.value | @sh), + "VM_NAME=" + (.vm_name.value | @sh), + "VM_CLIENT_ID=" + (.vm_client_id.value | @sh), + "UA_NAME=" + (.user_assigned_identity_name.value | @sh), + "UA_CLIENT_ID=" + (.user_assigned_identity_client_id.value | @sh), + "AZURE_SUBSCRIPTION_ID=" + (.subscription_id.value | @sh), + "AZURE_TENANT_ID=" + (.tenant_id.value | @sh), + "set +a" +] | map([.]) | .[] | @tsv diff --git a/examples/azuread/testing.tf b/examples/azuread/testing.tf new file mode 100644 index 00000000..cef943d8 --- /dev/null +++ b/examples/azuread/testing.tf @@ -0,0 +1,490 @@ +# +# Terraform setup for Azure SQL with Azure Active Directory authentication +# + +# +# Set up Terraform provider versions +# +provider "azuread" { + version = "~> 0.7" +} + +provider "azurerm" { + version = "~> 1.36" +} + +provider "http" { + version = "~> 1.1" +} + +provider "random" { + version = "~> 2.2" +} + +provider "tls" { + version = "~> 2.1" +} + +# +# Variables +# +# These variables allow limited overrides to control the resource creation. +# To specify, run terraform apply -var name1=value1 [-var name2=value2]... +# E.g. terraform apply -var prefix=my-stuff +# will use "my-stuff" in place of the randomly generated ID that is used by default. +# +variable "prefix" { + description = "Prefix for Azure resource names" + type = string + default = "" +} + +variable "location" { + description = "Azure location for resources" + type = string + default = "East US" +} + +variable "vm_admin_name" { + description = "Name of administrative user on virtual machine" + type = string + default = "azureuser" +} + +variable "ssh_key" { + description = "Path to RSA SSH private key (unencrypted)" + type = string + default = "~/.ssh/id_rsa" +} + +variable "workstation_ip" { + description = "IP address of this workstation to add to SQL server firewall rules" + type = string + default = "" +} + +# +# If the prefix is not specified via the variable, a sixteen character alphanumeric suffix is +# generated and then the prefix is set to "go-mssql-test-" + +# +resource "random_string" "random_prefix" { + length = 16 + lower = true + number = true + upper = false + special = false +} + +# +# Set up a local variable to capture the prefix to use - either the user-specified from the +# variable, or else the generated name using the random string above. +# +# Some resource names (e.g. SQL server) are more restricted than others - e.g. hyphens are +# not permitted - so we create a restricted name prefix as well as a regular name prefix. +# +locals { + regular_name_prefix = var.prefix != "" ? var.prefix : "go-mssql-test-${random_string.random_prefix.result}" + restricted_name_prefix = var.prefix != "" ? lower(replace(var.prefix, "/[^A-Za-z0-9]/", "")) : "gomssqltest${random_string.random_prefix.result}" +} + +# +# SSH Key - generate if not available at the file named in the variable. +# Terraform will complain if var.ssh_key is empty as this is interpreted as referring to the +# current working directory, and that is not a file. Instead, if you want to avoid using an +# existing SSH key, make it a literal "no" or some other string that is not an existing file or +# directory. +# +data "tls_public_key" "file_ssh_key" { + count = fileexists(var.ssh_key) ? 1 : 0 + private_key_pem = fileexists(var.ssh_key) ? file(var.ssh_key) : "" +} + +resource "tls_private_key" "rand_ssh_key" { + algorithm = "ECDSA" +} + +locals { + private_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.private_key_pem : tls_private_key.rand_ssh_key.private_key_pem + public_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_pem : tls_private_key.rand_ssh_key.public_key_pem + public_key_openssh = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_openssh : tls_private_key.rand_ssh_key.public_key_openssh +} + +# +# Retrieve tenant, subscription and default domain information based on the current Azure login. +# +data "azurerm_client_config" "current" { +} + +data "azurerm_subscription" "current" { +} + +data "azuread_domains" "current" { + only_default = "true" +} + +# +# Use ipify.org to determine workstation IP if not provided. +# If this guesses incorrectly, specify your workstation IP with -var worstation_ip= +# when you run terraform apply. +# +data "http" "workstation_ip" { + url = "https://api.ipify.org/" +} + +locals { + workstation_ip = var.workstation_ip != "" ? var.workstation_ip : chomp(data.http.workstation_ip.body) +} + +# +# Set up the Azure resource group for all the test resources. +# +resource "azurerm_resource_group" "rg" { + name = "${local.regular_name_prefix}-rg" + location = var.location +} + +# +# Set up an AD User to use as AD Administrator for the Azure SQL server. +# +# Using a regular user account makes it simpler to log in as the user with SSMS or the Go +# driver when setting up the other permissions for the identities that will be tested. +# It appears to although you can make the AD Administrator a service principal, doing so +# leads to issues during logins that do not occur when the AD Administrator is a normal +# AD User account. +# +resource "random_password" "sql_ad_admin_sp_password" { + length = 32 + special = true +} + +resource "azuread_user" "sql_ad_admin" { + user_principal_name = "SQLAdmin.${local.restricted_name_prefix}@${data.azuread_domains.current.domains[0].domain_name}" + display_name = "SQL Admin for ${local.restricted_name_prefix}" + mail_nickname = "SQLAdmin.${local.restricted_name_prefix}" + password = random_password.sql_ad_admin_sp_password.result +} + +# +# Set up the Azure SQL Server +# +# A normal (non-AD) administrator username and password are also provisioned. However, it is +# not possible to create AD users without logging in via an AD-authenticated account, so this +# non-AD administrator is not able to create new AD user accounts. +# +resource "random_password" "sql_admin_password" { + length = 16 + special = true +} + +resource "azurerm_sql_server" "sql_server" { + name = local.restricted_name_prefix + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + version = "12.0" + administrator_login = "sql-admin" + administrator_login_password = random_password.sql_admin_password.result +} + +resource "azurerm_sql_active_directory_administrator" "sql_server" { + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + login = "sql-ad-admin" + tenant_id = data.azurerm_client_config.current.tenant_id + object_id = azuread_user.sql_ad_admin.id +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_azure" { + name = "AllowAzureAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = "0.0.0.0" + end_ip_address = "0.0.0.0" +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_workstation" { + name = "AllowWorkstationAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = local.workstation_ip + end_ip_address = local.workstation_ip +} + +# +# Set up the test database on the Azure SQL server +# +resource "azurerm_sql_database" "sql_db" { + name = "go-mssqldb" + + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + location = azurerm_sql_server.sql_server.location + + requested_service_objective_name = "S0" +} + +# +# Create a service principal that will be granted access to the database, +# representing an application login to the database. +# +resource "azuread_application" "app" { + name = "${local.regular_name_prefix}-app" +} + +resource "azuread_service_principal" "app_sp" { + application_id = azuread_application.app.application_id + app_role_assignment_required = false +} + +resource "random_password" "app_sp_password" { + length = 32 + special = true +} + +resource "azuread_service_principal_password" "app_sp" { + service_principal_id = azuread_service_principal.app_sp.id + value = random_password.app_sp_password.result + end_date_relative = "8760h" +} + + +# +# Create a user-assigned identity that we will add to the VM in addition to the +# system-assigned identity. +# +resource "azurerm_user_assigned_identity" "vm_user_id" { + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + name = "${local.restricted_name_prefix}-user-id" +} + +# +# Create an Azure VM for testing managed identity authentication. +# +# To support the Azure VM, we need a virtual network, a subnet, the public IP, the network +# security group, and the network interface. The network security group allows incoming SSH +# from the anywhere on the internet. +# +resource "azurerm_virtual_network" "vm_vnet" { + name = "${local.regular_name_prefix}-vnet" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + address_space = ["10.0.0.0/16"] +} + +resource "azurerm_subnet" "vm_subnet" { + name = "${local.regular_name_prefix}-vm-sn" + resource_group_name = azurerm_resource_group.rg.name + virtual_network_name = azurerm_virtual_network.vm_vnet.name + address_prefix = "10.0.2.0/24" +} + +resource "azurerm_public_ip" "vm_ip" { + name = "${local.regular_name_prefix}-vm-ip" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + allocation_method = "Dynamic" + idle_timeout_in_minutes = 30 +} + +resource "azurerm_network_security_group" "vm_nsg" { + name = "${local.regular_name_prefix}-vm-nsg" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + security_rule { + name = "SSH" + priority = 1001 + direction = "Inbound" + access = "Allow" + protocol = "Tcp" + source_port_range = "*" + destination_port_range = "22" + source_address_prefix = "*" + destination_address_prefix = "*" + } +} + +resource "azurerm_network_interface" "vm_nic" { + name = "${local.regular_name_prefix}-vm-nic" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_security_group_id = azurerm_network_security_group.vm_nsg.id + + ip_configuration { + name = "${local.regular_name_prefix}-vm-nic-config" + subnet_id = azurerm_subnet.vm_subnet.id + private_ip_address_allocation = "Dynamic" + public_ip_address_id = azurerm_public_ip.vm_ip.id + } +} + +# +# Given the networking setup, now create the Azure VM +# +resource "azurerm_virtual_machine" "vm" { + name = "${local.regular_name_prefix}-vm" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_interface_ids = [azurerm_network_interface.vm_nic.id] + vm_size = "Standard_B1s" + + storage_os_disk { + name = "${local.regular_name_prefix}-vm-os" + caching = "ReadWrite" + create_option = "FromImage" + managed_disk_type = "Standard_LRS" + } + + storage_image_reference { + publisher = "Canonical" + offer = "UbuntuServer" + sku = "18.04-LTS" + version = "latest" + } + + os_profile { + computer_name = "${local.regular_name_prefix}-vm" + admin_username = var.vm_admin_name + } + + os_profile_linux_config { + disable_password_authentication = true + ssh_keys { + path = "/home/${var.vm_admin_name}/.ssh/authorized_keys" + key_data = local.public_key_openssh + } + } + + # Configure the VM with both SystemAssigned and a UserAssigned identity + identity { + type = "SystemAssigned, UserAssigned" + identity_ids = [azurerm_user_assigned_identity.vm_user_id.id] + } +} + +# Retrieve the application ID corresponding to the service principal ID assigned to the VM. +data "azuread_service_principal" "vm_sp" { + object_id = azurerm_virtual_machine.vm.identity.0.principal_id +} + +# Wait for public IP to be assigned after VM is created so we can report it in the outputs. +data "azurerm_public_ip" "vm_ip" { + name = azurerm_public_ip.vm_ip.name + resource_group_name = azurerm_virtual_machine.vm.resource_group_name +} + +# +# After provisioning or refreshing, Terraform will populate these outputs. +# These capture the necessary pieces of information to access the new infrastructure. +# +output "tenant_id" { + description = "Azure tenant ID" + value = data.azurerm_client_config.current.tenant_id +} + +output "subscription_id" { + description = "Azure subscription ID" + value = data.azurerm_client_config.current.subscription_id +} + +output "sql_server_name" { + description = "Azure SQL server name" + value = azurerm_sql_server.sql_server.name +} + +output "sql_server_fqdn" { + description = "Azure SQL server domain name" + value = azurerm_sql_server.sql_server.fully_qualified_domain_name +} + +output "sql_ad_admin_user" { + description = "Azure SQL administrator name (AD authentication)" + value = azuread_user.sql_ad_admin.user_principal_name +} + +output "sql_ad_admin_password" { + description = "Azure SQL administrator password (AD authentication)" + value = random_password.sql_ad_admin_sp_password.result + sensitive = true +} + +output "sql_admin_user" { + description = "Azure SQL administrator name (SQL server authentication)" + value = azurerm_sql_server.sql_server.administrator_login +} + +output "sql_admin_password" { + description = "Azure SQL administrator password (SQL server authentication)" + value = random_password.sql_admin_password.result + sensitive = true +} + +output "sql_database_name" { + description = "Azure SQL database name" + value = azurerm_sql_database.sql_db.name +} + +output "vm_name" { + description = "Azure virtual machine name" + value = azurerm_virtual_machine.vm.name +} + +output "vm_client_id" { + description = "Azure VM system-assigned identity client ID" + value = data.azuread_service_principal.vm_sp.application_id +} + +output "vm_principal_id" { + description = "Azure VM system-assigned identity principal ID" + value = azurerm_virtual_machine.vm.identity.0.principal_id +} + +output "vm_ip_address" { + description = "Azure virtual machine public IP" + value = data.azurerm_public_ip.vm_ip.ip_address +} + +output "vm_admin_name" { + description = "Azure virtual machine admin user name" + value = var.vm_admin_name +} + +output "vm_user_ssh_private_key" { + description = "Azure virtual machine admin user private SSH key" + value = local.private_key_pem + sensitive = true +} + +output "vm_user_ssh_openssh_key" { + description = "Azure virtual machine admin user SSH public key" + value = local.public_key_openssh + sensitive = true +} + +output "app_sp_client_id" { + description = "Service principal client ID for application user" + value = azuread_application.app.application_id +} + +output "app_name" { + description = "Service principal name for application user" + value = azuread_application.app.name +} + +output "app_sp_client_secret" { + description = "Service principal client secret for application user" + value = random_password.app_sp_password.result + sensitive = true +} + +output "user_assigned_identity_name" { + description = "User-assigned identity for the Azure virtual machine" + value = azurerm_user_assigned_identity.vm_user_id.name +} + +output "user_assigned_identity_client_id" { + description = "User-assigned identity client ID" + value = azurerm_user_assigned_identity.vm_user_id.client_id +} diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 67f88aa4..2c8965c2 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,12 +5,16 @@ import ( "flag" "fmt" "log" + "net/url" + "os" _ "github.com/denisenkom/go-mssqldb" ) var ( + database = flag.String("database", "", "the database name") debug = flag.Bool("debug", false, "enable debugging") + dsn = flag.String("dsn", os.Getenv("SQLSERVER_DSN"), "complete SQL DSN") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") server = flag.String("server", "", "the database server") @@ -20,24 +24,35 @@ var ( func main() { flag.Parse() - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) + var connString string + + if *dsn == "" { + if *debug { + fmt.Printf(" server: %s\n", *server) + fmt.Printf(" port: %d\n", *port) + fmt.Printf(" user: %s\n", *user) + fmt.Printf(" password: %s\n", *password) + fmt.Printf(" database: %s\n", *database) + } + + connString = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port, url.QueryEscape(*database)) + } else { + connString = *dsn } - connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) if *debug { - fmt.Printf(" connString:%s\n", connString) + fmt.Printf(" dsn: %s\n", connString) } + conn, err := sql.Open("mssql", connString) if err != nil { log.Fatal("Open connection failed:", err.Error()) } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("select 1, 'abc', suser_name()") if err != nil { log.Fatal("Prepare failed:", err.Error()) } @@ -46,12 +61,14 @@ func main() { row := stmt.QueryRow() var somenumber int64 var somechars string - err = row.Scan(&somenumber, &somechars) + var someuser string + err = row.Scan(&somenumber, &somechars, &someuser) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) + fmt.Printf("number: %d\n", somenumber) + fmt.Printf("chars: %s\n", somechars) + fmt.Printf("user: %s\n", someuser) fmt.Printf("bye\n") } diff --git a/examples/tvp/tvp.go b/examples/tvp/tvp.go index a07bb652..eae614ef 100644 --- a/examples/tvp/tvp.go +++ b/examples/tvp/tvp.go @@ -1,3 +1,5 @@ +// +build go1.9 + package main import ( diff --git a/fedauth.go b/fedauth.go new file mode 100644 index 00000000..8dbfcf44 --- /dev/null +++ b/fedauth.go @@ -0,0 +1,101 @@ +package mssql + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + + "github.com/Azure/go-autorest/autorest/adal" + "golang.org/x/crypto/pkcs12" +) + +const ( + activeDirectoryEndpoint = "https://login.microsoftonline.com/" + azureSQLResource = "https://database.windows.net/" + driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db" +) + +func fedAuthGetClientCertificate(clientCertPath, clientCertPassword string) (*x509.Certificate, *rsa.PrivateKey, error) { + pkcs, err := ioutil.ReadFile(clientCertPath) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + privateKey, certificate, err := pkcs12.Decode(pkcs, clientCertPassword) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) + if !isRsaKey { + return nil, nil, fmt.Errorf("AD client certificate at path %s must contain an RSA private key", clientCertPath) + } + + return certificate, rsaPrivateKey, nil +} + +func fedAuthGetAccessToken(ctx context.Context, resource, tenantID string, p connectParams, log optionalLogger) (accessToken string, err error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. When the workflow provides a complete endpoint + // URL for the tenant, the URL resolution just returns that endpoint. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + log.Printf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", activeDirectoryEndpoint, tenantID, err) + return "", err + } + + var token *adal.ServicePrincipalToken + if p.fedAuthLibrary == fedAuthLibrarySecurityToken { + // When the security token library is used, the token is obtained without input + // from the server, so the AD endpoint and Azure SQL resource URI are provided + // from the constants above. + if p.aadClientCertPath != "" { + var certificate *x509.Certificate + var rsaPrivateKey *rsa.PrivateKey + certificate, rsaPrivateKey, err = fedAuthGetClientCertificate(p.aadClientCertPath, p.password) + if err == nil { + token, err = adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, p.user, certificate, rsaPrivateKey, azureSQLResource) + } + } else { + token, err = adal.NewServicePrincipalToken(*oauthConfig, p.user, p.password, azureSQLResource) + } + } else if p.fedAuthLibrary == fedAuthLibraryADAL { + // When the ADAL workflow is used, the server provides the endpoint (STS URL) + // and resource (server SPN) during the login process. The STS URL is passed + // as the tenant ID and has already been used to build the OAuth config. + if p.fedAuthADALWorkflow == fedAuthADALWorkflowPassword { + token, err = adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, p.user, p.password, resource) + + } else if p.fedAuthADALWorkflow == fedAuthADALWorkflowMSI { + // When using MSI, to request a specific client ID or user-assigned identity, + // provide the ID as the username. + var msiEndpoint string + msiEndpoint, err = adal.GetMSIEndpoint() + if err == nil { + if p.user == "" { + token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource) + } else { + token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, p.user) + } + } + } + } else { + return "", errors.New("Unsupported federated authentication library") + } + + if err != nil { + log.Printf("Failed to obtain service principal token for client id %s in tenant %s: %v", p.user, tenantID, err) + return "", err + } + + err = token.RefreshWithContext(ctx) + if err != nil { + log.Printf("Failed to refresh service principal token for client id %s in tenant %s: %v", p.user, tenantID, err) + return "", err + } + + return token.Token().AccessToken, nil +} diff --git a/fedauth_test.go b/fedauth_test.go new file mode 100644 index 00000000..7b6e1de1 --- /dev/null +++ b/fedauth_test.go @@ -0,0 +1,189 @@ +package mssql + +import ( + "context" + "database/sql" + "net/url" + "os" + "strings" + "testing" +) + +func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) { + u := &url.URL{ + Scheme: "sqlserver", + Host: os.Getenv("SQL_SERVER"), + } + + if u.Host == "" { + t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable") + } + + database := os.Getenv("SQL_DATABASE") + if database == "" { + t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable") + } + + tenantID := os.Getenv("AZURE_TENANT_ID") + if tenantID == "" { + t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable") + } + + query := u.Query() + + query.Add("database", database) + query.Add("encrypt", "true") + query.Add("fedauth", fedAuth) + + u.RawQuery = query.Encode() + + return u, tenantID +} + +func checkFedAuthUserPassword(t *testing.T) *url.URL { + u, _ := checkAzureSQLEnvironment(fedAuthActiveDirectoryPassword, t) + + username := os.Getenv("SQL_AD_ADMIN_USER") + password := os.Getenv("SQL_AD_ADMIN_PASSWORD") + + if username == "" || password == "" { + t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables") + } + + u.User = url.UserPassword(username, password) + + return u +} + +func checkFedAuthAppPassword(t *testing.T) *url.URL { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryApplication, t) + + appClientID := os.Getenv("APP_SP_CLIENT_ID") + appPassword := os.Getenv("APP_SP_CLIENT_SECRET") + + if appClientID == "" || appPassword == "" { + t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables") + } + + u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword) + + return u +} + +func checkFedAuthAppCertPath(t *testing.T) *url.URL { + u := checkFedAuthAppPassword(t) + + appCertPath := os.Getenv("APP_SP_CLIENT_CERT") + if appCertPath == "" { + t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate") + } + + query := u.Query() + query.Add("clientcertpath", appCertPath) + u.RawQuery = query.Encode() + + return u +} + +func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryMSI, t) + + vmClientID := os.Getenv("VM_CLIENT_ID") + if vmClientID == "" { + t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable") + } + + return u, vmClientID + "@" + tenantID +} + +func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryMSI, t) + + uaClientID := os.Getenv("UA_CLIENT_ID") + if uaClientID == "" { + t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable") + } + + u.User = url.User(uaClientID) + + return u, uaClientID + "@" + tenantID +} + +func checkLoggedInUser(expected string, u *url.URL, t *testing.T) { + db, err := sql.Open("sqlserver", u.String()) + if err != nil { + t.Fatalf("Failed to open URL %v: %v", u, err) + } + + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sql := "SELECT SUSER_NAME()" + + stmt, err := db.PrepareContext(ctx, sql) + if err != nil { + t.Fatalf("Failed to prepare query %s: %v", sql, err) + } + + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatalf("Failed to fetch query result for %s: %v", sql, err) + } + + defer rows.Close() + + var username string + if !rows.Next() { + t.Fatalf("Empty result set for query %s", sql) + } + + err = rows.Scan(&username) + if err != nil { + t.Fatalf("Failed to fetch first row for %s: %v", sql, err) + } + + if !strings.EqualFold(username, expected) { + t.Fatalf("Expected username %s: actual: %s", expected, username) + } + + t.Logf("Logged in username %s matches expected %s", username, expected) +} + +func TestFedAuthWithUserAndPassword(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthUserPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingPassword(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthAppPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingCertificate(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthAppCertPath(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithSystemAssignedIdentity(t *testing.T) { + u, vmName := checkFedAuthVMSystemID(t) + SetLogger(testLogger{t}) + + checkLoggedInUser(vmName, u, t) +} + +func TestFedAuthWithUserAssignedIdentity(t *testing.T) { + SetLogger(testLogger{t}) + u, uaName := checkFedAuthVMUserAssignedID(t) + + checkLoggedInUser(uaName, u, t) +} diff --git a/go.mod b/go.mod index ebc02ab8..2526b04a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/Azure/go-autorest/autorest/adal v0.8.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd // indirect ) diff --git a/go.sum b/go.sum index 1887801b..581ab1c2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,30 @@ +github.com/Azure/go-autorest v13.3.0+incompatible h1:8Ix0VdeOllBx9jEcZ2Wb1uqWUpE1awmJiaHztwaJCPk= +github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest v0.9.2 h1:6AWuh3uWrsZJcNoCHrCF/+g4aKPCU39kaMO6/qrnK/4= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/adal v0.8.0 h1:CxTzQrySOxDnKpLjFJeZAS5Qrv/qFPkgLjx5bOAi//I= +github.com/Azure/go-autorest/autorest/adal v0.8.0/go.mod h1:Z6vX6WXXuyieHAXwMj0S6HY6e6wcHn37qQMBQlvY3lc= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSWlm5Ew6bxipnr/tbM= +github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd h1:Zc7EU2PqpsNeIfOoVA7hvQX4cS3YDJEs5KlfatT3hLo= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/log_conn.go b/log_conn.go new file mode 100644 index 00000000..4777e4c1 --- /dev/null +++ b/log_conn.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "encoding/hex" + "net" + "strings" + "time" +) + +type connLogger struct { + conn net.Conn + readKind, writeKind string + readCount, writeCount int + logger Logger +} + +var _ net.Conn = &connLogger{} + +func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn { + if len(kind) > 0 && !strings.HasPrefix(kind, " ") { + kind = " " + kind + } + + cl := &connLogger{ + conn: conn, + readKind: "R" + kind, + writeKind: "W" + kind, + logger: logger, + } + + return cl +} + +func (cl *connLogger) Read(p []byte) (n int, err error) { + n, err = cl.conn.Read(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump) + cl.readCount += n + } + + return +} + +func (cl *connLogger) Write(p []byte) (n int, err error) { + n, err = cl.conn.Write(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump) + cl.writeCount += n + } + + return +} + +func (cl *connLogger) Close() (err error) { + return cl.conn.Close() +} + +func (cl *connLogger) LocalAddr() net.Addr { + return cl.conn.LocalAddr() +} + +func (cl *connLogger) RemoteAddr() net.Addr { + return cl.conn.RemoteAddr() +} + +func (cl *connLogger) SetDeadline(t time.Time) error { + return cl.conn.SetDeadline(t) +} + +func (cl *connLogger) SetReadDeadline(t time.Time) error { + return cl.conn.SetReadDeadline(t) +} + +func (cl *connLogger) SetWriteDeadline(t time.Time) error { + return cl.conn.SetWriteDeadline(t) +} diff --git a/tds.go b/tds.go index 94198364..c4c8da9b 100644 --- a/tds.go +++ b/tds.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "os" "sort" "strconv" "strings" @@ -89,24 +90,27 @@ const ( // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx packAttention = 6 - packBulkLoadBCP = 7 - packTransMgrReq = 14 - packNormal = 15 - packLogin7 = 16 - packSSPIMessage = 17 - packPrelogin = 18 + packBulkLoadBCP = 7 + packFedAuthToken = 8 + packTransMgrReq = 14 + packNormal = 15 + packLogin7 = 16 + packSSPIMessage = 17 + packPrelogin = 18 ) // prelogin fields // http://msdn.microsoft.com/en-us/library/dd357559.aspx const ( - preloginVERSION = 0 - preloginENCRYPTION = 1 - preloginINSTOPT = 2 - preloginTHREADID = 3 - preloginMARS = 4 - preloginTRACEID = 5 - preloginTERMINATOR = 0xff + preloginVERSION = 0 + preloginENCRYPTION = 1 + preloginINSTOPT = 2 + preloginTHREADID = 3 + preloginMARS = 4 + preloginTRACEID = 5 + preloginFEDAUTHREQUIRED = 6 + preloginNONCEOPT = 7 + preloginTERMINATOR = 0xff ) const ( @@ -116,6 +120,33 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + fedAuthLibraryLiveIDCompactToken = 0x00 + fedAuthLibrarySecurityToken = 0x01 + fedAuthLibraryADAL = 0x02 + + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + fedAuthADALWorkflowPassword = 0x01 + fedAuthADALWorkflowIntegrated = 0x02 + fedAuthADALWorkflowMSI = 0x03 +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -137,6 +168,7 @@ const ( logParams = 16 logTransaction = 32 logDebug = 64 + logTraffic = 128 ) type columnStruct struct { @@ -238,6 +270,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -246,29 +288,34 @@ const ( ) type login struct { - TDSVersion uint32 - PacketSize uint32 - ClientProgVer uint32 - ClientPID uint32 - ConnectionID uint32 - OptionFlags1 uint8 - OptionFlags2 uint8 - TypeFlags uint8 - OptionFlags3 uint8 - ClientTimeZone int32 - ClientLCID uint32 - HostName string - UserName string - Password string - AppName string - ServerName string - CtlIntName string - Language string - Database string - ClientID [6]byte - SSPI []byte - AtchDBFile string - ChangePassword string + TDSVersion uint32 + PacketSize uint32 + ClientProgVer uint32 + ClientPID uint32 + ConnectionID uint32 + OptionFlags1 uint8 + OptionFlags2 uint8 + TypeFlags uint8 + OptionFlags3 uint8 + ClientTimeZone int32 + ClientLCID uint32 + HostName string + UserName string + Password string + AppName string + ServerName string + CtlIntName string + Language string + Database string + ClientID [6]byte + SSPI []byte + AtchDBFile string + ChangePassword string + FedAuthLibrary byte + FedAuthEcho byte + FedAuthToken string + FedAuthNonce []byte + FedAuthADALWorkflow byte } type loginHeader struct { @@ -295,7 +342,7 @@ type loginHeader struct { ServerNameOffset uint16 ServerNameLength uint16 ExtensionOffset uint16 - ExtensionLenght uint16 + ExtensionLength uint16 CtlIntNameOffset uint16 CtlIntNameLength uint16 LanguageOffset uint16 @@ -357,42 +404,81 @@ func sendLogin(w *tdsBuffer, login login) error { database := str2ucs2(login.Database) atchdbfile := str2ucs2(login.AtchDBFile) changepassword := str2ucs2(login.ChangePassword) + fedauthtoken := str2ucs2(login.FedAuthToken) + + // Determine if any feature extensions need to be written so we know whether + // to include the option flag and offset to the data. + var featureExtLength uint32 + fedAuth := login.FedAuthLibrary == fedAuthLibrarySecurityToken || login.FedAuthLibrary == fedAuthLibraryADAL + if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + // Each feature extension record is 1 byte to indicate the type of the + // feature extension, four bytes for the data size, then the data size. + // For the SecurityToken data, the size is one byte to indicate that + // it's the SecurityToken library, four bytes for the token length, + // then the token bytes. + featureExtLength += uint32(1 + 4 + 1 + 4 + len(fedauthtoken)) + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + // In addition to the 1 + 4 bytes for the feature extension header, + // ADAL just requires one byte to indicate the library and one to + // set the workflow (password, integrated or MSI). + featureExtLength += uint32(1 + 4 + 1 + 1) + } + + // If any feature extension records are written, a final single-byte terminator + // record must also be written, and the fExtension flag must be set. + if featureExtLength > 0 { + featureExtLength++ + login.OptionFlags3 |= fExtension + } else { + login.OptionFlags3 &^= fExtension + } + hdr := loginHeader{ - TDSVersion: login.TDSVersion, - PacketSize: login.PacketSize, - ClientProgVer: login.ClientProgVer, - ClientPID: login.ClientPID, - ConnectionID: login.ConnectionID, - OptionFlags1: login.OptionFlags1, - OptionFlags2: login.OptionFlags2, - TypeFlags: login.TypeFlags, - OptionFlags3: login.OptionFlags3, - ClientTimeZone: login.ClientTimeZone, - ClientLCID: login.ClientLCID, - HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), - UserNameLength: uint16(utf8.RuneCountInString(login.UserName)), - PasswordLength: uint16(utf8.RuneCountInString(login.Password)), - AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), - ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), - CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), - LanguageLength: uint16(utf8.RuneCountInString(login.Language)), - DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), - ClientID: login.ClientID, - SSPILength: uint16(len(login.SSPI)), - AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), - ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)), + TDSVersion: login.TDSVersion, + PacketSize: login.PacketSize, + ClientProgVer: login.ClientProgVer, + ClientPID: login.ClientPID, + ConnectionID: login.ConnectionID, + OptionFlags1: login.OptionFlags1, + OptionFlags2: login.OptionFlags2, + TypeFlags: login.TypeFlags, + OptionFlags3: login.OptionFlags3, + ClientTimeZone: login.ClientTimeZone, + ClientLCID: login.ClientLCID, + HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), + AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), + ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), + CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), + LanguageLength: uint16(utf8.RuneCountInString(login.Language)), + DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), + ClientID: login.ClientID, + SSPILength: uint16(len(login.SSPI)), + AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), } offset := uint16(binary.Size(hdr)) hdr.HostNameOffset = offset offset += uint16(len(hostname)) - hdr.UserNameOffset = offset - offset += uint16(len(username)) - hdr.PasswordOffset = offset - offset += uint16(len(password)) + + if !fedAuth { + hdr.UserNameOffset = offset + hdr.UserNameLength = uint16(utf8.RuneCountInString(login.UserName)) + offset += uint16(len(username)) + hdr.PasswordOffset = offset + hdr.PasswordLength = uint16(utf8.RuneCountInString(login.Password)) + offset += uint16(len(password)) + } + hdr.AppNameOffset = offset offset += uint16(len(appname)) hdr.ServerNameOffset = offset offset += uint16(len(servername)) + + if featureExtLength > 0 { + hdr.ExtensionOffset = offset + hdr.ExtensionLength = 4 + offset += hdr.ExtensionLength + } + hdr.CtlIntNameOffset = offset offset += uint16(len(ctlintname)) hdr.LanguageOffset = offset @@ -403,9 +489,14 @@ func sendLogin(w *tdsBuffer, login login) error { offset += uint16(len(login.SSPI)) hdr.AtchDBFileOffset = offset offset += uint16(len(atchdbfile)) - hdr.ChangePasswordOffset = offset - offset += uint16(len(changepassword)) - hdr.Length = uint32(offset) + + if !fedAuth { + hdr.ChangePasswordOffset = offset + hdr.ChangePasswordLength = uint16(utf8.RuneCountInString(login.ChangePassword)) + offset += uint16(len(changepassword)) + } + + hdr.Length = uint32(offset) + featureExtLength var err error err = binary.Write(w, binary.LittleEndian, &hdr) if err != nil { @@ -415,13 +506,16 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(username) - if err != nil { - return err - } - _, err = w.Write(password) - if err != nil { - return err + + if !fedAuth { + _, err = w.Write(username) + if err != nil { + return err + } + _, err = w.Write(password) + if err != nil { + return err + } } _, err = w.Write(appname) if err != nil { @@ -431,6 +525,12 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } + if featureExtLength > 0 { + err = binary.Write(w, binary.LittleEndian, uint32(offset)) + if err != nil { + return err + } + } _, err = w.Write(ctlintname) if err != nil { return err @@ -451,10 +551,61 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(changepassword) + + if !fedAuth { + _, err = w.Write(changepassword) + if err != nil { + return err + } + } + + // Write the feature extension record for federated authentication, if in use + if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+4+len(fedauthtoken))) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + binary.Write(w, binary.LittleEndian, uint32(len(fedauthtoken))) + w.Write(fedauthtoken) + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+1)) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + w.WriteByte(login.FedAuthADALWorkflow) + } + // Write the feature extension terminator if any feature extensions are written + if featureExtLength > 0 { + w.WriteByte(featExtTERMINATOR) + } + return w.FinishPacket() +} + +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, login login) (err error) { + fedauthtoken := str2ucs2(login.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(login.FedAuthNonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) if err != nil { - return err + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(login.FedAuthNonce) + if err != nil { + return + } + return w.FinishPacket() } @@ -751,6 +902,10 @@ initiate_connection: return nil, err } + if p.logFlags&logTraffic != 0 { + conn = newConnLogger(conn, "TCP", log) + } + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) @@ -778,6 +933,10 @@ initiate_connection: preloginMARS: {0}, // MARS disabled } + if p.fedAuthLibrary != fedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + err = writePrelogin(outbuf, fields) if err != nil { return nil, err @@ -788,6 +947,21 @@ initiate_connection: return nil, err } + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + var fedAuthEcho byte + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) == 1 { + // We need to be able to echo the value back to the server + fedAuthEcho = fedAuthSupport[0] + } else { + return nil, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + } else if p.fedAuthLibrary != fedAuthLibraryReserved { + return nil, fmt.Errorf("Federated authentication is not supported by the server") + } + encryptBytes, ok := fields[preloginENCRYPTION] if !ok { return nil, fmt.Errorf("Encrypt negotiation failed") @@ -811,6 +985,14 @@ initiate_connection: if p.trustServerCertificate { config.InsecureSkipVerify = true } + if p.tlsKeyLogFile != "" { + if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil { + defer w.Close() + config.KeyLogWriter = w + } else { + return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err) + } + } config.ServerName = p.hostInCertificate // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, @@ -823,7 +1005,11 @@ initiate_connection: tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() passthrough.c = toconn - outbuf.transport = tlsConn + if sess.logFlags&logTraffic != 0 { + outbuf.transport = newConnLogger(tlsConn, "TLS", log) + } else { + outbuf.transport = tlsConn + } if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) } @@ -835,23 +1021,61 @@ initiate_connection: } login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, + TDSVersion: verTDS74, + PacketSize: uint32(outbuf.PackageSize()), + Database: p.database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.workstation, + ServerName: p.host, + AppName: p.appname, + TypeFlags: p.typeFlags, + FedAuthLibrary: p.fedAuthLibrary, + FedAuthEcho: fedAuthEcho, + FedAuthADALWorkflow: p.fedAuthADALWorkflow, } + auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation) - if auth_ok { + if p.fedAuthAccessToken != "" { // accesstoken ignores user/password + if p.logFlags&logDebug != 0 { + log.Println("Using provided access token") + } + login.FedAuthToken = p.fedAuthAccessToken + login.FedAuthLibrary = fedAuthLibrarySecurityToken + } else if login.FedAuthLibrary == fedAuthLibraryReserved && auth_ok { + if p.logFlags&logDebug != 0 { + log.Println("Starting SSPI login") + } login.SSPI, err = auth.InitialBytes() if err != nil { return nil, err } login.OptionFlags2 |= fIntSecurity defer auth.Free() + } else if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + login.FedAuthToken, err = fedAuthGetAccessToken(ctx, "", p.aadTenantID, p, log) + if err != nil { + if p.logFlags&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + if p.logFlags&logDebug != 0 { + log.Println("Successfully obtained service principal token for federated authentication security token library") + } + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + if login.FedAuthADALWorkflow == fedAuthADALWorkflowPassword { + if p.logFlags&logDebug != 0 { + log.Printf("Starting ADAL username/password workflow for user %s", p.user) + } + login.UserName = p.user + login.Password = p.password + } else if login.FedAuthADALWorkflow == fedAuthADALWorkflowMSI { + if p.logFlags&logDebug != 0 { + log.Println("Starting ADAL managed service identity (MSI) workflow") + } + } else { + return nil, fmt.Errorf("Unsupported ADAL workflow type 0x%02x", int(login.FedAuthADALWorkflow)) + } } else { login.UserName = p.user login.Password = p.password @@ -885,6 +1109,20 @@ initiate_connection: } sspi_msg = nil } + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if login.FedAuthLibrary == fedAuthLibraryADAL { + login.FedAuthToken, err = fedAuthGetAccessToken(ctx, token.ServerSPN, token.STSURL, p, log) + if err != nil { + return nil, err + } + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, login) + if err != nil { + return nil, err + } + } case loginAckStruct: success = true sess.loginAck = token diff --git a/tds_test.go b/tds_test.go index 00360d85..546df24e 100644 --- a/tds_test.go +++ b/tds_test.go @@ -70,6 +70,108 @@ func TestSendLogin(t *testing.T) { } } +func TestSendLoginWithFedAuthToken(t *testing.T) { + memBuf := new(MockTransport) + buf := newTdsBuffer(1024, memBuf) + login := login{ + TDSVersion: verTDS74, + PacketSize: 0x1000, + ClientProgVer: 0x01060100, + ClientPID: 100, + ClientTimeZone: -4 * 60, + ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab}, + OptionFlags1: 0xe0, + OptionFlags3: 8, + HostName: "subdev1", + AppName: "appname", + ServerName: "servername", + CtlIntName: "library", + Language: "en", + Database: "database", + ClientLCID: 0x204, + FedAuthLibrary: fedAuthLibrarySecurityToken, + FedAuthToken: "fedauthtoken", + } + err := sendLogin(buf, login) + if err != nil { + t.Error("sendLogin should succeed") + } + ref := []byte{ + 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, + 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, + 0, 94, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, 142, + 0, 4, 0, 146, 0, 7, 0, 160, 0, 2, 0, 164, 0, 8, 0, 18, 52, 86, 120, 144, 171, + 180, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, + 0, 100, 0, 101, 0, 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, + 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, 97, + 0, 109, 0, 101, 0, 180, 0, 0, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, 114, 0, + 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, + 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, 116, + 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} + out := memBuf.Bytes() + if !bytes.Equal(ref, out) { + fmt.Println("Expected:") + fmt.Print(hex.Dump(ref)) + fmt.Println("Returned:") + fmt.Print(hex.Dump(out)) + t.Error("input output don't match") + } +} + +func TestSendLoginWithFedAuthADAL(t *testing.T) { + memBuf := new(MockTransport) + buf := newTdsBuffer(1024, memBuf) + login := login{ + TDSVersion: verTDS74, + PacketSize: 0x1000, + ClientProgVer: 0x01060100, + ClientPID: 100, + ClientTimeZone: -4 * 60, + ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab}, + OptionFlags1: 0xe0, + OptionFlags3: 8, + HostName: "subdev1", + AppName: "appname", + ServerName: "servername", + CtlIntName: "library", + Language: "en", + Database: "database", + ClientLCID: 0x204, + + FedAuthLibrary: fedAuthLibraryADAL, + FedAuthADALWorkflow: fedAuthADALWorkflowPassword, + UserName: "username", + Password: "AADpassword", + } + + err := sendLogin(buf, login) + if err != nil { + t.Error("sendLogin should succeed") + } + ref := []byte{ + 16, 1, 0, 196, 0, 0, 1, 0, 188, 0, 0, 0, 4, 0, 0, 116, 0, 16, + 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, + 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 108, 0, 7, 0, 122, 0, 10, 0, 142, 0, 4, 0, 146, 0, 7, 0, 160, 0, + 2, 0, 164, 0, 8, 0, 18, 52, 86, 120, 144, 171, 180, 0, 0, 0, + 180, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, 0, + 100, 0, 101, 0, 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, + 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, + 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 180, 0, 0, 0, 108, 0, 105, + 0, 98, 0, 114, 0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, + 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 2, 2, 0, 0, + 0, 4, 1, 255} + out := memBuf.Bytes() + t.Logf("out= %+v", out) + if !bytes.Equal(ref, out) { + fmt.Println("Expected:") + fmt.Print(hex.Dump(ref)) + fmt.Println("Returned:") + fmt.Print(hex.Dump(out)) + t.Error("input output don't match") + } +} + func TestSendSqlBatch(t *testing.T) { checkConnStr(t) p, err := parseConnectParams(makeConnStr(t).String()) @@ -176,8 +278,6 @@ func (l testLogger) Println(v ...interface{}) { l.t.Log(v...) } - - func TestConnect(t *testing.T) { checkConnStr(t) SetLogger(testLogger{t}) @@ -413,7 +513,7 @@ func TestSSPIAuth(t *testing.T) { func TestUcs22Str(t *testing.T) { // Test valid input - s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding if err != nil { t.Errorf("ucs22str should not fail for valid ucs2 byte sequence: %s", err) } @@ -429,7 +529,7 @@ func TestUcs22Str(t *testing.T) { } func TestReadUcs2(t *testing.T) { - buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding s, err := readUcs2(buf, 3) if err != nil { t.Errorf("readUcs2 should not fail for valid ucs2 byte sequence: %s", err) @@ -447,7 +547,7 @@ func TestReadUcs2(t *testing.T) { func TestReadUsVarChar(t *testing.T) { // should succeed for valid buffer - buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 + buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 s, err := readUsVarChar(buf) if err != nil { t.Errorf("readUsVarChar should not fail for valid ucs2 byte sequence: %s", err) @@ -487,4 +587,4 @@ func TestReadBVarByte(t *testing.T) { if err == nil { t.Error("readUsVarByte should fail on short buffer, but it didn't") } -} \ No newline at end of file +} diff --git a/token.go b/token.go index 1acac8a5..7e4c28af 100644 --- a/token.go +++ b/token.go @@ -6,31 +6,34 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "strconv" "strings" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte // token ids const ( - tokenReturnStatus token = 121 // 0x79 - tokenColMetadata token = 129 // 0x81 - tokenOrder token = 169 // 0xA9 - tokenError token = 170 // 0xAA - tokenInfo token = 171 // 0xAB - tokenReturnValue token = 0xAC - tokenLoginAck token = 173 // 0xad - tokenRow token = 209 // 0xd1 - tokenNbcRow token = 210 // 0xd2 - tokenEnvChange token = 227 // 0xE3 - tokenSSPI token = 237 // 0xED - tokenDone token = 253 // 0xFD - tokenDoneProc token = 254 - tokenDoneInProc token = 255 + tokenReturnStatus token = 121 // 0x79 + tokenColMetadata token = 129 // 0x81 + tokenOrder token = 169 // 0xA9 + tokenError token = 170 // 0xAA + tokenInfo token = 171 // 0xAB + tokenReturnValue token = 0xAC + tokenLoginAck token = 173 // 0xad + tokenFeatureExtAck token = 174 // 0xAE + tokenRow token = 209 // 0xd1 + tokenNbcRow token = 210 // 0xd2 + tokenEnvChange token = 227 // 0xE3 + tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE + tokenDone token = 253 // 0xFD + tokenDoneProc token = 254 + tokenDoneInProc token = 255 ) // done flags @@ -69,6 +72,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -424,6 +432,73 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + } else if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + } else { + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + if opts[i].fedAuthInfoID == fedAuthInfoSTSURL { + STSURL, err = ucs22str(optData) + } else if opts[i].fedAuthInfoID == fedAuthInfoSPN { + SPN, err = ucs22str(optData) + } else { + err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -447,6 +522,45 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { return res } +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) + } + } + + return ack +} + // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { count := r.uint16() @@ -571,12 +685,18 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus case tokenLoginAck: loginAck := parseLoginAck(sess.buf) ch <- loginAck + case tokenFeatureExtAck: + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order diff --git a/token_string.go b/token_string.go index c075b23b..74389fdd 100644 --- a/token_string.go +++ b/token_string.go @@ -1,29 +1,46 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[tokenReturnStatus-121] + _ = x[tokenColMetadata-129] + _ = x[tokenOrder-169] + _ = x[tokenError-170] + _ = x[tokenInfo-171] + _ = x[tokenReturnValue-172] + _ = x[tokenLoginAck-173] + _ = x[tokenFeatureExtAck-174] + _ = x[tokenRow-209] + _ = x[tokenNbcRow-210] + _ = x[tokenEnvChange-227] + _ = x[tokenSSPI-237] + _ = x[tokenFedAuthInfo-238] + _ = x[tokenDone-253] + _ = x[tokenDoneProc-254] + _ = x[tokenDoneInProc-255] +} const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +49,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/token_test.go b/token_test.go new file mode 100644 index 00000000..3cf24b86 --- /dev/null +++ b/token_test.go @@ -0,0 +1,127 @@ +package mssql + +import ( + "reflect" + "testing" +) + +func Test_parseFedAuthInfo(t *testing.T) { + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/0e4486d6-d407-4962-9803-0c1a4d4d87ce + tokenBytes := []byte{ + 4 + 9 + 9 + 58 + 68, 0, 0, 0, // TokenLength + 2, 0, 0, 0, // CountOfInfoIDs + // FedAuthInfoOpts: + 2, // FedAuthInfoID = SPN + 58, 0, 0, 0, // FedAuthInfoDataLen + 4 + 18, 0, 0, 0, // FedAuthInfoDataOffset + 1, // FedAuthInfoID = STSURL + 68, 0, 0, 0, // FedAuthInfoDataLen + 4 + 18 + 58, 0, 0, 0, // FedAuthInfoDataOffset + + // https://database.windows.net/ + // 58 bytes + 104, 0, 116, 0, 116, 0, 112, 0, 115, 0, 58, 0, 47, 0, 47, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 46, 0, 119, 0, 105, 0, 110, 0, 100, 0, 111, 0, 119, 0, 115, 0, 46, 0, 110, 0, 101, 0, 116, 0, 47, 0, + // https://login.microsoftonline.com/ + // 68 bytes + 104, 0, 116, 0, 116, 0, 112, 0, 115, 0, 58, 0, 47, 0, 47, 0, 108, 0, 111, 0, 103, 0, 105, 0, 110, 0, 46, 0, 109, 0, 105, 0, 99, 0, 114, 0, 111, 0, 115, 0, 111, 0, 102, 0, 116, 0, 111, 0, 110, 0, 108, 0, 105, 0, 110, 0, 101, 0, 46, 0, 99, 0, 111, 0, 109, 0, 47, 0, + } + + memBuf := new(MockTransport) + buf := newTdsBuffer(1024, memBuf) + buf.rbuf = tokenBytes + buf.rsize = len(tokenBytes) + buf.rpos = 0 + + got := parseFedAuthInfo(buf) + want := fedAuthInfoStruct{ + STSURL: "https://login.microsoftonline.com/", + ServerSPN: "https://database.windows.net/", + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("Expected %+v, got %+v", want, got) + } +} + +func Test_parseFeatureExtAck(t *testing.T) { + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a + + testCases := []struct { + name string + tokenBytes []byte + expected map[byte]interface{} + }{ + {"Nonce and Signature", + []byte{ + 0x02, // FeatureId == FEDAUTH + 64, 0, 0, 0, // FeatureAckDataLen + // nonce + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + // sig + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + + 0xff, // terminator + }, + map[byte]interface{}{ + 2: fedAuthAckStruct{ + Nonce: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + Signature: []byte{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + }}, + }, + { + "Nonce only", + []byte{ + 0x02, // FeatureId == FEDAUTH + 32, 0, 0, 0, // FeatureAckDataLen + // nonce + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 0xff, // terminator + }, + map[byte]interface{}{ + 2: fedAuthAckStruct{ + Nonce: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + }}, + }, + { + "Empty", + []byte{ + 0x02, // FeatureId == FEDAUTH + 0, 0, 0, 0, // FeatureAckDataLen + 0xff, // terminator + }, + map[byte]interface{}{ + 2: fedAuthAckStruct{}}, + }, + { + "Ignored", + []byte{ + // this feature should be ignored, go-mssqldb does not handle it + 0x08, // FeatureId == AZURESQLSUPPORT + 1, 0, 0, 0, // FeatureAckDataLen + 0x0, //The server does not support the AZURESQLSUPPORT feature extension. + + 0x02, // FeatureId == FEDAUTH + 0, 0, 0, 0, // FeatureAckDataLen + 0xff, // terminator + }, + map[byte]interface{}{ + 2: fedAuthAckStruct{}}, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + memBuf := new(MockTransport) + buf := newTdsBuffer(1024, memBuf) + buf.rbuf = tt.tokenBytes + buf.rsize = len(tt.tokenBytes) + buf.rpos = 0 + + got := parseFeatureExtAck(buf) + want := tt.expected + + if !reflect.DeepEqual(got, want) { + t.Fatalf("Expected %+v, got %+v", want, got) + } + }) + } +} diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6cf42641..f4c2ddff 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -209,7 +209,8 @@ func TestTVP(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Best datetime2 precision is 100 ns granularity + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRow{ { PBinary: []byte("ccc"), @@ -462,7 +463,8 @@ func TestTVP_WithTag(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Default (and maximum) datetime2 precision is 7 digits or 100ns + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRowWithSkipTag{ { PBinary: []byte("ccc"),