Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add socks4 support #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The package has the following features:
* Rules to do granular filtering of commands
* Custom DNS resolution
* Unit tests
* Support for socks4 if enabled

TODO
====
Expand Down
211 changes: 138 additions & 73 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,61 @@ type conn interface {
}

// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
// Read the version byte
header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err)
func NewRequest(bufConn io.Reader, reqVersion byte) (*Request, error) {
request := &Request{
Version: reqVersion,
bufConn: bufConn,
}

// Ensure we are compatible
if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
}
if reqVersion == socks5Version {
header := []byte{0, 0, 0}
// Read the version byte
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err)
}

// Read in the destination address
dest, err := readAddrSpec(bufConn)
if err != nil {
return nil, err
}
// Ensure we are compatible
if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
}

request := &Request{
Version: socks5Version,
Command: header[1],
DestAddr: dest,
bufConn: bufConn,
request.Command = header[1]

var err error
// Read in the destination address
request.DestAddr, err = readAddrSpec(bufConn)
if err != nil {
return nil, err
}
} else {
header := []byte{0}
// Read the command byte
if _, err := io.ReadAtLeast(bufConn, header, 1); err != nil {
return nil, fmt.Errorf("Failed to get command: %v", err)
}

// Ensure we are compatible
if header[0] != 1 {
return nil, fmt.Errorf("Unsupported command: %v", header[0])
}
request.Command = header[0]

var err error
// Read in the destination address
request.DestAddr, err = readAddrSpecv4(bufConn)
if err != nil {
return nil, err
}

authStr := make([]byte, 256)
n, err := io.ReadAtLeast(bufConn, authStr, 1)
if err != nil {
return nil, fmt.Errorf("Failed to get auth string: %v", err)
}

if n > 0 {
request.AuthContext = &AuthContext{UserPassAuth, map[string]string{"Username": string(authStr[:n-1])}}
}
}

return request, nil
Expand All @@ -122,14 +154,14 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
// Resolve the address if we have a FQDN
dest := req.DestAddr
if dest.FQDN != "" {
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
nctx, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil {
if err := sendReply(conn, hostUnreachable, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
}
ctx = ctx_
ctx = nctx
dest.IP = addr
}

Expand All @@ -148,7 +180,7 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
case AssociateCommand:
return s.handleAssociate(ctx, conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
if err := sendReply(conn, commandNotSupported, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Unsupported command: %v", req.Command)
Expand All @@ -158,13 +190,13 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
if nctx, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
ctx = nctx
}

// Attempt to connect
Expand All @@ -174,6 +206,7 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
return net.Dial(net_, addr)
}
}

target, err := dial(ctx, "tcp", req.realDestAddr.Address())
if err != nil {
msg := err.Error()
Expand All @@ -183,7 +216,7 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
} else if strings.Contains(msg, "network is unreachable") {
resp = networkUnreachable
}
if err := sendReply(conn, resp, nil); err != nil {
if err := sendReply(conn, resp, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
Expand All @@ -193,7 +226,7 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
// Send success
local := target.LocalAddr().(*net.TCPAddr)
bind := AddrSpec{IP: local.IP, Port: local.Port}
if err := sendReply(conn, successReply, &bind); err != nil {
if err := sendReply(conn, successReply, &bind, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}

Expand All @@ -216,17 +249,17 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
// handleBind is used to handle a connect command
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
if nctx, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
ctx = nctx
}

// TODO: Support bind
if err := sendReply(conn, commandNotSupported, nil); err != nil {
if err := sendReply(conn, commandNotSupported, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
Expand All @@ -235,17 +268,17 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
if nctx, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
ctx = nctx
}

// TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil {
if err := sendReply(conn, commandNotSupported, nil, req.Version); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
Expand Down Expand Up @@ -303,46 +336,78 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) {
return d, nil
}

// sendReply is used to send a reply message
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
// Format the address
var addrType uint8
var addrBody []byte
var addrPort uint16
switch {
case addr == nil:
addrType = ipv4Address
addrBody = []byte{0, 0, 0, 0}
addrPort = 0

case addr.FQDN != "":
addrType = fqdnAddress
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
addrPort = uint16(addr.Port)

case addr.IP.To4() != nil:
addrType = ipv4Address
addrBody = []byte(addr.IP.To4())
addrPort = uint16(addr.Port)

case addr.IP.To16() != nil:
addrType = ipv6Address
addrBody = []byte(addr.IP.To16())
addrPort = uint16(addr.Port)
// Expects port, follwed by the address
func readAddrSpecv4(r io.Reader) (*AddrSpec, error) {
d := &AddrSpec{}

default:
return fmt.Errorf("Failed to format address: %v", addr)
// Read the port
port := []byte{0, 0}
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
return nil, err
}
d.Port = (int(port[0]) << 8) | int(port[1])

// Format the message
msg := make([]byte, 6+len(addrBody))
msg[0] = socks5Version
msg[1] = resp
msg[2] = 0 // Reserved
msg[3] = addrType
copy(msg[4:], addrBody)
msg[4+len(addrBody)] = byte(addrPort >> 8)
msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
addr := []byte{0, 0, 0, 0}
if _, err := io.ReadAtLeast(r, addr, 4); err != nil {
return nil, err
}
d.IP = net.IP(addr)

return d, nil
}

// sendReply is used to send a reply message
func sendReply(w io.Writer, resp uint8, addr *AddrSpec, version byte) error {
var msg []byte
if version == socks5Version {
// Format the address
var addrType uint8
var addrBody []byte
var addrPort uint16
switch {
case addr == nil:
addrType = ipv4Address
addrBody = []byte{0, 0, 0, 0}
addrPort = 0

case addr.FQDN != "":
addrType = fqdnAddress
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
addrPort = uint16(addr.Port)

case addr.IP.To4() != nil:
addrType = ipv4Address
addrBody = []byte(addr.IP.To4())
addrPort = uint16(addr.Port)

case addr.IP.To16() != nil:
addrType = ipv6Address
addrBody = []byte(addr.IP.To16())
addrPort = uint16(addr.Port)

default:
return fmt.Errorf("Failed to format address: %v", addr)
}

// Format the message
msg = make([]byte, 6+len(addrBody))
msg[0] = socks5Version
msg[1] = resp
msg[2] = 0 // Reserved
msg[3] = addrType
copy(msg[4:], addrBody)
msg[4+len(addrBody)] = byte(addrPort >> 8)
msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
} else {
msg = make([]byte, 8)
msg[0] = 0
if resp == successReply {
msg[1] = 0x5a
} else {
msg[1] = 0x5b
}
// bytes 3-8 are reserved
}

// Send the message
_, err := w.Write(msg)
Expand Down
4 changes: 2 additions & 2 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestRequest_Connect(t *testing.T) {

// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
req, err := NewRequest(buf, 5)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func TestRequest_Connect_RuleFail(t *testing.T) {

// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
req, err := NewRequest(buf, 5)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down
Loading