diff --git a/README.md b/README.md index ac6293f..dfb5499 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,27 @@ An implementation of a L4 TCP proxy written in Golang. Provides the following functionality: -* mTLS authentication, with an authorization system to define what groups can access the upstreams +* mTLS authentication, with an authorization system to define what groups can access the upstreams. -* Builtin least-connection forwarding to available upstreams +* Builtin least-connection forwarding to available upstreams. * Per-client rate-limiting, using a token bucket implementation. +## Running + +The proxy comes with a wrapper to run it, with hardcoded configuration you can change for your needs. + +1. First, modify `cmd/server/main.go`, making adjustments to the configuration as you see fit. +2. Run `make` to build the binary, which will be output to the current directory as `server`. +3. Run the server with `./server`. + +### Running sample upstreams + +If you want to run with some sample upstreams (nginx), just launch the docker compose file. The `server` is already +configured to point to these. + + docker-compose up + ## Development In order to build and develop the proxy, you should have Go installed and available. diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..e5c3660 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "fmt" + "log/slog" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/joshbranham/tcp-proxy/pkg/tcpproxy" +) + +func main() { + logger := slog.Default() + + // TODO: Configure the proxy here. With more time, using a configuration file (YAML etc) and/or CLI arguments + // would be a better approach. + targets := []string{"localhost:9000", "localhost:9001", "localhost:9002"} + loadBalancer, _ := tcpproxy.NewLeastConnectionBalancer(targets) + config := &tcpproxy.Config{ + LoadBalancer: loadBalancer, + ListenerConfig: &tcpproxy.ListenerConfig{ + ListenerAddr: "localhost:5000", + CA: "certificates/ca.pem", + Certificate: "tcp-proxy.pem", + PrivateKey: "tcp-proxy.key", + }, + UpstreamConfig: &tcpproxy.UpstreamConfig{ + Name: "test", + Targets: targets, + + AuthorizedGroups: []string{"engineering"}, + }, + + RateLimitConfig: &tcpproxy.RateLimitConfig{ + Capacity: 10, + FillRate: 5 * time.Second, + }, + Logger: slog.Default(), + } + + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT, syscall.SIGTERM) + + proxy, err := tcpproxy.New(config) + if err != nil { + fmt.Print(err) + os.Exit(1) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + err = proxy.Serve() + if err != nil { + fmt.Print(err) + os.Exit(1) + } + wg.Done() + }() + + <-sigC + logger.Info("shutting down proxy...") + _ = proxy.Close() + wg.Wait() + logger.Info("proxy stopped") +} diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..54a8d75 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,18 @@ +version: '3' + +services: + nginx0: + image: nginx + ports: + - "9000:80" + restart: always + nginx1: + image: nginx + ports: + - "9001:80" + restart: always + nginx2: + image: nginx + ports: + - "9002:80" + restart: always diff --git a/pkg/tcpproxy/config.go b/pkg/tcpproxy/config.go index 7fa1afa..899e392 100644 --- a/pkg/tcpproxy/config.go +++ b/pkg/tcpproxy/config.go @@ -6,6 +6,7 @@ import ( "errors" "log/slog" "os" + "time" ) // Config is the top-level configuration object used to configure a Proxy. @@ -16,6 +17,8 @@ type Config struct { ListenerConfig *ListenerConfig // UpstreamConfig comprises where to route requests as well as which clients are authorized to do so. UpstreamConfig *UpstreamConfig + // RateLimitConfig defines how rate limiting should be configured for clients. + RateLimitConfig *RateLimitConfig // Logger is a slog.Logger used for logging proxy activities to stdout. Logger *slog.Logger @@ -43,6 +46,15 @@ type UpstreamConfig struct { AuthorizedGroups []string } +// RateLimitConfig is the configuration for the built-in token bucket rate limiting implementation. These settings +// are applied on a per-client basis. +type RateLimitConfig struct { + // Capacity is the maximum tokens the bucket can have. + Capacity int + // FillRate is how often 1 token is added to the bucket + FillRate time.Duration +} + // Validate confirms a given Config has all required fields set. func (c *Config) Validate() error { if c.ListenerConfig == nil { @@ -54,6 +66,9 @@ func (c *Config) Validate() error { if c.UpstreamConfig == nil { return errors.New("config does not contain a UpstreamConfig") } + if c.RateLimitConfig == nil { + return errors.New("config does not contain a RateLimitConfig") + } if c.Logger == nil { return errors.New("config does not contain a Logger") } diff --git a/pkg/tcpproxy/proxy_test.go b/pkg/tcpproxy/proxy_test.go index bde3de7..6c0abe0 100644 --- a/pkg/tcpproxy/proxy_test.go +++ b/pkg/tcpproxy/proxy_test.go @@ -100,6 +100,10 @@ func setupTestProxy(t *testing.T, target string, authorizedGroup string) *Proxy AuthorizedGroups: []string{authorizedGroup}, }, + RateLimitConfig: &RateLimitConfig{ + Capacity: 10, + FillRate: 10 * time.Second, + }, Logger: slog.Default(), } proxy, err := New(config) diff --git a/pkg/tcpproxy/rate_limit_manager.go b/pkg/tcpproxy/rate_limit_manager.go new file mode 100644 index 0000000..dd315c6 --- /dev/null +++ b/pkg/tcpproxy/rate_limit_manager.go @@ -0,0 +1,57 @@ +package tcpproxy + +import ( + "log/slog" + "sync" + "time" +) + +// RateLimitManager wraps many RateLimiters and provides mechanisms for getting per-client RateLimiters. +type RateLimitManager struct { + defaultCapacity int + defaultFillRate time.Duration + logger *slog.Logger + rateLimiters map[string]*RateLimiter + mutex sync.RWMutex +} + +// NewRateLimitManager returns a configured RateLimitManager. +func NewRateLimitManager(capacity int, fillRate time.Duration, logger *slog.Logger) *RateLimitManager { + return &RateLimitManager{ + defaultCapacity: capacity, + defaultFillRate: fillRate, + logger: logger, + rateLimiters: make(map[string]*RateLimiter), + mutex: sync.RWMutex{}, + } +} + +// RateLimiterFor returns, or creates, a RateLimiter for a given client string. +func (r *RateLimitManager) RateLimiterFor(client string) *RateLimiter { + var rateLimiter *RateLimiter + + r.mutex.Lock() + if r.rateLimiters[client] == nil { + rateLimiter = NewRateLimiter(int64(r.defaultCapacity), r.defaultFillRate) + r.rateLimiters[client] = rateLimiter + } else { + rateLimiter = r.rateLimiters[client] + } + r.mutex.Unlock() + + return rateLimiter +} + +// Close calls Close() on all known RateLimiters. RateLimiters can only be closed once, however this +// func will handle if a RateLimiter is already closed. +func (r *RateLimitManager) Close() { + r.mutex.RLock() + for _, rateLimiter := range r.rateLimiters { + if rateLimiter != nil { + if err := rateLimiter.Close(); err != nil { + r.logger.Warn("error closing rate limiter", "error", err) + } + } + } + r.mutex.RUnlock() +} diff --git a/pkg/tcpproxy/rate_limit_manager_test.go b/pkg/tcpproxy/rate_limit_manager_test.go new file mode 100644 index 0000000..2ad006d --- /dev/null +++ b/pkg/tcpproxy/rate_limit_manager_test.go @@ -0,0 +1,19 @@ +package tcpproxy + +import ( + "log/slog" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewRateLimitManager(t *testing.T) { + rlm := NewRateLimitManager(10, 1*time.Millisecond, slog.Default()) + + user1 := rlm.RateLimiterFor("user1") + assert.Equal(t, user1, rlm.RateLimiterFor("user1")) + + // Ensure we clean up goroutines for the manager and any child RateLimiters + rlm.Close() +} diff --git a/pkg/tcpproxy/rate_limiter.go b/pkg/tcpproxy/rate_limiter.go new file mode 100644 index 0000000..506a71b --- /dev/null +++ b/pkg/tcpproxy/rate_limiter.go @@ -0,0 +1,80 @@ +package tcpproxy + +import ( + "errors" + "sync" + "sync/atomic" + "time" +) + +// tokenFillRate is the amount of tokens added to the bucket at the given FillRate. +const tokenFillRate = 1 + +// RateLimiter is an instance of rate limiting configuration, used for a single client. +type RateLimiter struct { + capacity int64 + fillRate time.Duration + tokens atomic.Int64 + closed atomic.Bool + + shutdownC chan struct{} + wg sync.WaitGroup +} + +// NewRateLimiter returns a RateLimiter and spawns a goroutine to add tokens up until the capacity. +// Use Close() to cleanup the goroutine and stop token accumulation. +func NewRateLimiter(capacity int64, fillRate time.Duration) *RateLimiter { + rl := &RateLimiter{ + capacity: capacity, + fillRate: fillRate, + + shutdownC: make(chan struct{}), + } + rl.tokens.Add(capacity) + + rl.wg.Add(1) + go rl.fillTokens() + + return rl +} + +// Close stops the accumulation of tokens for the RateLimiter. +func (r *RateLimiter) Close() error { + if r.closed.Load() { + return errors.New("RateLimiter is already closed") + } else { + close(r.shutdownC) + r.closed.Store(true) + r.wg.Wait() + } + + return nil +} + +// ConnectionAllowed validates the RateLimiter isn't at the limit. If allowed, this returns true and decrements +// the tokens by 1. If false, it returns false and leaves the tokens as is. +func (r *RateLimiter) ConnectionAllowed() bool { + if r.tokens.Load() > 0 { + r.tokens.Add(-1) + return true + } + + return false +} + +func (r *RateLimiter) fillTokens() { + ticker := time.NewTicker(r.fillRate) + for { + select { + case <-r.shutdownC: + ticker.Stop() + r.wg.Done() + return + case <-ticker.C: + tokens := r.tokens.Load() + if tokens != 0 && tokens < r.capacity { + r.tokens.Add(tokenFillRate) + } + } + } +} diff --git a/pkg/tcpproxy/rate_limiter_test.go b/pkg/tcpproxy/rate_limiter_test.go new file mode 100644 index 0000000..9c58872 --- /dev/null +++ b/pkg/tcpproxy/rate_limiter_test.go @@ -0,0 +1,25 @@ +package tcpproxy + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRateLimiter_ConnectionAllowed(t *testing.T) { + // TODO: Set a FillRate that exceeds test runtime. In the future, we could use + // an interface and mock add/remove tokens from the bucket without the ticker. + rl := NewRateLimiter(1, 1*time.Minute) + + // Allowed immediately at creation with 1 token + assert.Equal(t, true, rl.ConnectionAllowed()) + + // This brings capacity to zero, so connection not allowed + assert.Equal(t, false, rl.ConnectionAllowed()) + + // Ensure we can't double close a RateLimiter + require.NoError(t, rl.Close()) + assert.Error(t, rl.Close()) +}