From b17b191e9105469450714b418fab0d945db5927a Mon Sep 17 00:00:00 2001 From: orenbm Date: Wed, 23 Oct 2019 10:07:02 -0400 Subject: [PATCH 1/2] Expose constants and functions so that they are accessible when consuming the project MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Consts   - Prelogin consts (PreloginVERSION, PreloginENCRYPTION, etc.)   - Encrypt consts (EncryptOff, EncryptOn, etc.) - Types   - TdsBuffer   - KeySlice - Functions   - NewTdsBuffer   - ReadNextPacket --- buf.go | 54 +++++++++++++++++++------------------- buf_test.go | 14 +++++----- net.go | 6 ++--- rpc.go | 2 +- tds.go | 74 ++++++++++++++++++++++++++--------------------------- tds_test.go | 12 ++++----- token.go | 24 ++++++++--------- tran.go | 6 ++--- types.go | 20 +++++++-------- 9 files changed, 105 insertions(+), 107 deletions(-) diff --git a/buf.go b/buf.go index ba39b40f..a30ad86a 100644 --- a/buf.go +++ b/buf.go @@ -17,11 +17,11 @@ type header struct { Pad uint8 } -// tdsBuffer reads and writes TDS packets of data to the transport. +// TdsBuffer reads and writes TDS packets of data to the transport. // The write and read buffers are separate to make sending attn signals // possible without locks. Currently attn signals are only sent during // reads, not writes. -type tdsBuffer struct { +type TdsBuffer struct { transport io.ReadWriteCloser packetSize int @@ -39,14 +39,14 @@ type tdsBuffer struct { final bool rPacketType packetType - // afterFirst is assigned to right after tdsBuffer is created and + // afterFirst is assigned to right after TdsBuffer is created and // before the first use. It is executed after the first packet is // written and then removed. afterFirst func() } -func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { - return &tdsBuffer{ +func NewTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *TdsBuffer { + return &TdsBuffer{ packetSize: int(bufsize), wbuf: make([]byte, 1<<16), rbuf: make([]byte, 1<<16), @@ -55,15 +55,15 @@ func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { } } -func (rw *tdsBuffer) ResizeBuffer(packetSize int) { +func (rw *TdsBuffer) ResizeBuffer(packetSize int) { rw.packetSize = packetSize } -func (w *tdsBuffer) PackageSize() int { +func (w *TdsBuffer) PackageSize() int { return w.packetSize } -func (w *tdsBuffer) flush() (err error) { +func (w *TdsBuffer) flush() (err error) { // Write packet size. w.wbuf[0] = byte(w.wPacketType) binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) @@ -88,7 +88,7 @@ func (w *tdsBuffer) flush() (err error) { return nil } -func (w *tdsBuffer) Write(p []byte) (total int, err error) { +func (w *TdsBuffer) Write(p []byte) (total int, err error) { for { copied := copy(w.wbuf[w.wpos:w.packetSize], p) w.wpos += copied @@ -103,7 +103,7 @@ func (w *tdsBuffer) Write(p []byte) (total int, err error) { } } -func (w *tdsBuffer) WriteByte(b byte) error { +func (w *TdsBuffer) WriteByte(b byte) error { if int(w.wpos) == len(w.wbuf) || w.wpos == w.packetSize { if err := w.flush(); err != nil { return err @@ -114,7 +114,7 @@ func (w *tdsBuffer) WriteByte(b byte) error { return nil } -func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) { +func (w *TdsBuffer) BeginPacket(packetType packetType, resetSession bool) { status := byte(0) if resetSession { switch packetType { @@ -129,14 +129,14 @@ func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) { w.wPacketType = packetType } -func (w *tdsBuffer) FinishPacket() error { +func (w *TdsBuffer) FinishPacket() error { w.wbuf[1] |= 1 // Mark this as the last packet in the message. return w.flush() } var headerSize = binary.Size(header{}) -func (r *tdsBuffer) readNextPacket() error { +func (r *TdsBuffer) ReadNextPacket() error { h := header{} var err error err = binary.Read(r.transport, binary.BigEndian, &h) @@ -160,20 +160,20 @@ func (r *tdsBuffer) readNextPacket() error { return nil } -func (r *tdsBuffer) BeginRead() (packetType, error) { - err := r.readNextPacket() +func (r *TdsBuffer) BeginRead() (packetType, error) { + err := r.ReadNextPacket() if err != nil { return 0, err } return r.rPacketType, nil } -func (r *tdsBuffer) ReadByte() (res byte, err error) { +func (r *TdsBuffer) ReadByte() (res byte, err error) { if r.rpos == r.rsize { if r.final { return 0, io.EOF } - err = r.readNextPacket() + err = r.ReadNextPacket() if err != nil { return 0, err } @@ -183,7 +183,7 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { return res, nil } -func (r *tdsBuffer) byte() byte { +func (r *TdsBuffer) byte() byte { b, err := r.ReadByte() if err != nil { badStreamPanic(err) @@ -191,36 +191,36 @@ func (r *tdsBuffer) byte() byte { return b } -func (r *tdsBuffer) ReadFull(buf []byte) { +func (r *TdsBuffer) ReadFull(buf []byte) { _, err := io.ReadFull(r, buf[:]) if err != nil { badStreamPanic(err) } } -func (r *tdsBuffer) uint64() uint64 { +func (r *TdsBuffer) uint64() uint64 { var buf [8]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint64(buf[:]) } -func (r *tdsBuffer) int32() int32 { +func (r *TdsBuffer) int32() int32 { return int32(r.uint32()) } -func (r *tdsBuffer) uint32() uint32 { +func (r *TdsBuffer) uint32() uint32 { var buf [4]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint32(buf[:]) } -func (r *tdsBuffer) uint16() uint16 { +func (r *TdsBuffer) uint16() uint16 { var buf [2]byte r.ReadFull(buf[:]) return binary.LittleEndian.Uint16(buf[:]) } -func (r *tdsBuffer) BVarChar() string { +func (r *TdsBuffer) BVarChar() string { return readBVarCharOrPanic(r) } @@ -240,18 +240,18 @@ func readUsVarCharOrPanic(r io.Reader) string { return s } -func (r *tdsBuffer) UsVarChar() string { +func (r *TdsBuffer) UsVarChar() string { return readUsVarCharOrPanic(r) } -func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { +func (r *TdsBuffer) Read(buf []byte) (copied int, err error) { copied = 0 err = nil if r.rpos == r.rsize { if r.final { return 0, io.EOF } - err = r.readNextPacket() + err = r.ReadNextPacket() if err != nil { return } diff --git a/buf_test.go b/buf_test.go index 76efbd1d..42b1aaba 100644 --- a/buf_test.go +++ b/buf_test.go @@ -29,15 +29,15 @@ func (failBuffer) Close() error { return nil } -func makeBuf(bufSize uint16, testData []byte) *tdsBuffer { +func makeBuf(bufSize uint16, testData []byte) *TdsBuffer { buffer := closableBuffer{bytes.NewBuffer(testData)} - return newTdsBuffer(bufSize, &buffer) + return NewTdsBuffer(bufSize, &buffer) } func TestStreamShorterThanHeader(t *testing.T) { //buffer := closableBuffer{*bytes.NewBuffer([]byte{0xFF, 0xFF})} //buffer := closableBuffer{*bytes.NewBuffer([]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF})} - //tdsBuffer := newTdsBuffer(100, &buffer) + //TdsBuffer := NewTdsBuffer(100, &buffer) buffer := makeBuf(100, []byte{0xFF, 0xFF}) _, err := buffer.BeginRead() if err == nil { @@ -187,7 +187,7 @@ func TestReadFailsOnSecondPacket(t *testing.T) { func TestWrite(t *testing.T) { memBuf := bytes.NewBuffer([]byte{}) - buf := newTdsBuffer(11, closableBuffer{memBuf}) + buf := NewTdsBuffer(11, closableBuffer{memBuf}) buf.BeginPacket(1, false) err := buf.WriteByte(2) if err != nil { @@ -233,7 +233,7 @@ func TestWrite(t *testing.T) { func TestWriteErrors(t *testing.T) { // write should fail if underlying transport fails - buf := newTdsBuffer(uint16(headerSize)+1, failBuffer{}) + buf := NewTdsBuffer(uint16(headerSize)+1, failBuffer{}) buf.BeginPacket(1, false) wrote, err := buf.Write([]byte{0, 0}) // may change from error to panic in future @@ -245,7 +245,7 @@ func TestWriteErrors(t *testing.T) { } // writebyte should fail if underlying transport fails - buf = newTdsBuffer(uint16(headerSize)+1, failBuffer{}) + buf = NewTdsBuffer(uint16(headerSize)+1, failBuffer{}) buf.BeginPacket(1, false) // first write should not fail because if fits in the buffer err = buf.WriteByte(0) @@ -261,7 +261,7 @@ func TestWriteErrors(t *testing.T) { func TestWrite_BufferBounds(t *testing.T) { memBuf := bytes.NewBuffer([]byte{}) - buf := newTdsBuffer(11, closableBuffer{memBuf}) + buf := NewTdsBuffer(11, closableBuffer{memBuf}) buf.BeginPacket(1, false) // write bytes enough to complete a package diff --git a/net.go b/net.go index 94858cc7..64d96f22 100644 --- a/net.go +++ b/net.go @@ -7,8 +7,8 @@ import ( ) type timeoutConn struct { - c net.Conn - timeout time.Duration + c net.Conn + timeout time.Duration } func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { @@ -65,7 +65,7 @@ func (c timeoutConn) SetWriteDeadline(t time.Time) error { // this connection is used during TLS Handshake // TDS protocol requires TLS handshake messages to be sent inside TDS packets type tlsHandshakeConn struct { - buf *tdsBuffer + buf *TdsBuffer packetPending bool continueRead bool } diff --git a/rpc.go b/rpc.go index 4ca22578..ca660cf2 100644 --- a/rpc.go +++ b/rpc.go @@ -46,7 +46,7 @@ var ( ) // http://msdn.microsoft.com/en-us/library/dd357576.aspx -func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) { +func sendRpc(buf *TdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) { buf.BeginPacket(packRPCRequest, resetSession) writeAllHeaders(buf, headers) if len(proc.name) == 0 { diff --git a/tds.go b/tds.go index 5a9f53b7..6b4a1963 100644 --- a/tds.go +++ b/tds.go @@ -100,24 +100,24 @@ const ( // prelogin fields // http://msdn.microsoft.com/en-us/library/dd357559.aspx const ( - preloginVERSION = 0 - preloginENCRYPTION = 1 - preloginINSTOPT = 2 - preloginTHREADID = 3 - preloginMARS = 4 - preloginTRACEID = 5 - preloginTERMINATOR = 0xff + PreloginVERSION = 0 + PreloginENCRYPTION = 1 + PreloginINSTOPT = 2 + PreloginTHREADID = 3 + PreloginMARS = 4 + PreloginTRACEID = 5 + PreloginTERMINATOR = 0xff ) const ( - encryptOff = 0 // Encryption is available but off. - encryptOn = 1 // Encryption is available and on. - encryptNotSup = 2 // Encryption is not available. - encryptReq = 3 // Encryption is required. + EncryptOff = 0 // Encryption is available but off. + EncryptOn = 1 // Encryption is available and on. + EncryptNotSup = 2 // Encryption is not available. + EncryptReq = 3 // Encryption is required. ) type tdsSession struct { - buf *tdsBuffer + buf *TdsBuffer loginAck loginAckStruct database string partner string @@ -146,19 +146,19 @@ type columnStruct struct { ti typeInfo } -type keySlice []uint8 +type KeySlice []uint8 -func (p keySlice) Len() int { return len(p) } -func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } -func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p KeySlice) Len() int { return len(p) } +func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] } +func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx -func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { +func writePrelogin(w *TdsBuffer, fields map[uint8][]byte) error { var err error w.BeginPacket(packPrelogin, false) offset := uint16(5*len(fields) + 1) - keys := make(keySlice, 0, len(fields)) + keys := make(KeySlice, 0, len(fields)) for k, _ := range fields { keys = append(keys, k) } @@ -181,7 +181,7 @@ func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { } offset += size } - err = w.WriteByte(preloginTERMINATOR) + err = w.WriteByte(PreloginTERMINATOR) if err != nil { return err } @@ -199,7 +199,7 @@ func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { return w.FinishPacket() } -func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { +func readPrelogin(r *TdsBuffer) (map[uint8][]byte, error) { packet_type, err := r.BeginRead() if err != nil { return nil, err @@ -215,7 +215,7 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { results := map[uint8][]byte{} for true { rec_type := struct_buf[offset] - if rec_type == preloginTERMINATOR { + if rec_type == PreloginTERMINATOR { break } @@ -345,7 +345,7 @@ func manglePassword(password string) []byte { } // http://msdn.microsoft.com/en-us/library/dd304019.aspx -func sendLogin(w *tdsBuffer, login login) error { +func sendLogin(w *TdsBuffer, login login) error { w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) @@ -624,7 +624,7 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { return nil } -func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { +func sendSqlBatch72(buf *TdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { buf.BeginPacket(packSQLBatch, resetSession) if err = writeAllHeaders(buf, headers); err != nil { @@ -640,7 +640,7 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, rese // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx -func sendAttention(buf *tdsBuffer) error { +func sendAttention(buf *TdsBuffer) error { buf.BeginPacket(packAttention, false) return buf.FinishPacket() } @@ -752,7 +752,7 @@ initiate_connection: toconn := newTimeoutConn(conn, p.conn_timeout) - outbuf := newTdsBuffer(p.packetSize, toconn) + outbuf := NewTdsBuffer(p.packetSize, toconn) sess := tdsSession{ buf: outbuf, log: log, @@ -763,18 +763,18 @@ initiate_connection: instance_buf = append(instance_buf, 0) // zero terminate instance name var encrypt byte if p.disableEncryption { - encrypt = encryptNotSup + encrypt = EncryptNotSup } else if p.encrypt { - encrypt = encryptOn + encrypt = EncryptOn } else { - encrypt = encryptOff + encrypt = EncryptOff } fields := map[uint8][]byte{ - preloginVERSION: {0, 0, 0, 0, 0, 0}, - preloginENCRYPTION: {encrypt}, - preloginINSTOPT: instance_buf, - preloginTHREADID: {0, 0, 0, 0}, - preloginMARS: {0}, // MARS disabled + PreloginVERSION: {0, 0, 0, 0, 0, 0}, + PreloginENCRYPTION: {encrypt}, + PreloginINSTOPT: instance_buf, + PreloginTHREADID: {0, 0, 0, 0}, + PreloginMARS: {0}, // MARS disabled } err = writePrelogin(outbuf, fields) @@ -787,16 +787,16 @@ initiate_connection: return nil, err } - encryptBytes, ok := fields[preloginENCRYPTION] + encryptBytes, ok := fields[PreloginENCRYPTION] if !ok { return nil, fmt.Errorf("Encrypt negotiation failed") } encrypt = encryptBytes[0] - if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { + if p.encrypt && (encrypt == EncryptNotSup || encrypt == EncryptOff) { return nil, fmt.Errorf("Server does not support encryption") } - if encrypt != encryptNotSup { + if encrypt != EncryptNotSup { var config tls.Config if p.certificate != "" { pem, err := ioutil.ReadFile(p.certificate) @@ -826,7 +826,7 @@ initiate_connection: if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) } - if encrypt == encryptOff { + if encrypt == EncryptOff { outbuf.afterFirst = func() { outbuf.transport = toconn } diff --git a/tds_test.go b/tds_test.go index 00360d85..f2bf6686 100644 --- a/tds_test.go +++ b/tds_test.go @@ -22,7 +22,7 @@ func (t *MockTransport) Close() error { func TestSendLogin(t *testing.T) { memBuf := new(MockTransport) - buf := newTdsBuffer(1024, memBuf) + buf := NewTdsBuffer(1024, memBuf) login := login{ TDSVersion: verTDS73, PacketSize: 0x1000, @@ -176,8 +176,6 @@ func (l testLogger) Println(v ...interface{}) { l.t.Log(v...) } - - func TestConnect(t *testing.T) { checkConnStr(t) SetLogger(testLogger{t}) @@ -413,7 +411,7 @@ func TestSSPIAuth(t *testing.T) { func TestUcs22Str(t *testing.T) { // Test valid input - s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + s, err := ucs22str([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding if err != nil { t.Errorf("ucs22str should not fail for valid ucs2 byte sequence: %s", err) } @@ -429,7 +427,7 @@ func TestUcs22Str(t *testing.T) { } func TestReadUcs2(t *testing.T) { - buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding + buf := bytes.NewBuffer([]byte{0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding s, err := readUcs2(buf, 3) if err != nil { t.Errorf("readUcs2 should not fail for valid ucs2 byte sequence: %s", err) @@ -447,7 +445,7 @@ func TestReadUcs2(t *testing.T) { func TestReadUsVarChar(t *testing.T) { // should succeed for valid buffer - buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 + buf := bytes.NewBuffer([]byte{3, 0, 0x31, 0, 0x32, 0, 0x33, 0}) // 123 in UCS2 encoding with length prefix 3 uint16 s, err := readUsVarChar(buf) if err != nil { t.Errorf("readUsVarChar should not fail for valid ucs2 byte sequence: %s", err) @@ -487,4 +485,4 @@ func TestReadBVarByte(t *testing.T) { if err == nil { t.Error("readUsVarByte should fail on short buffer, but it didn't") } -} \ No newline at end of file +} diff --git a/token.go b/token.go index 1acac8a5..788c2b4d 100644 --- a/token.go +++ b/token.go @@ -386,11 +386,11 @@ func processEnvChg(sess *tdsSession) { } // http://msdn.microsoft.com/en-us/library/dd358180.aspx -func parseReturnStatus(r *tdsBuffer) ReturnStatus { +func parseReturnStatus(r *TdsBuffer) ReturnStatus { return ReturnStatus(r.int32()) } -func parseOrder(r *tdsBuffer) (res orderStruct) { +func parseOrder(r *TdsBuffer) (res orderStruct) { len := int(r.uint16()) res.ColIds = make([]uint16, len/2) for i := 0; i < len/2; i++ { @@ -400,7 +400,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) { } // https://msdn.microsoft.com/en-us/library/dd340421.aspx -func parseDone(r *tdsBuffer) (res doneStruct) { +func parseDone(r *TdsBuffer) (res doneStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() @@ -408,7 +408,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) { } // https://msdn.microsoft.com/en-us/library/dd340553.aspx -func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { +func parseDoneInProc(r *TdsBuffer) (res doneInProcStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() @@ -417,7 +417,7 @@ func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { type sspiMsg []byte -func parseSSPIMsg(r *tdsBuffer) sspiMsg { +func parseSSPIMsg(r *TdsBuffer) sspiMsg { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) @@ -431,7 +431,7 @@ type loginAckStruct struct { ProgVer uint32 } -func parseLoginAck(r *tdsBuffer) loginAckStruct { +func parseLoginAck(r *TdsBuffer) loginAckStruct { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) @@ -448,7 +448,7 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { } // http://msdn.microsoft.com/en-us/library/dd357363.aspx -func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { +func parseColMetadata72(r *TdsBuffer) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent @@ -468,14 +468,14 @@ func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { } // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseRow(r *TdsBuffer, columns []columnStruct, row []interface{}) { for i, column := range columns { row[i] = column.ti.Reader(&column.ti, r) } } // http://msdn.microsoft.com/en-us/library/dd304783.aspx -func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseNbcRow(r *TdsBuffer, columns []columnStruct, row []interface{}) { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) @@ -489,7 +489,7 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseError72(r *tdsBuffer) (res Error) { +func parseError72(r *TdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -503,7 +503,7 @@ func parseError72(r *tdsBuffer) (res Error) { } // http://msdn.microsoft.com/en-us/library/dd304156.aspx -func parseInfo(r *tdsBuffer) (res Error) { +func parseInfo(r *TdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() @@ -517,7 +517,7 @@ func parseInfo(r *tdsBuffer) (res Error) { } // https://msdn.microsoft.com/en-us/library/dd303881.aspx -func parseReturnValue(r *tdsBuffer) (nv namedValue) { +func parseReturnValue(r *TdsBuffer) (nv namedValue) { /* ParamOrdinal ParamName diff --git a/tran.go b/tran.go index cb643681..52832355 100644 --- a/tran.go +++ b/tran.go @@ -28,7 +28,7 @@ const ( isolationSnapshot = 5 ) -func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { +func sendBeginXact(buf *TdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmBeginXact @@ -51,7 +51,7 @@ const ( fBeginXact = 1 ) -func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { +func sendCommitXact(buf *TdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmCommitXact @@ -80,7 +80,7 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u return buf.FinishPacket() } -func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { +func sendRollbackXact(buf *TdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmRollbackXact diff --git a/types.go b/types.go index b6e7fb2b..4050e382 100644 --- a/types.go +++ b/types.go @@ -90,7 +90,7 @@ type typeInfo struct { Collation cp.Collation UdtInfo udtInfo XmlInfo xmlInfo - Reader func(ti *typeInfo, r *tdsBuffer) (res interface{}) + Reader func(ti *typeInfo, r *TdsBuffer) (res interface{}) Writer func(w io.Writer, ti typeInfo, buf []byte) (err error) } @@ -113,7 +113,7 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer) (res typeInfo) { +func readTypeInfo(r *TdsBuffer) (res typeInfo) { res.TypeId = r.byte() switch res.TypeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, @@ -309,7 +309,7 @@ func decodeDateTime(buf []byte) time.Time { 0, 0, secs, ns, time.UTC) } -func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { +func readFixedType(ti *typeInfo, r *TdsBuffer) interface{} { r.ReadFull(ti.Buffer) buf := ti.Buffer switch ti.TypeId { @@ -343,7 +343,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { panic("shoulnd't get here") } -func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { +func readByteLenType(ti *typeInfo, r *TdsBuffer) interface{} { size := r.byte() if size == 0 { return nil @@ -442,7 +442,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { +func readShortLenType(ti *typeInfo, r *TdsBuffer) interface{} { size := r.uint16() if size == 0xffff { return nil @@ -485,7 +485,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { +func readLongLenType(ti *typeInfo, r *TdsBuffer) interface{} { // information about this format can be found here: // http://msdn.microsoft.com/en-us/library/dd304783.aspx // and here: @@ -544,7 +544,7 @@ func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readCollation(r *tdsBuffer) (res cp.Collation) { +func readCollation(r *TdsBuffer) (res cp.Collation) { res.LcidAndFlags = r.uint32() res.SortId = r.byte() return @@ -560,7 +560,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) { // reads variant value // http://msdn.microsoft.com/en-us/library/dd303302.aspx -func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { +func readVariantType(ti *typeInfo, r *TdsBuffer) interface{} { size := r.int32() if size == 0 { return nil @@ -652,7 +652,7 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx -func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { +func readPLPType(ti *typeInfo, r *TdsBuffer) interface{} { size := r.uint64() var buf *bytes.Buffer switch size { @@ -709,7 +709,7 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { } } -func readVarLen(ti *typeInfo, r *tdsBuffer) { +func readVarLen(ti *typeInfo, r *TdsBuffer) { switch ti.TypeId { case typeDateN: ti.Size = 3 From 77827dd23a511f3330385ba3f880f5ddb97a3931 Mon Sep 17 00:00:00 2001 From: orenbm Date: Wed, 23 Oct 2019 10:07:47 -0400 Subject: [PATCH 2/2] Add a method that exposes the underlying transport to the backend --- mssql.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mssql.go b/mssql.go index e37109cd..c0a93238 100644 --- a/mssql.go +++ b/mssql.go @@ -152,6 +152,11 @@ type Conn struct { returnStatus *ReturnStatus } +// NetConn exposes the underlying transport to the backend +func (c *Conn) NetConn() net.Conn { + return c.sess.buf.transport.(net.Conn) +} + func (c *Conn) setReturnStatus(s ReturnStatus) { if c.returnStatus == nil { return