diff --git a/dburl.go b/dburl.go index f6e13b8..8a03d3b 100644 --- a/dburl.go +++ b/dburl.go @@ -11,8 +11,12 @@ package dburl import ( "database/sql" + "io/fs" "net/url" "os" + "path" + "path/filepath" + "runtime" "strings" ) @@ -78,6 +82,9 @@ func Parse(urlstr string) (*URL, error) { case err != nil: return nil, err case v.Scheme == "": + if typ, err := SchemeType(urlstr); err == nil { + return Parse(typ + ":" + urlstr) + } return nil, ErrInvalidDatabaseScheme } // create url @@ -95,44 +102,31 @@ func Parse(urlstr string) (*URL, error) { } // get dsn generator scheme, ok := schemeMap[u.Scheme] - if !ok { + switch { + case !ok: return nil, ErrUnknownDatabaseScheme - } - // load real scheme for file: - if scheme.Driver == "file" { - typ, err := SchemeType(u.Opaque) - if err == nil { - if s, ok := schemeMap[typ]; ok { - scheme = s - } - } - } - // if scheme does not understand opaque URLs, retry parsing after building - // fully qualified URL - if !scheme.Opaque && u.Opaque != "" { - var q string - if u.RawQuery != "" { - q = "?" + u.RawQuery - } - var f string - if u.Fragment != "" { - f = "#" + u.Fragment + case scheme.Driver == "file" && u.Opaque != "": + // determine scheme for file + if typ, err := SchemeType(u.Opaque); err == nil { + return Parse(typ + ":" + buildOpaque(u)) } - return Parse(u.OriginalScheme + "://" + u.Opaque + q + f) - } - if scheme.Opaque && u.Opaque == "" { + return nil, ErrUnknownFileExtension + case !scheme.Opaque && u.Opaque != "": + // if scheme does not understand opaque URLs, retry parsing after + // building fully qualified URL + return Parse(u.OriginalScheme + "://" + buildOpaque(u)) + case scheme.Opaque && u.Opaque == "": // force Opaque u.Opaque, u.Host, u.Path, u.RawPath = u.Host+u.Path, "", "", "" - } else if u.Host == "." || (u.Host == "" && strings.TrimPrefix(u.Path, "/") != "") { + case u.Host == ".", u.Host == "" && strings.TrimPrefix(u.Path, "/") != "": // force unix proto u.Transport = "unix" } - // check proto + // check transport if checkTransport || u.Transport != "tcp" { - if scheme.Transport == TransportNone { - return nil, ErrInvalidTransportProtocol - } switch { + case scheme.Transport == TransportNone: + return nil, ErrInvalidTransportProtocol case scheme.Transport&TransportAny != 0 && u.Transport != "", scheme.Transport&TransportTCP != 0 && u.Transport == "tcp", scheme.Transport&TransportUDP != 0 && u.Transport == "udp", @@ -240,23 +234,36 @@ func (u *URL) Normalize(sep, empty string, cut int) string { return strings.Join(s, sep) } -// SchemeType returns the scheme type for a file on disk. +// SchemeType returns the scheme type for a path. func SchemeType(name string) (string, error) { - f, err := os.OpenFile(name, os.O_RDONLY, 0) - if err != nil { - return "", err + // try to resolve the path on unix systems + if runtime.GOOS != "windows" /*&& !mode(name).IsRegular()*/ { + if typ, ok := resolveType(name); ok { + return typ, nil + } } - defer f.Close() - buf := make([]byte, 128) - if _, err := f.Read(buf); err != nil { - return "", err + if f, err := OpenFile(name); err == nil { + defer f.Close() + // file exists, match header + buf := make([]byte, 64) + if n, _ := f.Read(buf); n == 0 { + return "sqlite3", nil + } + for _, typ := range fileTypes { + if typ.f(buf) { + return typ.driver, nil + } + } + return "", ErrUnknownFileHeader } - for _, header := range headerTypes { - if header.f(buf) { - return header.driver, nil + // doesn't exist, match file extension + ext := filepath.Ext(name) + for _, typ := range fileTypes { + if typ.ext.MatchString(ext) { + return typ.driver, nil } } - return "", ErrUnknownFileHeader + return "", ErrUnknownFileExtension } // Error is an error. @@ -275,6 +282,8 @@ const ( ErrUnknownDatabaseScheme Error = "unknown database scheme" // ErrUnknownFileHeader is the unknown file header error. ErrUnknownFileHeader Error = "unknown file header" + // ErrUnknownFileExtension is the unknown file extension error. + ErrUnknownFileExtension Error = "unknown file extension" // ErrInvalidTransportProtocol is the invalid transport protocol error. ErrInvalidTransportProtocol Error = "invalid transport protocol" // ErrRelativePathNotSupported is the relative paths not supported error. @@ -286,3 +295,106 @@ const ( // ErrMissingUser is the missing user error. ErrMissingUser Error = "missing user" ) + +// Stat is the default stat func. +// +// Used internally to stat files, and used when generating the DSNs for +// postgres://, mysql://, file:// schemes, and opaque [URL]'s. +var Stat = func(name string) (fs.FileInfo, error) { + return fs.Stat(os.DirFS(filepath.Dir(name)), filepath.Base(name)) +} + +// OpenFile is the default open file func. +// +// Used internally to read file headers. +var OpenFile = func(name string) (fs.File, error) { + f, err := os.OpenFile(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + return f, nil +} + +// buildOpaque builds a opaque path from u. +func buildOpaque(u *URL) string { + var q string + if u.RawQuery != "" { + q = "?" + u.RawQuery + } + var f string + if u.Fragment != "" { + f = "#" + u.Fragment + } + return u.Opaque + q + f +} + +// resolveType tries to resolve a path to a Unix domain socket or directory. +func resolveType(s string) (string, bool) { + dir := s + for dir != "" && dir != "/" && dir != "." { + // chop off :4444 port + i, j := strings.LastIndex(dir, ":"), strings.LastIndex(dir, "/") + if i != -1 && i > j { + dir = dir[:i] + } + switch fi, err := Stat(dir); { + case err == nil && fi.IsDir(): + return "postgres", true + case err == nil && fi.Mode()&fs.ModeSocket != 0: + return "mysql", true + case err == nil: + return "", false + } + if j != -1 { + dir = dir[:j] + } else { + dir = "" + } + } + return "", false +} + +// resolveSocket tries to resolve a path to a Unix domain socket based on the +// form "/path/to/socket/dbname" returning either the original path and the +// empty string, or the components "/path/to/socket" and "dbname", when +// /path/to/socket/dbname is reported by Stat as a socket. +func resolveSocket(s string) (string, string) { + dir, dbname := s, "" + for dir != "" && dir != "/" && dir != "." { + if mode(dir)&fs.ModeSocket != 0 { + return dir, dbname + } + dir, dbname = path.Dir(dir), path.Base(dir) + } + return s, "" +} + +// resolveDir resolves a directory with a :port list. +func resolveDir(s string) (string, string, string) { + dir := s + for dir != "" && dir != "/" && dir != "." { + port := "" + i, j := strings.LastIndex(dir, ":"), strings.LastIndex(dir, "/") + if i != -1 && i > j { + port, dir = dir[i+1:], dir[:i] + } + if mode(dir)&fs.ModeDir != 0 { + dbname := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(s, dir), ":"+port), "/") + return dir, port, dbname + } + if j != -1 { + dir = dir[:j] + } else { + dir = "" + } + } + return s, "", "" +} + +// mode returns the mode of the path. +func mode(s string) os.FileMode { + if fi, err := Stat(s); err == nil { + return fi.Mode() + } + return 0 +} diff --git a/dburl_test.go b/dburl_test.go index 56c908b..836655e 100644 --- a/dburl_test.go +++ b/dburl_test.go @@ -1,33 +1,14 @@ package dburl import ( + "errors" "io/fs" "os" + "strconv" "testing" "time" ) -type stat fs.FileMode - -func (mode stat) Name() string { return "" } -func (mode stat) Size() int64 { return 1 } -func (mode stat) Mode() fs.FileMode { return fs.FileMode(mode) } -func (mode stat) ModTime() time.Time { return time.Now() } -func (mode stat) IsDir() bool { return fs.FileMode(mode)&fs.ModeDir != 0 } -func (mode stat) Sys() interface{} { return nil } - -func init() { - Stat = func(name string) (fs.FileInfo, error) { - switch name { - case "/var/run/postgresql": - return stat(fs.ModeDir), nil - case "/var/run/mysqld/mysqld.sock": - return stat(fs.ModeSocket), nil - } - return nil, fs.ErrNotExist - } -} - func TestBadParse(t *testing.T) { tests := []struct { s string @@ -75,16 +56,24 @@ func TestBadParse(t *testing.T) { {`tablestore+tcp://`, ErrInvalidTransportProtocol}, {`bend://`, ErrMissingHost}, {`databend://`, ErrMissingHost}, + {`unknown_file.ext3`, ErrInvalidDatabaseScheme}, } - for i, test := range tests { - _, err := Parse(test.s) - if err == nil { - t.Errorf("test %d expected error parsing %q", i, test.s) - continue - } - if err != test.exp { - t.Errorf("test %d expected error parsing %q: expected: %v got: %v", i, test.s, test.exp, err) - } + for i, tt := range tests { + test := tt + t.Run(strconv.Itoa(i), func(t *testing.T) { + testBadParse(t, test.s, test.exp) + }) + } +} + +func testBadParse(t *testing.T, s string, exp error) { + t.Helper() + _, err := Parse(s) + switch { + case err == nil: + t.Errorf("%q expected error nil error, got: %v", s, err) + case !errors.Is(err, exp): + t.Errorf("%q expected error %v, got: %v", s, exp, err) } } @@ -95,140 +84,825 @@ func TestParse(t *testing.T) { exp string path string }{ - {`pg:`, `postgres`, ``, ``}, - {`pg://`, `postgres`, ``, ``}, - {`pg:user:pass@localhost/booktest`, `postgres`, `dbname=booktest host=localhost password=pass user=user`, ``}, - {`pg:/var/run/postgresql`, `postgres`, `host=/var/run/postgresql`, `/var/run/postgresql`}, - {`pg:/var/run/postgresql:6666/mydb`, `postgres`, `dbname=mydb host=/var/run/postgresql port=6666`, `/var/run/postgresql`}, - {`pg:/var/run/postgresql/mydb`, `postgres`, `dbname=mydb host=/var/run/postgresql`, `/var/run/postgresql`}, - {`pg:/var/run/postgresql:7777`, `postgres`, `host=/var/run/postgresql port=7777`, `/var/run/postgresql`}, - {`pg+unix:/var/run/postgresql:4444/booktest`, `postgres`, `dbname=booktest host=/var/run/postgresql port=4444`, `/var/run/postgresql`}, - {`pg:user:pass@/var/run/postgresql/mydb`, `postgres`, `dbname=mydb host=/var/run/postgresql password=pass user=user`, `/var/run/postgresql`}, - {`pg:user:pass@/really/bad/path`, `postgres`, `host=/really/bad/path password=pass user=user`, ``}, - {`my:`, `mysql`, `tcp(localhost:3306)/`, ``}, // 10 - {`my://`, `mysql`, `tcp(localhost:3306)/`, ``}, - {`my:booktest:booktest@localhost/booktest`, `mysql`, `booktest:booktest@tcp(localhost:3306)/booktest`, ``}, - {`my:/var/run/mysqld/mysqld.sock/mydb?timeout=90`, `mysql`, `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, `/var/run/mysqld/mysqld.sock`}, - {`my:///var/run/mysqld/mysqld.sock/mydb?timeout=90`, `mysql`, `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, `/var/run/mysqld/mysqld.sock`}, - {`my+unix:user:pass@mysqld.sock?timeout=90`, `mysql`, `user:pass@unix(mysqld.sock)/?timeout=90`, ``}, - {`my:./path/to/socket`, `mysql`, `unix(path/to/socket)/`, ``}, - {`my+unix:./path/to/socket`, `mysql`, `unix(path/to/socket)/`, ``}, - {`mymy:`, `mymysql`, `tcp:localhost:3306*//`, ``}, // 18 - {`mymy://`, `mymysql`, `tcp:localhost:3306*//`, ``}, - {`mymy:user:pass@localhost/booktest`, `mymysql`, `tcp:localhost:3306*booktest/user/pass`, ``}, - {`mymy:/var/run/mysqld/mysqld.sock/mydb?timeout=90&test=true`, `mymysql`, `unix:/var/run/mysqld/mysqld.sock,test,timeout=90*mydb`, `/var/run/mysqld/mysqld.sock`}, - {`mymy:///var/run/mysqld/mysqld.sock/mydb?timeout=90`, `mymysql`, `unix:/var/run/mysqld/mysqld.sock,timeout=90*mydb`, `/var/run/mysqld/mysqld.sock`}, - {`mymy+unix:user:pass@mysqld.sock?timeout=90`, `mymysql`, `unix:mysqld.sock,timeout=90*/user/pass`, ``}, - {`mymy:./path/to/socket`, `mymysql`, `unix:path/to/socket*//`, ``}, - {`mymy+unix:./path/to/socket`, `mymysql`, `unix:path/to/socket*//`, ``}, - {`mssql://`, `sqlserver`, `sqlserver://localhost`, ``}, // 26 - {`mssql://user:pass@localhost/dbname`, `sqlserver`, `sqlserver://user:pass@localhost/?database=dbname`, ``}, - {`mssql://user@localhost/service/dbname`, `sqlserver`, `sqlserver://user@localhost/service?database=dbname`, ``}, - {`mssql://user:!234%23$@localhost:1580/dbname`, `sqlserver`, `sqlserver://user:%21234%23$@localhost:1580/?database=dbname`, ``}, - {`mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, `azuresql`, `sqlserver://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, ``}, - {`azuresql://user:pass@localhost:100/dbname`, `azuresql`, `sqlserver://user:pass@localhost:100/?database=dbname`, ``}, - {`sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, ``}, - {`azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, `azuresql`, `sqlserver://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, ``}, - { - `adodb://Microsoft.ACE.OLEDB.12.0?Extended+Properties=%22Text%3BHDR%3DNO%3BFMT%3DDelimited%22`, `adodb`, // 30 - `Data Source=.;Extended Properties="Text;HDR=NO;FMT=Delimited";Provider=Microsoft.ACE.OLEDB.12.0`, ``, - }, - { - `adodb://user:pass@Provider.Name:1542/Oracle8i/dbname`, `adodb`, - `Data Source=Oracle8i;Database=dbname;Password=pass;Port=1542;Provider=Provider.Name;User ID=user`, ``, - }, - { - `oo+Postgres+Unicode://user:pass@host:5432/dbname`, `adodb`, - `Provider=MSDASQL.1;Extended Properties="Database=dbname;Driver={Postgres Unicode};PWD=pass;Port=5432;Server=host;UID=user"`, ``, - }, - {`sqlite:///path/to/file.sqlite3`, `sqlite3`, `/path/to/file.sqlite3`, ``}, - {`sq://path/to/file.sqlite3`, `sqlite3`, `path/to/file.sqlite3`, ``}, - {`sq:path/to/file.sqlite3`, `sqlite3`, `path/to/file.sqlite3`, ``}, - {`sq:./path/to/file.sqlite3`, `sqlite3`, `./path/to/file.sqlite3`, ``}, - {`sq://./path/to/file.sqlite3?loc=auto`, `sqlite3`, `./path/to/file.sqlite3?loc=auto`, ``}, - {`sq::memory:?loc=auto`, `sqlite3`, `:memory:?loc=auto`, ``}, - {`sq://:memory:?loc=auto`, `sqlite3`, `:memory:?loc=auto`, ``}, - {`or://user:pass@localhost:3000/sidname`, `oracle`, `oracle://user:pass@localhost:3000/sidname`, ``}, // 41 - {`or://localhost`, `oracle`, `oracle://localhost:1521`, ``}, - {`oracle://user:pass@localhost`, `oracle`, `oracle://user:pass@localhost:1521`, ``}, - {`oracle://user:pass@localhost/service_name/instance_name`, `oracle`, `oracle://user:pass@localhost:1521/service_name/instance_name`, ``}, - {`oracle://user:pass@localhost:2000/xe.oracle.docker`, `oracle`, `oracle://user:pass@localhost:2000/xe.oracle.docker`, ``}, - {`or://username:password@host/ORCL`, `oracle`, `oracle://username:password@host:1521/ORCL`, ``}, - {`odpi://username:password@sales-server:1521/sales.us.acme.com`, `oracle`, `oracle://username:password@sales-server:1521/sales.us.acme.com`, ``}, - {`oracle://username:password@sales-server.us.acme.com/sales.us.oracle.com`, `oracle`, `oracle://username:password@sales-server.us.acme.com:1521/sales.us.oracle.com`, ``}, - {`presto://host:8001/`, `presto`, `http://user@host:8001?catalog=default`, ``}, // 49 - {`presto://host/catalogname/schemaname`, `presto`, `http://user@host:8080?catalog=catalogname&schema=schemaname`, ``}, - {`prs://admin@host/catalogname`, `presto`, `https://admin@host:8443?catalog=catalogname`, ``}, - {`prestodbs://admin:pass@host:9998/catalogname`, `presto`, `https://admin:pass@host:9998?catalog=catalogname`, ``}, - {`ca://host`, `cql`, `host:9042`, ``}, // 53 - {`cassandra://host:9999`, `cql`, `host:9999`, ``}, - {`scy://user@host:9999`, `cql`, `host:9999?username=user`, ``}, - {`scylla://user@host:9999?timeout=1000`, `cql`, `host:9999?timeout=1000&username=user`, ``}, - {`datastax://user:pass@localhost:9999/?timeout=1000`, `cql`, `localhost:9999?password=pass&timeout=1000&username=user`, ``}, - {`ca://user:pass@localhost:9999/dbname?timeout=1000`, `cql`, `localhost:9999?keyspace=dbname&password=pass&timeout=1000&username=user`, ``}, - {`ig://host`, `ignite`, `tcp://host:10800`, ``}, // 59 - {`ignite://host:9999`, `ignite`, `tcp://host:9999`, ``}, - {`gridgain://user@host:9999`, `ignite`, `tcp://host:9999?username=user`, ``}, - {`ig://user@host:9999?timeout=1000`, `ignite`, `tcp://host:9999?timeout=1000&username=user`, ``}, - {`ig://user:pass@localhost:9999/?timeout=1000`, `ignite`, `tcp://localhost:9999?password=pass&timeout=1000&username=user`, ``}, - {`ig://user:pass@localhost:9999/dbname?timeout=1000`, `ignite`, `tcp://localhost:9999/dbname?password=pass&timeout=1000&username=user`, ``}, - {`sf://user@host:9999/dbname/schema?timeout=1000`, `snowflake`, `user@host:9999/dbname/schema?timeout=1000`, ``}, - {`sf://user:pass@localhost:9999/dbname/schema?timeout=1000`, `snowflake`, `user:pass@localhost:9999/dbname/schema?timeout=1000`, ``}, - {`rs://user:pass@amazon.com/dbname`, `postgres`, `postgres://user:pass@amazon.com:5439/dbname`, ``}, // 67 - {`ve://user:pass@vertica-host/dbvertica?tlsmode=server-strict`, `vertica`, `vertica://user:pass@vertica-host:5433/dbvertica?tlsmode=server-strict`, ``}, // 68 - {`moderncsqlite:///path/to/file.sqlite3`, `moderncsqlite`, `/path/to/file.sqlite3`, ``}, // 69 - {`modernsqlite:///path/to/file.sqlite3`, `moderncsqlite`, `/path/to/file.sqlite3`, ``}, - {`mq://path/to/file.sqlite3`, `moderncsqlite`, `path/to/file.sqlite3`, ``}, - {`mq:path/to/file.sqlite3`, `moderncsqlite`, `path/to/file.sqlite3`, ``}, - {`mq:./path/to/file.sqlite3`, `moderncsqlite`, `./path/to/file.sqlite3`, ``}, - {`mq://./path/to/file.sqlite3?loc=auto`, `moderncsqlite`, `./path/to/file.sqlite3?loc=auto`, ``}, - {`mq::memory:?loc=auto`, `moderncsqlite`, `:memory:?loc=auto`, ``}, - {`mq://:memory:?loc=auto`, `moderncsqlite`, `:memory:?loc=auto`, ``}, - {`gr://user:pass@localhost:3000/sidname`, `godror`, `user/pass@//localhost:3000/sidname`, ``}, // 77 - {`gr://localhost`, `godror`, `localhost`, ``}, - {`godror://user:pass@localhost`, `godror`, `user/pass@//localhost`, ``}, - {`godror://user:pass@localhost/service_name/instance_name`, `godror`, `user/pass@//localhost/service_name/instance_name`, ``}, - {`godror://user:pass@localhost:2000/xe.oracle.docker`, `godror`, `user/pass@//localhost:2000/xe.oracle.docker`, ``}, - {`gr://username:password@host/ORCL`, `godror`, `username/password@//host/ORCL`, ``}, - {`gr://username:password@sales-server:1521/sales.us.acme.com`, `godror`, `username/password@//sales-server:1521/sales.us.acme.com`, ``}, - {`godror://username:password@sales-server.us.acme.com/sales.us.oracle.com`, `godror`, `username/password@//sales-server.us.acme.com/sales.us.oracle.com`, ``}, - {`trino://host:8001/`, `trino`, `http://user@host:8001?catalog=default`, ``}, // 85 - {`trino://host/catalogname/schemaname`, `trino`, `http://user@host:8080?catalog=catalogname&schema=schemaname`, ``}, - {`trs://admin@host/catalogname`, `trino`, `https://admin@host:8443?catalog=catalogname`, ``}, - {`pgx://`, `pgx`, `postgres://localhost:5432/`, ``}, - {`ca://`, `cql`, `localhost:9042`, ``}, - {`exa://`, `exasol`, `exa:localhost:8563`, ``}, - {`exa://user:pass@host:1883/dbname?autocommit=1`, `exasol`, `exa:host:1883;autocommit=1;password=pass;schema=dbname;user=user`, ``}, // 91 - {`ots://user:pass@localhost/instance_name`, `ots`, `https://user:pass@localhost/instance_name`, ``}, - {`ots+https://user:pass@localhost/instance_name`, `ots`, `https://user:pass@localhost/instance_name`, ``}, - {`ots+http://user:pass@localhost/instance_name`, `ots`, `http://user:pass@localhost/instance_name`, ``}, - {`tablestore://user:pass@localhost/instance_name`, `ots`, `https://user:pass@localhost/instance_name`, ``}, - {`tablestore+https://user:pass@localhost/instance_name`, `ots`, `https://user:pass@localhost/instance_name`, ``}, - {`tablestore+http://user:pass@localhost/instance_name`, `ots`, `http://user:pass@localhost/instance_name`, ``}, - {`bend://user:pass@localhost/instance_name?sslmode=disabled&warehouse=wh`, `databend`, `bend://user:pass@localhost/instance_name?sslmode=disabled&warehouse=wh`, ``}, - {`databend://user:pass@localhost/instance_name?tenant=tn&warehouse=wh`, `databend`, `databend://user:pass@localhost/instance_name?tenant=tn&warehouse=wh`, ``}, - {`flightsql://user:pass@localhost?timeout=3s&token=foobar&tls=enabled`, `flightsql`, `flightsql://user:pass@localhost?timeout=3s&token=foobar&tls=enabled`, ``}, - {`duckdb:/path/to/foo.db?access_mode=read_only&threads=4`, `duckdb`, `/path/to/foo.db?access_mode=read_only&threads=4`, ``}, - {`dk:///path/to/foo.db?access_mode=read_only&threads=4`, `duckdb`, `/path/to/foo.db?access_mode=read_only&threads=4`, ``}, - {`file:./testdata/test.sqlite3?a=b`, `sqlite3`, `./testdata/test.sqlite3?a=b`, ``}, - {`file:./testdata/test.duckdb?a=b`, `duckdb`, `./testdata/test.duckdb?a=b`, ``}, + { + `pg:`, + `postgres`, + ``, + ``, + }, + { + `pg://`, + `postgres`, + ``, + ``, + }, + { + `pg:user:pass@localhost/booktest`, + `postgres`, + `dbname=booktest host=localhost password=pass user=user`, + ``, + }, + { + `pg:/var/run/postgresql`, + `postgres`, + `host=/var/run/postgresql`, + `/var/run/postgresql`, + }, + { + `pg:/var/run/postgresql:6666/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql port=6666`, + `/var/run/postgresql`, + }, + { + `/var/run/postgresql:6666/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql port=6666`, + `/var/run/postgresql`, + }, + { + `pg:/var/run/postgresql/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql`, + `/var/run/postgresql`, + }, + { + `/var/run/postgresql/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql`, + `/var/run/postgresql`, + }, + { + `pg:/var/run/postgresql:7777`, + `postgres`, + `host=/var/run/postgresql port=7777`, + `/var/run/postgresql`, + }, + { + `pg+unix:/var/run/postgresql:4444/booktest`, + `postgres`, + `dbname=booktest host=/var/run/postgresql port=4444`, + `/var/run/postgresql`, + }, + { + `/var/run/postgresql:7777`, + `postgres`, + `host=/var/run/postgresql port=7777`, + `/var/run/postgresql`, + }, + { + `pg:user:pass@/var/run/postgresql/mydb`, + `postgres`, + `dbname=mydb host=/var/run/postgresql password=pass user=user`, + `/var/run/postgresql`, + }, + { + `pg:user:pass@/really/bad/path`, + `postgres`, + `host=/really/bad/path password=pass user=user`, + ``, + }, + { + `my:`, + `mysql`, + `tcp(localhost:3306)/`, + ``, + }, + { + `my://`, + `mysql`, + `tcp(localhost:3306)/`, + ``, + }, + { + `my:booktest:booktest@localhost/booktest`, + `mysql`, + `booktest:booktest@tcp(localhost:3306)/booktest`, + ``, + }, + { + `my:/var/run/mysqld/mysqld.sock/mydb?timeout=90`, + `mysql`, + `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `/var/run/mysqld/mysqld.sock/mydb?timeout=90`, + `mysql`, + `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `my:///var/run/mysqld/mysqld.sock/mydb?timeout=90`, + `mysql`, + `unix(/var/run/mysqld/mysqld.sock)/mydb?timeout=90`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `my+unix:user:pass@mysqld.sock?timeout=90`, + `mysql`, + `user:pass@unix(mysqld.sock)/?timeout=90`, + ``, + }, + { + `my:./path/to/socket`, + `mysql`, + `unix(path/to/socket)/`, + ``, + }, + { + `my+unix:./path/to/socket`, + `mysql`, + `unix(path/to/socket)/`, + ``, + }, + { + `mymy:`, + `mymysql`, + `tcp:localhost:3306*//`, + ``, + }, + { + `mymy://`, + `mymysql`, + `tcp:localhost:3306*//`, + ``, + }, + { + `mymy:user:pass@localhost/booktest`, + `mymysql`, + `tcp:localhost:3306*booktest/user/pass`, + ``, + }, + { + `mymy:/var/run/mysqld/mysqld.sock/mydb?timeout=90&test=true`, + `mymysql`, + `unix:/var/run/mysqld/mysqld.sock,test,timeout=90*mydb`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `mymy:///var/run/mysqld/mysqld.sock/mydb?timeout=90`, + `mymysql`, + `unix:/var/run/mysqld/mysqld.sock,timeout=90*mydb`, + `/var/run/mysqld/mysqld.sock`, + }, + { + `mymy+unix:user:pass@mysqld.sock?timeout=90`, + `mymysql`, + `unix:mysqld.sock,timeout=90*/user/pass`, + ``, + }, + { + `mymy:./path/to/socket`, + `mymysql`, + `unix:path/to/socket*//`, + ``, + }, + { + `mymy+unix:./path/to/socket`, + `mymysql`, + `unix:path/to/socket*//`, + ``, + }, + { + `mssql://`, + `sqlserver`, + `sqlserver://localhost`, + ``, + }, + { + `mssql://user:pass@localhost/dbname`, + `sqlserver`, + `sqlserver://user:pass@localhost/?database=dbname`, + ``, + }, + { + `mssql://user@localhost/service/dbname`, + `sqlserver`, + `sqlserver://user@localhost/service?database=dbname`, + ``, + }, + { + `mssql://user:!234%23$@localhost:1580/dbname`, + `sqlserver`, + `sqlserver://user:%21234%23$@localhost:1580/?database=dbname`, + ``, + }, + { + `mssql://user:!234%23$@localhost:1580/service/dbname?fedauth=true`, + `azuresql`, + `sqlserver://user:%21234%23$@localhost:1580/service?database=dbname&fedauth=true`, + ``, + }, + { + `azuresql://user:pass@localhost:100/dbname`, + `azuresql`, + `sqlserver://user:pass@localhost:100/?database=dbname`, + ``, + }, + { + `sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, + `azuresql`, + `sqlserver://xxx.database.windows.net?database=xxx&fedauth=ActiveDirectoryMSI`, + ``, + }, + { + `azuresql://xxx.database.windows.net/dbname?fedauth=ActiveDirectoryMSI`, + `azuresql`, + `sqlserver://xxx.database.windows.net/?database=dbname&fedauth=ActiveDirectoryMSI`, + ``, + }, + { + `adodb://Microsoft.ACE.OLEDB.12.0?Extended+Properties=%22Text%3BHDR%3DNO%3BFMT%3DDelimited%22`, + `adodb`, + `Data Source=.;Extended Properties="Text;HDR=NO;FMT=Delimited";Provider=Microsoft.ACE.OLEDB.12.0`, + ``, + }, + { + `adodb://user:pass@Provider.Name:1542/Oracle8i/dbname`, + `adodb`, + `Data Source=Oracle8i;Database=dbname;Password=pass;Port=1542;Provider=Provider.Name;User ID=user`, + ``, + }, + { + `oo+Postgres+Unicode://user:pass@host:5432/dbname`, + `adodb`, + `Provider=MSDASQL.1;Extended Properties="Database=dbname;Driver={Postgres Unicode};PWD=pass;Port=5432;Server=host;UID=user"`, + ``, + }, + { + `sqlite:///path/to/file.sqlite3`, + `sqlite3`, + `/path/to/file.sqlite3`, + ``, + }, + { + `sq://path/to/file.sqlite3`, + `sqlite3`, + `path/to/file.sqlite3`, + ``, + }, + { + `sq:path/to/file.sqlite3`, + `sqlite3`, + `path/to/file.sqlite3`, + ``, + }, + { + `sq:./path/to/file.sqlite3`, + `sqlite3`, + `./path/to/file.sqlite3`, + ``, + }, + { + `sq://./path/to/file.sqlite3?loc=auto`, + `sqlite3`, + `./path/to/file.sqlite3?loc=auto`, + ``, + }, + { + `sq::memory:?loc=auto`, + `sqlite3`, + `:memory:?loc=auto`, + ``, + }, + { + `sq://:memory:?loc=auto`, + `sqlite3`, + `:memory:?loc=auto`, + ``, + }, + { + `or://user:pass@localhost:3000/sidname`, + `oracle`, + `oracle://user:pass@localhost:3000/sidname`, + ``, + }, + { + `or://localhost`, + `oracle`, + `oracle://localhost:1521`, + ``, + }, + { + `oracle://user:pass@localhost`, + `oracle`, + `oracle://user:pass@localhost:1521`, + ``, + }, + { + `oracle://user:pass@localhost/service_name/instance_name`, + `oracle`, + `oracle://user:pass@localhost:1521/service_name/instance_name`, + ``, + }, + { + `oracle://user:pass@localhost:2000/xe.oracle.docker`, + `oracle`, + `oracle://user:pass@localhost:2000/xe.oracle.docker`, + ``, + }, + { + `or://username:password@host/ORCL`, + `oracle`, + `oracle://username:password@host:1521/ORCL`, + ``, + }, + { + `odpi://username:password@sales-server:1521/sales.us.acme.com`, + `oracle`, + `oracle://username:password@sales-server:1521/sales.us.acme.com`, + ``, + }, + { + `oracle://username:password@sales-server.us.acme.com/sales.us.oracle.com`, + `oracle`, + `oracle://username:password@sales-server.us.acme.com:1521/sales.us.oracle.com`, + ``, + }, + { + `presto://host:8001/`, + `presto`, + `http://user@host:8001?catalog=default`, + ``, + }, + { + `presto://host/catalogname/schemaname`, + `presto`, + `http://user@host:8080?catalog=catalogname&schema=schemaname`, + ``, + }, + { + `prs://admin@host/catalogname`, + `presto`, + `https://admin@host:8443?catalog=catalogname`, + ``, + }, + { + `prestodbs://admin:pass@host:9998/catalogname`, + `presto`, + `https://admin:pass@host:9998?catalog=catalogname`, + ``, + }, + { + `ca://host`, + `cql`, + `host:9042`, + ``, + }, + { + `cassandra://host:9999`, + `cql`, + `host:9999`, + ``, + }, + { + `scy://user@host:9999`, + `cql`, + `host:9999?username=user`, + ``, + }, + { + `scylla://user@host:9999?timeout=1000`, + `cql`, + `host:9999?timeout=1000&username=user`, + ``, + }, + { + `datastax://user:pass@localhost:9999/?timeout=1000`, + `cql`, + `localhost:9999?password=pass&timeout=1000&username=user`, + ``, + }, + { + `ca://user:pass@localhost:9999/dbname?timeout=1000`, + `cql`, + `localhost:9999?keyspace=dbname&password=pass&timeout=1000&username=user`, + ``, + }, + { + `ig://host`, + `ignite`, + `tcp://host:10800`, + ``, + }, + { + `ignite://host:9999`, + `ignite`, + `tcp://host:9999`, + ``, + }, + { + `gridgain://user@host:9999`, + `ignite`, + `tcp://host:9999?username=user`, + ``, + }, + { + `ig://user@host:9999?timeout=1000`, + `ignite`, + `tcp://host:9999?timeout=1000&username=user`, + ``, + }, + { + `ig://user:pass@localhost:9999/?timeout=1000`, + `ignite`, + `tcp://localhost:9999?password=pass&timeout=1000&username=user`, + ``, + }, + { + `ig://user:pass@localhost:9999/dbname?timeout=1000`, + `ignite`, + `tcp://localhost:9999/dbname?password=pass&timeout=1000&username=user`, + ``, + }, + { + `sf://user@host:9999/dbname/schema?timeout=1000`, + `snowflake`, + `user@host:9999/dbname/schema?timeout=1000`, + ``, + }, + { + `sf://user:pass@localhost:9999/dbname/schema?timeout=1000`, + `snowflake`, + `user:pass@localhost:9999/dbname/schema?timeout=1000`, + ``, + }, + { + `rs://user:pass@amazon.com/dbname`, + `postgres`, + `postgres://user:pass@amazon.com:5439/dbname`, + ``, + }, + { + `ve://user:pass@vertica-host/dbvertica?tlsmode=server-strict`, + `vertica`, + `vertica://user:pass@vertica-host:5433/dbvertica?tlsmode=server-strict`, + ``, + }, + { + `moderncsqlite:///path/to/file.sqlite3`, + `moderncsqlite`, + `/path/to/file.sqlite3`, + ``, + }, + { + `modernsqlite:///path/to/file.sqlite3`, + `moderncsqlite`, + `/path/to/file.sqlite3`, + ``, + }, + { + `mq://path/to/file.sqlite3`, + `moderncsqlite`, + `path/to/file.sqlite3`, + ``, + }, + { + `mq:path/to/file.sqlite3`, + `moderncsqlite`, + `path/to/file.sqlite3`, + ``, + }, + { + `mq:./path/to/file.sqlite3`, + `moderncsqlite`, + `./path/to/file.sqlite3`, + ``, + }, + { + `mq://./path/to/file.sqlite3?loc=auto`, + `moderncsqlite`, + `./path/to/file.sqlite3?loc=auto`, + ``, + }, + { + `mq::memory:?loc=auto`, + `moderncsqlite`, + `:memory:?loc=auto`, + ``, + }, + { + `mq://:memory:?loc=auto`, + `moderncsqlite`, + `:memory:?loc=auto`, + ``, + }, + { + `gr://user:pass@localhost:3000/sidname`, + `godror`, + `user/pass@//localhost:3000/sidname`, + ``, + }, + { + `gr://localhost`, + `godror`, + `localhost`, + ``, + }, + { + `godror://user:pass@localhost`, + `godror`, + `user/pass@//localhost`, + ``, + }, + { + `godror://user:pass@localhost/service_name/instance_name`, + `godror`, + `user/pass@//localhost/service_name/instance_name`, + ``, + }, + { + `godror://user:pass@localhost:2000/xe.oracle.docker`, + `godror`, + `user/pass@//localhost:2000/xe.oracle.docker`, + ``, + }, + { + `gr://username:password@host/ORCL`, + `godror`, + `username/password@//host/ORCL`, + ``, + }, + { + `gr://username:password@sales-server:1521/sales.us.acme.com`, + `godror`, + `username/password@//sales-server:1521/sales.us.acme.com`, + ``, + }, + { + `godror://username:password@sales-server.us.acme.com/sales.us.oracle.com`, + `godror`, + `username/password@//sales-server.us.acme.com/sales.us.oracle.com`, + ``, + }, + { + `trino://host:8001/`, + `trino`, + `http://user@host:8001?catalog=default`, + ``, + }, + { + `trino://host/catalogname/schemaname`, + `trino`, + `http://user@host:8080?catalog=catalogname&schema=schemaname`, + ``, + }, + { + `trs://admin@host/catalogname`, + `trino`, + `https://admin@host:8443?catalog=catalogname`, + ``, + }, + { + `pgx://`, + `pgx`, + `postgres://localhost:5432/`, + ``, + }, + { + `ca://`, + `cql`, + `localhost:9042`, + ``, + }, + { + `exa://`, + `exasol`, + `exa:localhost:8563`, + ``, + }, + { + `exa://user:pass@host:1883/dbname?autocommit=1`, + `exasol`, + `exa:host:1883;autocommit=1;password=pass;schema=dbname;user=user`, + ``, + }, + { + `ots://user:pass@localhost/instance_name`, + `ots`, + `https://user:pass@localhost/instance_name`, + ``, + }, + { + `ots+https://user:pass@localhost/instance_name`, + `ots`, + `https://user:pass@localhost/instance_name`, + ``, + }, + { + `ots+http://user:pass@localhost/instance_name`, + `ots`, + `http://user:pass@localhost/instance_name`, + ``, + }, + { + `tablestore://user:pass@localhost/instance_name`, + `ots`, + `https://user:pass@localhost/instance_name`, + ``, + }, + { + `tablestore+https://user:pass@localhost/instance_name`, + `ots`, + `https://user:pass@localhost/instance_name`, + ``, + }, + { + `tablestore+http://user:pass@localhost/instance_name`, + `ots`, + `http://user:pass@localhost/instance_name`, + ``, + }, + { + `bend://user:pass@localhost/instance_name?sslmode=disabled&warehouse=wh`, + `databend`, + `bend://user:pass@localhost/instance_name?sslmode=disabled&warehouse=wh`, + ``, + }, + { + `databend://user:pass@localhost/instance_name?tenant=tn&warehouse=wh`, + `databend`, + `databend://user:pass@localhost/instance_name?tenant=tn&warehouse=wh`, + ``, + }, + { + `flightsql://user:pass@localhost?timeout=3s&token=foobar&tls=enabled`, + `flightsql`, + `flightsql://user:pass@localhost?timeout=3s&token=foobar&tls=enabled`, + ``, + }, + { + `duckdb:/path/to/foo.db?access_mode=read_only&threads=4`, + `duckdb`, + `/path/to/foo.db?access_mode=read_only&threads=4`, + ``, + }, + { + `dk:///path/to/foo.db?access_mode=read_only&threads=4`, + `duckdb`, + `/path/to/foo.db?access_mode=read_only&threads=4`, + ``, + }, + { + `file:./testdata/test.sqlite3?a=b`, + `sqlite3`, + `./testdata/test.sqlite3?a=b`, + ``, + }, + { + `file:./testdata/test.duckdb?a=b`, + `duckdb`, + `./testdata/test.duckdb?a=b`, + ``, + }, + { + `file:__nonexistent__.db`, + `sqlite3`, + `__nonexistent__.db`, + ``, + }, + { + `file:__nonexistent__.sqlite3`, + `sqlite3`, + `__nonexistent__.sqlite3`, + ``, + }, + { + `file:__nonexistent__.duckdb`, + `duckdb`, + `__nonexistent__.duckdb`, + ``, + }, + { + `__nonexistent__.db`, + `sqlite3`, + `__nonexistent__.db`, + ``, + }, + { + `__nonexistent__.sqlite3`, + `sqlite3`, + `__nonexistent__.sqlite3`, + ``, + }, + { + `__nonexistent__.duckdb`, + `duckdb`, + `__nonexistent__.duckdb`, + ``, + }, + { + `file:fake.sqlite3?a=b`, + `sqlite3`, + `fake.sqlite3?a=b`, + ``, + }, + { + `fake.sq`, + `sqlite3`, + `fake.sq`, + ``, + }, + { + `file:fake.duckdb?a=b`, + `duckdb`, + `fake.duckdb?a=b`, + ``, + }, + { + `fake.dk`, + `duckdb`, + `fake.dk`, + ``, + }, + } + for i, tt := range tests { + test := tt + t.Run(strconv.Itoa(i), func(t *testing.T) { + testParse(t, test.s, test.d, test.exp, test.path) + }) + } +} + +func testParse(t *testing.T, s, d, exp, path string) { + t.Helper() + u, err := Parse(s) + switch { + case err != nil: + t.Errorf("%q expected no error, got: %v", s, err) + case u.GoDriver != "" && u.GoDriver != d: + t.Errorf("%q expected go driver %q, got: %q", s, d, u.GoDriver) + case u.GoDriver == "" && u.Driver != d: + t.Errorf("%q expected driver %q, got: %q", s, d, u.Driver) + case u.DSN != exp: + _, err := os.Stat(path) + if path != "" && err != nil && os.IsNotExist(err) { + t.Logf("%q expected dsn %q, got: %q -- ignoring because `%s` does not exist", s, exp, u.DSN, path) + } else { + t.Errorf("%q expected:\n%q\ngot:\n%q", s, exp, u.DSN) + } } - for i, test := range tests { - u, err := Parse(test.s) - switch { - case err != nil: - t.Fatalf("test %d expected no error, got: %v", i, err) - case u.GoDriver != "" && u.GoDriver != test.d: - t.Errorf("test %d expected go driver %q, got: %q", i, test.d, u.GoDriver) - case u.GoDriver == "" && u.Driver != test.d: - t.Errorf("test %d expected driver %q, got: %q", i, test.d, u.Driver) - case u.DSN != test.exp: - _, err := os.Stat(test.path) - if test.path != "" && err != nil && os.IsNotExist(err) { - t.Logf("test %d expected dsn %q, got: %q -- ignoring because `%s` does not exist", i, test.exp, u.DSN, test.path) - } else { - t.Errorf("test %d expected:\n%q\ngot:\n%q", i, test.exp, u.DSN) - } +} + +func init() { + statFile, openFile := Stat, OpenFile + Stat = func(name string) (fs.FileInfo, error) { + if s, ok := newStat(name); ok { + return s, nil } + return statFile(name) } + OpenFile = func(name string) (fs.File, error) { + if s, ok := newStat(name); ok { + return s, nil + } + return openFile(name) + } +} + +type stat struct { + name string + mode fs.FileMode + content string +} + +func newStat(name string) (stat, bool) { + const ( + sqlite3Header = "SQLite format 3\000.........." + duckdbHeader = "12345678DUCK87654321.............." + ) + files := map[string]string{ + "fake.sqlite3": sqlite3Header, + "fake.sq": sqlite3Header, + "fake.duckdb": duckdbHeader, + "fake.dk": duckdbHeader, + } + switch name { + case "/var/run/postgresql": + return stat{name, fs.ModeDir, ""}, true + case "/var/run/mysqld/mysqld.sock": + return stat{name, fs.ModeSocket, ""}, true + case "fake.sqlite3", "fake.sq", "fake.duckdb", "fake.dk": + return stat{name, 0, files[name]}, true + } + return stat{}, false +} + +func (s stat) Name() string { return s.name } +func (s stat) Size() int64 { return int64(len(s.content)) } +func (s stat) Mode() fs.FileMode { return s.mode } +func (s stat) ModTime() time.Time { return time.Now() } +func (s stat) IsDir() bool { return s.mode&fs.ModeDir != 0 } +func (s stat) Sys() interface{} { return nil } +func (s stat) Close() error { return nil } + +func (s stat) Stat() (fs.FileInfo, error) { + return s, nil +} + +func (s stat) Read(b []byte) (int, error) { + v := []byte(s.content) + copy(b, v) + return len(v), nil } diff --git a/dsn.go b/dsn.go index a106fff..f5fe62b 100644 --- a/dsn.go +++ b/dsn.go @@ -2,23 +2,12 @@ package dburl import ( "fmt" - "io/fs" "net/url" - "os" "path" - "path/filepath" "sort" "strings" ) -// Stat is the default stat func. -// -// Used internally to stat files, and is subsequently used when generating the -// DSNs for postgres://, mysql://, and sqlite3:// [URL] schemes. -var Stat = func(name string) (fs.FileInfo, error) { - return fs.Stat(os.DirFS(filepath.Dir(name)), filepath.Base(name)) -} - // GenScheme returns a generator that will generate a scheme based on the // passed scheme DSN. func GenScheme(scheme string) func(*URL) (string, string, error) { @@ -675,52 +664,3 @@ func genOptions(q url.Values, joiner, assign, sep, valSep string, skipWhenEmpty } return "" } - -// resolveSocket tries to resolve a path to a Unix domain socket based on the -// form "/path/to/socket/dbname" returning either the original path and the -// empty string, or the components "/path/to/socket" and "dbname", when -// /path/to/socket/dbname is reported by Stat as a socket. -// -// Used for MySQL DSNs. -func resolveSocket(s string) (string, string) { - dir, dbname := s, "" - for dir != "" && dir != "/" && dir != "." { - if mode(dir)&fs.ModeSocket != 0 { - return dir, dbname - } - dir, dbname = path.Dir(dir), path.Base(dir) - } - return s, "" -} - -// resolveDir resolves a directory with a :port list. -// -// Used for PostgreSQL DSNs. -func resolveDir(s string) (string, string, string) { - dir := s - for dir != "" && dir != "/" && dir != "." { - port := "" - i, j := strings.LastIndex(dir, ":"), strings.LastIndex(dir, "/") - if i != -1 && i > j { - port, dir = dir[i+1:], dir[:i] - } - if mode(dir)&fs.ModeDir != 0 { - rest := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(s, dir), ":"+port), "/") - return dir, port, rest - } - if j != -1 { - dir = dir[:j] - } else { - dir = "" - } - } - return s, "", "" -} - -// mode returns the mode of the path. -func mode(s string) os.FileMode { - if fi, err := Stat(s); err == nil { - return fi.Mode() - } - return 0 -} diff --git a/example_test.go b/example_test.go index 3207500..a2aae2d 100644 --- a/example_test.go +++ b/example_test.go @@ -7,12 +7,8 @@ import ( "github.com/xo/dburl" ) -func Example_parse() { - u, err := dburl.Parse("pg://user:pass@host:1234/dbname") - if err != nil { - log.Fatal(err) - } - db, err := sql.Open(u.Driver, u.DSN) +func Example() { + db, err := dburl.Open("my://user:pass@host:1234/dbname") if err != nil { log.Fatal(err) } @@ -28,8 +24,12 @@ func Example_parse() { } } -func Example_open() { - db, err := dburl.Open("my://user:pass@host:1234/dbname") +func Example_parse() { + u, err := dburl.Parse("pg://user:pass@host:1234/dbname") + if err != nil { + log.Fatal(err) + } + db, err := sql.Open(u.Driver, u.DSN) if err != nil { log.Fatal(err) } diff --git a/scheme.go b/scheme.go index b4dc024..034020d 100644 --- a/scheme.go +++ b/scheme.go @@ -338,8 +338,8 @@ func init() { for _, scheme := range schemes { Register(scheme) } - RegisterHeaderType("duckdb", isDuckdbHeader) - RegisterHeaderType("sqlite3", isSqlite3Header) + RegisterFileType("duckdb", isDuckdbHeader, `(?i)\.duckdb$`) + RegisterFileType("sqlite3", isSqlite3Header, `(?i)\.(db|sqlite|sqlite3)$`) } // schemeMap is the map of registered schemes. @@ -440,28 +440,34 @@ func RegisterAlias(name, alias string) { registerAlias(name, alias, true) } -// headerTypes are registered header recognition funcs. -var headerTypes []headerType +// fileTypes are registered header recognition funcs. +var fileTypes []fileType -// RegisterHeaderType registers a file header recognition func. -func RegisterHeaderType(driver string, f func([]byte) bool) { - headerTypes = append(headerTypes, headerType{ +// RegisterFileType registers a file header recognition func, and extension regexp. +func RegisterFileType(driver string, f func([]byte) bool, ext string) { + extRE, err := regexp.Compile(ext) + if err != nil { + panic(fmt.Sprintf("invalid extension regexp %q: %v", ext, err)) + } + fileTypes = append(fileTypes, fileType{ driver: driver, f: f, + ext: extRE, }) } -// headerType wraps a header recognition func. -type headerType struct { +// fileType wraps file type information. +type fileType struct { driver string f func([]byte) bool + ext *regexp.Regexp } -// HeaderTypes returns the registered header types. -func HeaderTypes() []string { +// FileTypes returns the registered file types. +func FileTypes() []string { var v []string - for _, header := range headerTypes { - v = append(v, header.driver) + for _, typ := range fileTypes { + v = append(v, typ.driver) } return v } @@ -514,7 +520,7 @@ func ShortAlias(name string) string { // // See: https://www.sqlite.org/fileformat.html func isSqlite3Header(buf []byte) bool { - return len(buf) == 0 || bytes.HasPrefix(buf, sqlite3Header) + return bytes.HasPrefix(buf, sqlite3Header) } // sqlite3Header is the sqlite3 header.