diff --git a/dburl.go b/dburl.go index 8a03d3b..340736f 100644 --- a/dburl.go +++ b/dburl.go @@ -105,16 +105,23 @@ func Parse(urlstr string) (*URL, error) { switch { case !ok: return nil, ErrUnknownDatabaseScheme - case scheme.Driver == "file" && u.Opaque != "": + case scheme.Driver == "file": // determine scheme for file - if typ, err := SchemeType(u.Opaque); err == nil { - return Parse(typ + ":" + buildOpaque(u)) + s := u.opaqueOrPath() + switch { + case u.Transport != "tcp", strings.Index(u.OriginalScheme, "+") != -1: + return nil, ErrInvalidTransportProtocol + case s == "": + return nil, ErrMissingPath + } + if typ, err := SchemeType(s); err == nil { + return Parse(typ + "://" + u.buildOpaque()) } return nil, ErrUnknownFileExtension case !scheme.Opaque && u.Opaque != "": // if scheme does not understand opaque URLs, retry parsing after // building fully qualified URL - return Parse(u.OriginalScheme + "://" + buildOpaque(u)) + return Parse(u.OriginalScheme + "://" + u.buildOpaque()) case scheme.Opaque && u.Opaque == "": // force Opaque u.Opaque, u.Host, u.Path, u.RawPath = u.Host+u.Path, "", "", "" @@ -234,10 +241,35 @@ func (u *URL) Normalize(sep, empty string, cut int) string { return strings.Join(s, sep) } +// buildOpaque builds a opaque path. +func (u *URL) buildOpaque() string { + var up string + if u.User != nil { + up = u.User.String() + "@" + } + var q string + if u.RawQuery != "" { + q = "?" + u.RawQuery + } + var f string + if u.Fragment != "" { + f = "#" + u.Fragment + } + return up + u.opaqueOrPath() + q + f +} + +// opaqueOrPath returns the opaque or path value. +func (u *URL) opaqueOrPath() string { + if u.Opaque != "" { + return u.Opaque + } + return u.Path +} + // SchemeType returns the scheme type for a path. func SchemeType(name string) (string, error) { // try to resolve the path on unix systems - if runtime.GOOS != "windows" /*&& !mode(name).IsRegular()*/ { + if runtime.GOOS != "windows" { if typ, ok := resolveType(name); ok { return typ, nil } @@ -315,21 +347,13 @@ var OpenFile = func(name string) (fs.File, error) { return f, nil } -// buildOpaque builds a opaque path from u. -func buildOpaque(u *URL) string { - var q string - if u.RawQuery != "" { - q = "?" + u.RawQuery - } - var f string - if u.Fragment != "" { - f = "#" + u.Fragment - } - return u.Opaque + q + f -} - // resolveType tries to resolve a path to a Unix domain socket or directory. func resolveType(s string) (string, bool) { + if i := strings.LastIndex(s, "?"); i != -1 { + if _, err := Stat(s[:i]); err == nil { + s = s[:i] + } + } dir := s for dir != "" && dir != "/" && dir != "." { // chop off :4444 port diff --git a/dburl_test.go b/dburl_test.go index 836655e..fa82bd0 100644 --- a/dburl_test.go +++ b/dburl_test.go @@ -816,10 +816,51 @@ func TestParse(t *testing.T) { `fake.dk`, ``, }, + { + `file:/var/run/mysqld/mysqld.sock/mydb?timeout=90`, + `mysql`, + `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `file:/var/run/postgresql`, + `postgres`, + `host=/var/run/postgresql`, + `/var/run/postgresql`, + }, + { + `file:/var/run/postgresql:6666/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql port=6666`, + `/var/run/postgresql`, + }, + { + `file:/var/run/postgresql/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql`, + `/var/run/postgresql`, + }, + { + `file:/var/run/postgresql:7777`, + `postgres`, + `host=/var/run/postgresql port=7777`, + `/var/run/postgresql`, + }, + { + `file://user:pass@/var/run/postgresql/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql password=pass user=user`, + `/var/run/postgresql`, + }, } + m := make(map[string]bool) for i, tt := range tests { test := tt t.Run(strconv.Itoa(i), func(t *testing.T) { + if _, ok := m[test.s]; ok { + t.Fatalf("%s is already tested", test.s) + } + m[test.s] = true testParse(t, test.s, test.d, test.exp, test.path) }) } diff --git a/scheme.go b/scheme.go index 034020d..9fd8033 100644 --- a/scheme.go +++ b/scheme.go @@ -512,7 +512,10 @@ func SchemeDriverAndAliases(name string) (string, []string) { // ShortAlias returns the short alias for the scheme name. func ShortAlias(name string) string { - return schemeMap[name].Aliases[0] + if scheme, ok := schemeMap[name]; ok { + return scheme.Aliases[0] + } + return "" } // isSqlite3Header returns true when the passed header is empty or starts with