diff --git a/conn.go b/conn.go index f313c149..000cad90 100644 --- a/conn.go +++ b/conn.go @@ -1068,6 +1068,8 @@ func isDriverSetting(key string) bool { return true case "fallback_application_name": return true + case "keepalives", "keepalives_interval": + return true case "connect_timeout": return true case "disable_prepared_binary_result": diff --git a/connector.go b/connector.go index d7d47261..cde9fabe 100644 --- a/connector.go +++ b/connector.go @@ -5,8 +5,11 @@ import ( "database/sql/driver" "errors" "fmt" + "net" "os" + "strconv" "strings" + "time" ) // Connector represents a fixed configuration for the pq driver with a given @@ -107,9 +110,41 @@ func NewConnector(dsn string) (*Connector, error) { } // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { + ntw, _ := network(o) + if ntw == "unix" { o["sslmode"] = "disable" } - return &Connector{opts: o, dialer: defaultDialer{}}, nil + var d net.Dialer + if ntw == "tcp" { + d.KeepAlive, err = keepalive(o) + if err != nil { + return nil, err + } + } + + return &Connector{opts: o, dialer: defaultDialer{d}}, nil +} + +// keepalive returns the interval between keep-alive probes controlled by keepalives_interval. +// If zero, keep-alive probes are sent with a default value (see net.Dialer). +// If negative, keep-alive probes are disabled. +// +// The keepalives parameter controls whether client-side TCP keepalives are used. +// The default value is 1, meaning on, but you can change this to 0, meaning off, if keepalives are not wanted. +func keepalive(o values) (time.Duration, error) { + v, ok := o["keepalives"] + if v == "0" { + return -1, nil + } + + if v, ok = o["keepalives_interval"]; !ok { + return 0, nil + } + + keepintvl, err := strconv.ParseInt(v, 10, 0) + if err != nil { + return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %w", err) + } + return time.Duration(keepintvl) * time.Second, nil } diff --git a/connector_test.go b/connector_test.go index 3d2c67b0..054a06ce 100644 --- a/connector_test.go +++ b/connector_test.go @@ -6,7 +6,10 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" + "strconv" "testing" + "time" ) func TestNewConnector_WorksWithOpenDB(t *testing.T) { @@ -65,3 +68,70 @@ func TestNewConnector_Driver(t *testing.T) { } txn.Rollback() } + +func TestNewConnectorKeepalive(t *testing.T) { + c, err := NewConnector("keepalives=1 keepalives_interval=10") + if err != nil { + t.Fatal(err) + } + db := sql.OpenDB(c) + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + txn.Rollback() + + d, _ := c.dialer.(defaultDialer) + want := 10 * time.Second + if want != d.d.KeepAlive { + t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive) + } +} + +func TestKeepalive(t *testing.T) { + var tt = map[string]struct { + input values + want time.Duration + }{ + "keepalives on": {values{"keepalives": "1"}, 0}, + "keepalives on by default": {nil, 0}, + "keepalives off": {values{"keepalives": "0"}, -1}, + "keepalives_interval 5 seconds": {values{"keepalives_interval": "5"}, 5 * time.Second}, + "keepalives_interval default": {values{"keepalives_interval": "0"}, 0}, + "keepalives_interval off": {values{"keepalives_interval": "-1"}, -1 * time.Second}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + got, err := keepalive(tc.input) + if err != nil { + t.Fatal(err) + } + if tc.want != got { + t.Fatalf("expected: %v, got: %v", tc.want, got) + } + }) + } +} + +func TestKeepaliveError(t *testing.T) { + var tt = map[string]struct { + input values + want error + }{ + "keepalives_interval whitespace": {values{"keepalives_interval": " "}, strconv.ErrSyntax}, + "keepalives_interval float": {values{"keepalives_interval": "1.1"}, strconv.ErrSyntax}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + _, err := keepalive(tc.input) + if !errors.Is(err, tc.want) { + t.Fatalf("expected: %v, got: %v", tc.want, err) + } + }) + } +} diff --git a/doc.go b/doc.go index b5718480..246d5ffd 100644 --- a/doc.go +++ b/doc.go @@ -51,6 +51,12 @@ supported: * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) * fallback_application_name - An application_name to fall back to if one isn't provided. + * keepalives - Whether or not to use client-side TCP keepalives + (the default value is 1, meaning on, but you can change this to 0, meaning off) + * keepalives_interval - The number of seconds after which a TCP keepalive message + that is not acknowledged by the server should be retransmitted. + If zero or not specified, keep-alive probes are sent with a default value (see net.Dialer). + If negative, keep-alive probes are disabled. * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. * sslcert - Cert file location. The file must contain PEM encoded data.