Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection logging to help with debugging #626

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions conn_str.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mssql

import (
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -39,6 +40,7 @@ type connectParams struct {
packetSize uint16
fedAuthLibrary int
fedAuthADALWorkflow byte
tlsKeyLogFile string
}

// default packet size for TDS buffer
Expand Down Expand Up @@ -235,6 +237,11 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}

p.tlsKeyLogFile, ok = params["tls key log file"]
if ok && p.tlsKeyLogFile != "" && p.disableEncryption {
return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled")
}

return p, nil
}

Expand Down
5 changes: 3 additions & 2 deletions conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func TestValidConnectionString(t *testing.T) {
{"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }},
{"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }},
{"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }},
{"tls key log file=tls.log", func(p connectParams) bool { return p.tlsKeyLogFile == "tls.log" }},
{"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool {
return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second
}},
Expand Down Expand Up @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams {
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
return connectParams{
host: os.Getenv("HOST"),
host: os.Getenv("HOST"),
instance: os.Getenv("INSTANCE"),
database: os.Getenv("DATABASE"),
user: os.Getenv("SQLUSER"),
user: os.Getenv("SQLUSER"),
password: os.Getenv("SQLPASSWORD"),
logFlags: logFlags,
}
Expand Down
80 changes: 80 additions & 0 deletions log_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package mssql

import (
"encoding/hex"
"net"
"strings"
"time"
)

type connLogger struct {
conn net.Conn
readKind, writeKind string
readCount, writeCount int
logger Logger
}

var _ net.Conn = &connLogger{}

func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn {
if len(kind) > 0 && !strings.HasPrefix(kind, " ") {
kind = " " + kind
}

cl := &connLogger{
conn: conn,
readKind: "R" + kind,
writeKind: "W" + kind,
logger: logger,
}

return cl
}

func (cl *connLogger) Read(p []byte) (n int, err error) {
n, err = cl.conn.Read(p)

if n > 0 {
dump := hex.Dump(p)
cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump)
cl.readCount += n
}

return
}

func (cl *connLogger) Write(p []byte) (n int, err error) {
n, err = cl.conn.Write(p)

if n > 0 {
dump := hex.Dump(p)
cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump)
cl.writeCount += n
}

return
}

func (cl *connLogger) Close() (err error) {
return cl.conn.Close()
}

func (cl *connLogger) LocalAddr() net.Addr {
return cl.conn.LocalAddr()
}

func (cl *connLogger) RemoteAddr() net.Addr {
return cl.conn.RemoteAddr()
}

func (cl *connLogger) SetDeadline(t time.Time) error {
return cl.conn.SetDeadline(t)
}

func (cl *connLogger) SetReadDeadline(t time.Time) error {
return cl.conn.SetReadDeadline(t)
}

func (cl *connLogger) SetWriteDeadline(t time.Time) error {
return cl.conn.SetWriteDeadline(t)
}
121 changes: 121 additions & 0 deletions log_conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package mssql

import (
"net"
"sync/atomic"
"testing"
"time"
)

func TestConnLoggerOperations(t *testing.T) {
clt := &connLoggerTest{}
cl := newConnLogger(clt, "test", nullLogger{})
packet := append(make([]byte, 0, 10), 1, 2, 3, 4, 5)
n, err := cl.Read(packet)
if n != 10 || err != nil {
t.Error("Unexpected return value from call to Read()")
}

n, err = cl.Write(packet)
if n != 5 || err != nil {
t.Error("Unexpected return value from call to Write()")
}

if cl.Close() != nil {
t.Error("Unexpected return value from call to Close()")
}

if cl.LocalAddr() == nil {
t.Error("Unexpected return value from call to LocalAddr()")
}

if cl.RemoteAddr() == nil {
t.Error("Unexpected return value from call to RemoteAddr()")
}

if cl.SetDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetDeadline()")
}

if cl.SetReadDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetReadDeadline()")
}

if cl.SetWriteDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetWriteDeadline()")
}

if atomic.LoadInt32(&clt.calls) != 8 {
t.Error("Unexpected number of calls recorded")
}
}

type connLoggerTest struct {
calls int32
}

var _ net.Conn = &connLoggerTest{}

type addressTest struct {
}

var _ net.Addr = &addressTest{}

type nullLogger struct {
}

var _ Logger = nullLogger{}

func (n nullLogger) Printf(format string, v ...interface{}) {
}

func (n nullLogger) Println(v ...interface{}) {
}

func (a *addressTest) Network() string {
return "test"
}

func (a *addressTest) String() string {
return "test"
}

func (cl *connLoggerTest) Read(p []byte) (int, error) {
atomic.AddInt32(&cl.calls, 1)
return cap(p), nil
}

func (cl *connLoggerTest) Write(p []byte) (int, error) {
atomic.AddInt32(&cl.calls, 1)
return len(p), nil
}

func (cl *connLoggerTest) Close() error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) LocalAddr() net.Addr {
atomic.AddInt32(&cl.calls, 1)
return &addressTest{}
}

func (cl *connLoggerTest) RemoteAddr() net.Addr {
atomic.AddInt32(&cl.calls, 1)
return &addressTest{}
}

func (cl *connLoggerTest) SetDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) SetReadDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) SetWriteDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}
20 changes: 19 additions & 1 deletion tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"io/ioutil"
"net"
"os"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -152,6 +153,7 @@ const (
logParams = 16
logTransaction = 32
logDebug = 64
logTraffic = 128
)

type columnStruct struct {
Expand Down Expand Up @@ -1059,6 +1061,10 @@ initiate_connection:
return nil, err
}

if p.logFlags&logTraffic != 0 {
conn = newConnLogger(conn, "TCP", log)
}

toconn := newTimeoutConn(conn, p.conn_timeout)

outbuf := newTdsBuffer(p.packetSize, toconn)
Expand Down Expand Up @@ -1104,6 +1110,14 @@ initiate_connection:
if p.trustServerCertificate {
config.InsecureSkipVerify = true
}
if p.tlsKeyLogFile != "" {
if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil {
defer w.Close()
config.KeyLogWriter = w
} else {
return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err)
}
}
config.ServerName = p.hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
Expand All @@ -1116,7 +1130,11 @@ initiate_connection:
tlsConn := tls.Client(&passthrough, &config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if sess.logFlags&logTraffic != 0 {
outbuf.transport = newConnLogger(tlsConn, "TLS", log)
} else {
outbuf.transport = tlsConn
}
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
Expand Down