diff --git a/dburl.go b/dburl.go index 69962b2..211dd56 100644 --- a/dburl.go +++ b/dburl.go @@ -24,7 +24,11 @@ func Open(urlstr string) (*sql.DB, error) { if err != nil { return nil, err } - return sql.Open(u.Driver, u.DSN) + driver := u.Driver + if u.GoDriver != "" { + driver = u.GoDriver + } + return sql.Open(driver, u.DSN) } // URL wraps the standard [net/url.URL] type, adding OriginalScheme, Transport, @@ -40,8 +44,12 @@ type URL struct { // Driver is the non-aliased SQL driver name that should be used in a call // to sql/Open. Driver string - // Unaliased is the unaliased driver name. - Unaliased string + // GoDriver is the Go SQL driver name to use when opening a connection to + // the database. Used by Microsoft SQL Server's azuresql URLs, as the + // wire-compatible alias style uses a different syntax style. + GoDriver string + // UnaliasedDriver is the unaliased driver name. + UnaliasedDriver string // DSN is the built connection "data source name" that can be used in a // call to sql/Open. DSN string @@ -123,12 +131,12 @@ func Parse(urlstr string) (*URL, error) { } } // set driver - u.Driver, u.Unaliased = scheme.Driver, scheme.Driver + u.Driver, u.UnaliasedDriver = scheme.Driver, scheme.Driver if scheme.Override != "" { u.Driver = scheme.Override } // generate dsn - if u.DSN, err = scheme.Generator(u); err != nil { + if u.DSN, u.GoDriver, err = scheme.Generator(u); err != nil { return nil, err } return u, nil @@ -185,7 +193,7 @@ func (u *URL) Short() string { // Normalize returns the driver, host, port, database, and user name of a URL, // joined with sep, populating blank fields with empty. func (u *URL) Normalize(sep, empty string, cut int) string { - s := []string{u.Unaliased, "", "", "", ""} + s := []string{u.UnaliasedDriver, "", "", "", ""} if u.Transport != "tcp" && u.Transport != "unix" { s[0] += "+" + u.Transport } diff --git a/dburl_test.go b/dburl_test.go index 71bc249..4a9b6a4 100644 --- a/dburl_test.go +++ b/dburl_test.go @@ -124,10 +124,10 @@ func TestParse(t *testing.T) { {`mssql://user:pass@localhost/dbname`, `sqlserver`, `sqlserver://user:pass@localhost/?database=dbname`, ``}, {`mssql://user@localhost/service/dbname`, `sqlserver`, `sqlserver://user@localhost/service?database=dbname`, ``}, {`mssql://user:!234%23$@localhost:1580/dbname`, `sqlserver`, `sqlserver://user:%21234%23$@localhost:1580/?database=dbname`, ``}, - {`mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, `sqlserver`, `azuresql://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, ``}, - {`azuresql://user:pass@localhost:100/dbname`, `sqlserver`, `azuresql://user:pass@localhost:100/?database=dbname`, ``}, - {`sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, `sqlserver`, `azuresql://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, ``}, - {`azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, `sqlserver`, `azuresql://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, ``}, + {`mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, `azuresql`, `sqlserver://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, ``}, + {`azuresql://user:pass@localhost:100/dbname`, `azuresql`, `sqlserver://user:pass@localhost:100/?database=dbname`, ``}, + {`sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, ``}, + {`azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, ``}, { `adodb://Microsoft.ACE.OLEDB.12.0?Extended+Properties=%22Text%3BHDR%3DNO%3BFMT%3DDelimited%22`, `adodb`, // 30 `Data Source=.;Extended Properties="Text;HDR=NO;FMT=Delimited";Provider=Microsoft.ACE.OLEDB.12.0`, ``, @@ -214,7 +214,9 @@ func TestParse(t *testing.T) { switch { case err != nil: t.Fatalf("test %d expected no error, got: %v", i, err) - case u.Driver != test.d: + case u.GoDriver != "" && u.GoDriver != test.d: + t.Errorf("test %d expected go driver %q, got: %q", i, test.d, u.GoDriver) + case u.GoDriver == "" && u.Driver != test.d: t.Errorf("test %d expected driver %q, got: %q", i, test.d, u.Driver) case u.DSN != test.exp: _, err := os.Stat(test.path) diff --git a/dsn.go b/dsn.go index e9a9ce4..9bb0ca2 100644 --- a/dsn.go +++ b/dsn.go @@ -21,8 +21,8 @@ var Stat = func(name string) (fs.FileInfo, error) { // GenScheme returns a func that generates a scheme:// style DSN from the // passed URL. -func GenScheme(scheme string) func(*URL) (string, error) { - return func(u *URL) (string, error) { +func GenScheme(scheme string) func(*URL) (string, string, error) { + return func(u *URL) (string, string, error) { z := &url.URL{ Scheme: scheme, Opaque: u.Opaque, @@ -36,27 +36,27 @@ func GenScheme(scheme string) func(*URL) (string, error) { if z.Host == "" { z.Host = "localhost" } - return z.String(), nil + return z.String(), "", nil } } // GenSchemeTruncate generates a DSN by truncating the scheme://. -func GenSchemeTruncate(u *URL) (string, error) { +func GenSchemeTruncate(u *URL) (string, string, error) { s := u.String() if i := strings.Index(s, "://"); i != -1 { - return s[i+3:], nil + return s[i+3:], "", nil } - return s, nil + return s, "", nil } // GenFromURL returns a func that generates a DSN based on parameters of the // passed URL. -func GenFromURL(urlstr string) func(*URL) (string, error) { +func GenFromURL(urlstr string) func(*URL) (string, string, error) { z, err := url.Parse(urlstr) if err != nil { panic(err) } - return func(u *URL) (string, error) { + return func(u *URL) (string, string, error) { opaque := z.Opaque if u.Opaque != "" { opaque = u.Opaque @@ -101,20 +101,20 @@ func GenFromURL(urlstr string) func(*URL) (string, error) { RawQuery: q.Encode(), Fragment: fragment, } - return y.String(), nil + return y.String(), "", nil } } // GenOpaque generates a opaque file path DSN from the passed URL. -func GenOpaque(u *URL) (string, error) { +func GenOpaque(u *URL) (string, string, error) { if u.Opaque == "" { - return "", ErrMissingPath + return "", "", ErrMissingPath } - return u.Opaque + genQueryOptions(u.Query()), nil + return u.Opaque + genQueryOptions(u.Query()), "", nil } // GenAdodb generates a adodb DSN from the passed URL. -func GenAdodb(u *URL) (string, error) { +func GenAdodb(u *URL) (string, string, error) { // grab data source host, port := u.Hostname(), u.Port() dsname, dbname := strings.TrimPrefix(u.Path, "/"), "" @@ -146,11 +146,11 @@ func GenAdodb(u *URL) (string, error) { } u.hostPortDB = []string{host, port, n} } - return genOptionsOdbc(q, true), nil + return genOptionsOdbc(q, true), "", nil } // GenCassandra generates a cassandra DSN from the passed URL. -func GenCassandra(u *URL) (string, error) { +func GenCassandra(u *URL) (string, string, error) { host, port, dbname := "localhost", "9042", strings.TrimPrefix(u.Path, "/") if h := u.Hostname(); h != "" { host = h @@ -170,11 +170,11 @@ func GenCassandra(u *URL) (string, error) { if dbname != "" { q.Set("keyspace", dbname) } - return host + ":" + port + genQueryOptions(q), nil + return host + ":" + port + genQueryOptions(q), "", nil } // GenCosmos generates a cosmos DSN from the passed URL. -func GenCosmos(u *URL) (string, error) { +func GenCosmos(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") if port != "" { port = ":" + port @@ -183,25 +183,25 @@ func GenCosmos(u *URL) (string, error) { q.Set("AccountEndpoint", "https://"+host+port) // add user/pass if u.User == nil { - return "", ErrMissingUser + return "", "", ErrMissingUser } q.Set("AccountKey", u.User.Username()) if dbname != "" { q.Set("Db", dbname) } - return genOptionsOdbc(q, true), nil + return genOptionsOdbc(q, true), "", nil } // GenDatabend generates a databend DSN from the passed URL. -func GenDatabend(u *URL) (string, error) { +func GenDatabend(u *URL) (string, string, error) { if u.Hostname() == "" { - return "", ErrMissingHost + return "", "", ErrMissingHost } - return u.String(), nil + return u.String(), "", nil } // GenExasol generates a exasol DSN from the passed URL. -func GenExasol(u *URL) (string, error) { +func GenExasol(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") if host == "" { host = "localhost" @@ -218,11 +218,11 @@ func GenExasol(u *URL) (string, error) { pass, _ := u.User.Password() q.Set("password", pass) } - return fmt.Sprintf("exa:%s:%s%s", host, port, genOptions(q, ";", "=", ";", ",", true)), nil + return fmt.Sprintf("exa:%s:%s%s", host, port, genOptions(q, ";", "=", ";", ",", true)), "", nil } // GenFirebird generates a firebird DSN from the passed URL. -func GenFirebird(u *URL) (string, error) { +func GenFirebird(u *URL) (string, string, error) { z := &url.URL{ User: u.User, Host: u.Host, @@ -231,11 +231,11 @@ func GenFirebird(u *URL) (string, error) { RawQuery: u.RawQuery, Fragment: u.Fragment, } - return strings.TrimPrefix(z.String(), "//"), nil + return strings.TrimPrefix(z.String(), "//"), "", nil } // GenGodror generates a godror DSN from the passed URL. -func GenGodror(u *URL) (string, error) { +func GenGodror(u *URL) (string, string, error) { // Easy Connect Naming method enables clients to connect to a database server // without any configuration. Clients use a connect string for a simple TCP/IP // address, which includes a host name and optional port and service name: @@ -265,11 +265,11 @@ func GenGodror(u *URL) (string, error) { if instance != "" { dsn += "/" + instance } - return dsn, nil + return dsn, "", nil } // GenIgnite generates an ignite DSN from the passed URL. -func GenIgnite(u *URL) (string, error) { +func GenIgnite(u *URL) (string, string, error) { host, port, dbname := "localhost", "10800", strings.TrimPrefix(u.Path, "/") if h := u.Hostname(); h != "" { host = h @@ -289,11 +289,11 @@ func GenIgnite(u *URL) (string, error) { if dbname != "" { dbname = "/" + dbname } - return "tcp://" + host + ":" + port + dbname + genQueryOptions(q), nil + return "tcp://" + host + ":" + port + dbname + genQueryOptions(q), "", nil } // GenMymysql generates a mymysql DSN from the passed URL. -func GenMymysql(u *URL) (string, error) { +func GenMymysql(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") // resolve path if u.Transport == "unix" { @@ -332,11 +332,11 @@ func GenMymysql(u *URL) (string, error) { } else if strings.HasSuffix(dsn, "*") { dsn += "//" } - return dsn, nil + return dsn, "", nil } // GenMysql generates a mysql DSN from the passed URL. -func GenMysql(u *URL) (string, error) { +func GenMysql(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") // build dsn var dsn string @@ -374,11 +374,11 @@ func GenMysql(u *URL) (string, error) { } // add proto and database dsn += u.Transport + "(" + host + port + ")" + "/" + dbname - return dsn + genQueryOptions(u.Query()), nil + return dsn + genQueryOptions(u.Query()), "", nil } // GenOdbc generates a odbc DSN from the passed URL. -func GenOdbc(u *URL) (string, error) { +func GenOdbc(u *URL) (string, string, error) { // save host, port, dbname host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") if u.hostPortDB == nil { @@ -410,23 +410,23 @@ func GenOdbc(u *URL) (string, error) { p, _ := u.User.Password() q.Set("PWD", p) } - return genOptionsOdbc(q, true), nil + return genOptionsOdbc(q, true), "", nil } // GenOleodbc generates a oleodbc DSN from the passed URL. -func GenOleodbc(u *URL) (string, error) { - props, err := GenOdbc(u) +func GenOleodbc(u *URL) (string, string, error) { + props, _, err := GenOdbc(u) if err != nil { - return "", nil + return "", "", nil } - return `Provider=MSDASQL.1;Extended Properties="` + props + `"`, nil + return `Provider=MSDASQL.1;Extended Properties="` + props + `"`, "", nil } // GenPostgres generates a postgres DSN from the passed URL. -func GenPostgres(u *URL) (string, error) { +func GenPostgres(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") if host == "." { - return "", ErrRelativePathNotSupported + return "", "", ErrRelativePathNotSupported } // resolve path if u.Transport == "unix" { @@ -450,11 +450,11 @@ func GenPostgres(u *URL) (string, error) { if u.hostPortDB == nil { u.hostPortDB = []string{host, port, dbname} } - return genOptions(q, "", "=", " ", ",", true), nil + return genOptions(q, "", "=", " ", ",", true), "", nil } // GenPresto generates a presto DSN from the passed URL. -func GenPresto(u *URL) (string, error) { +func GenPresto(u *URL) (string, string, error) { z := &url.URL{ Scheme: "http", Opaque: u.Opaque, @@ -496,50 +496,50 @@ func GenPresto(u *URL) (string, error) { q.Set("schema", schema) } z.RawQuery = q.Encode() - return z.String(), nil + return z.String(), "", nil } // GenSnowflake generates a snowflake DSN from the passed URL. -func GenSnowflake(u *URL) (string, error) { +func GenSnowflake(u *URL) (string, string, error) { host, port, dbname := u.Hostname(), u.Port(), strings.TrimPrefix(u.Path, "/") if host == "" { - return "", ErrMissingHost + return "", "", ErrMissingHost } if port != "" { port = ":" + port } // add user/pass if u.User == nil { - return "", ErrMissingUser + return "", "", ErrMissingUser } user := u.User.Username() if pass, _ := u.User.Password(); pass != "" { user += ":" + pass } - return user + "@" + host + port + "/" + dbname + genQueryOptions(u.Query()), nil + return user + "@" + host + port + "/" + dbname + genQueryOptions(u.Query()), "", nil } // GenSpanner generates a spanner DSN from the passed URL. -func GenSpanner(u *URL) (string, error) { +func GenSpanner(u *URL) (string, string, error) { project, instance, dbname := u.Hostname(), "", strings.TrimPrefix(u.Path, "/") if project == "" { - return "", ErrMissingHost + return "", "", ErrMissingHost } i := strings.Index(dbname, "/") if i == -1 { - return "", ErrMissingPath + return "", "", ErrMissingPath } instance, dbname = dbname[:i], dbname[i+1:] if instance == "" || dbname == "" { - return "", ErrMissingPath + return "", "", ErrMissingPath } - return fmt.Sprintf(`projects/%s/instances/%s/databases/%s`, project, instance, dbname), nil + return fmt.Sprintf(`projects/%s/instances/%s/databases/%s`, project, instance, dbname), "", nil } // GenSqlserver generates a sqlserver DSN from the passed URL. -func GenSqlserver(u *URL) (string, error) { +func GenSqlserver(u *URL) (string, string, error) { z := &url.URL{ - Scheme: sqlserverDriver(u), + Scheme: "sqlserver", Opaque: u.Opaque, User: u.User, Host: u.Host, @@ -550,36 +550,31 @@ func GenSqlserver(u *URL) (string, error) { if z.Host == "" { z.Host = "localhost" } + driver := "sqlserver" + if strings.Contains(strings.ToLower(u.Scheme), "azuresql") || + u.Query().Get("fedauth") != "" { + driver = "azuresql" + } v := strings.Split(strings.TrimPrefix(z.Path, "/"), "/") if n, q := len(v), z.Query(); !q.Has("database") && n != 0 && len(v[0]) != 0 { q.Set("database", v[n-1]) z.Path, z.RawQuery = "/"+strings.Join(v[:n-1], "/"), q.Encode() } - return z.String(), nil -} - -// sqlserverDriver returns the driver used for a Microsoft SQL Server URL. -func sqlserverDriver(u *URL) string { - switch { - case u.Query().Has("fedauth"), - strings.Contains(strings.ToLower(u.OriginalScheme), "azuresql"): - return "azuresql" - } - return "sqlserver" + return z.String(), driver, nil } // GenTableStore generates a tablestore DSN from the passed URL. -func GenTableStore(u *URL) (string, error) { +func GenTableStore(u *URL) (string, string, error) { var transport string splits := strings.Split(u.OriginalScheme, "+") if len(splits) == 0 { - return "", ErrInvalidDatabaseScheme + return "", "", ErrInvalidDatabaseScheme } else if len(splits) == 1 || splits[1] == "https" { transport = "https" } else if splits[1] == "http" { transport = "http" } else { - return "", ErrInvalidTransportProtocol + return "", "", ErrInvalidTransportProtocol } z := &url.URL{ Scheme: transport, @@ -591,11 +586,11 @@ func GenTableStore(u *URL) (string, error) { RawQuery: u.RawQuery, Fragment: u.Fragment, } - return z.String(), nil + return z.String(), "", nil } // GenVoltdb generates a voltdb DSN from the passed URL. -func GenVoltdb(u *URL) (string, error) { +func GenVoltdb(u *URL) (string, string, error) { host, port := "localhost", "21212" if h := u.Hostname(); h != "" { host = h @@ -603,7 +598,7 @@ func GenVoltdb(u *URL) (string, error) { if p := u.Port(); p != "" { port = p } - return host + ":" + port, nil + return host + ":" + port, "", nil } // convertOptions converts an option value based on name, value pairs. diff --git a/scheme.go b/scheme.go index a1d0b7c..3f0c985 100644 --- a/scheme.go +++ b/scheme.go @@ -32,7 +32,7 @@ type Scheme struct { // URL information. // // Note: this func should not modify the passed URL. - Generator func(*URL) (string, error) + Generator func(*URL) (string, string, error) // Transport are allowed protocol transport types for the scheme. Transport Transport // Opaque toggles Parse to not re-process URLs with an "opaque" component.