Skip to content

Commit

Permalink
add support custom tls config
Browse files Browse the repository at this point in the history
  • Loading branch information
daqingshu committed Nov 26, 2022
1 parent d5affd5 commit c90ae64
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 9 deletions.
41 changes: 41 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,47 @@ Valid values for sslmode are:
the server was signed by a trusted CA and the server host name
matches the one in the certificate)
For support ssl key in memory, we extend sslmode. For example:
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"log"
"github.com/lib/pq"
)
func main() {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile("ca.crt")
if err != nil {
log.Fatal(err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append PEM.")
}
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair("client1.crt", "client1.key")
if err != nil {
log.Fatal(err)
}
clientCert = append(clientCert, certs)
err = pq.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
ServerName: "pq.example.com",
})
if err != nil {
log.Fatal(err)
}
connStr := "host=pq.example.com port=5432 user=user1 dbname=pqgotest password=pqgotest sslmode=custom"
db, err := sql.Open("postgres", connStr)
if err != nil {
log.Fatal(err)
}
}
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
for more information about connection string parameters.
Expand Down
74 changes: 65 additions & 9 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,74 @@ package pq
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
)

// Registry for custom tls.Configs
var (
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
)

func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "require" || strings.ToLower(key) == "verify-ca" || strings.ToLower(key) == "verify-full" || strings.ToLower(key) == "disable" {
return fmt.Errorf("key '%s' is reserved", key)
}

tlsConfigLock.Lock()
if tlsConfigRegistry == nil {
tlsConfigRegistry = make(map[string]*tls.Config)
}

tlsConfigRegistry[key] = config
tlsConfigLock.Unlock()
return nil
}

// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
tlsConfigLock.Lock()
if tlsConfigRegistry != nil {
delete(tlsConfigRegistry, key)
}
tlsConfigLock.Unlock()
}

func getTLSConfigClone(key string) (config *tls.Config) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
}
tlsConfigLock.RUnlock()
return
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}

// Not a valid bool value
return
}

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
// related settings. The function is nil when no upgrade should take place.
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
verifyCaOnly := false
tlsConf := tls.Config{}
tlsConf := &tls.Config{}
switch mode := o["sslmode"]; mode {
// "require" is the default.
case "", "require":
Expand Down Expand Up @@ -48,7 +103,12 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
case "disable":
return nil, nil
default:
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
{
tlsConf = getTLSConfigClone(mode)
if tlsConf == nil {
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}
}
}

// Set Server Name Indication (SNI), if enabled by connection parameters.
Expand All @@ -61,11 +121,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
tlsConf.ServerName = o["host"]
}

err := sslClientCertificates(&tlsConf, o)
if err != nil {
return nil, err
}
err = sslCertificateAuthority(&tlsConf, o)
err := sslClientCertificates(tlsConf, o)
if err != nil {
return nil, err
}
Expand All @@ -78,9 +134,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient

return func(conn net.Conn) (net.Conn, error) {
client := tls.Client(conn, &tlsConf)
client := tls.Client(conn, tlsConf)
if verifyCaOnly {
err := sslVerifyCertificateAuthority(client, &tlsConf)
err := sslVerifyCertificateAuthority(client, tlsConf)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c90ae64

Please sign in to comment.