Skip to content

Commit

Permalink
Implement a Token Bucket rate limiter (#4)
Browse files Browse the repository at this point in the history
* Implement a Token Bucket rate limiter

* Misc cleanup, comments

* RateLimitManager gracefully handles individual RateLimiters being closed already.
  • Loading branch information
joshbranham authored Mar 8, 2024
1 parent 07b0fa1 commit f1d7cc6
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 0 deletions.
15 changes: 15 additions & 0 deletions pkg/tcpproxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tcpproxy
import (
"errors"
"log/slog"
"time"
)

// Config is the top-level configuration object used to configure a Proxy.
Expand All @@ -13,6 +14,8 @@ type Config struct {
ListenerConfig *ListenerConfig
// UpstreamConfig comprises where to route requests as well as which clients are authorized to do so.
UpstreamConfig *UpstreamConfig
// RateLimitConfig defines how rate limiting should be configured for clients.
RateLimitConfig *RateLimitConfig

// Logger is a slog.Logger used for logging proxy activities to stdout.
Logger *slog.Logger
Expand Down Expand Up @@ -41,6 +44,15 @@ type UpstreamConfig struct {
AuthorizedGroups []string
}

// RateLimitConfig is the configuration for the built-in token bucket rate limiting implementation. These settings
// are applied on a per-client basis.
type RateLimitConfig struct {
// Capacity is the maximum tokens the bucket can have.
Capacity int
// FillRate is how often 1 token is added to the bucket
FillRate time.Duration
}

// Validate confirms a given Config has all required fields set.
func (c *Config) Validate() error {
if c.ListenerConfig == nil {
Expand All @@ -52,6 +64,9 @@ func (c *Config) Validate() error {
if c.UpstreamConfig == nil {
return errors.New("config does not contain a UpstreamConfig")
}
if c.RateLimitConfig == nil {
return errors.New("config does not contain a RateLimitConfig")
}
if c.Logger == nil {
return errors.New("config does not contain a Logger")
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/tcpproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func setupTestProxy(t *testing.T, target string) *Proxy {
// TODO: Not implemented yet
AuthorizedGroups: []string{""},
},
RateLimitConfig: &RateLimitConfig{
Capacity: 10,
FillRate: 5 * time.Second,
},
Logger: slog.Default(),
}
proxy, err := New(config)
Expand Down
57 changes: 57 additions & 0 deletions pkg/tcpproxy/rate_limit_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package tcpproxy

import (
"log/slog"
"sync"
"time"
)

// RateLimitManager wraps many RateLimiters and provides mechanisms for getting per-client RateLimiters.
type RateLimitManager struct {
defaultCapacity int
defaultFillRate time.Duration
logger *slog.Logger
rateLimiters map[string]*RateLimiter
mutex sync.RWMutex
}

// NewRateLimitManager returns a configured RateLimitManager.
func NewRateLimitManager(capacity int, fillRate time.Duration, logger *slog.Logger) *RateLimitManager {
return &RateLimitManager{
defaultCapacity: capacity,
defaultFillRate: fillRate,
logger: logger,
rateLimiters: make(map[string]*RateLimiter),
mutex: sync.RWMutex{},
}
}

// RateLimiterFor returns, or creates, a RateLimiter for a given client string.
func (r *RateLimitManager) RateLimiterFor(client string) *RateLimiter {
var rateLimiter *RateLimiter

r.mutex.Lock()
if r.rateLimiters[client] == nil {
rateLimiter = NewRateLimiter(int64(r.defaultCapacity), r.defaultFillRate)
r.rateLimiters[client] = rateLimiter
} else {
rateLimiter = r.rateLimiters[client]
}
r.mutex.Unlock()

return rateLimiter
}

// Close calls Close() on all known RateLimiters. RateLimiters can only be closed once, however this
// func will handle if a RateLimiter is already closed.
func (r *RateLimitManager) Close() {
r.mutex.RLock()
for _, rateLimiter := range r.rateLimiters {
if rateLimiter != nil {
if err := rateLimiter.Close(); err != nil {
r.logger.Warn("error closing rate limiter", "error", err)
}
}
}
r.mutex.RUnlock()
}
19 changes: 19 additions & 0 deletions pkg/tcpproxy/rate_limit_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package tcpproxy

import (
"log/slog"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestNewRateLimitManager(t *testing.T) {
rlm := NewRateLimitManager(10, 1*time.Millisecond, slog.Default())

user1 := rlm.RateLimiterFor("user1")
assert.Equal(t, user1, rlm.RateLimiterFor("user1"))

// Ensure we clean up goroutines for the manager and any child RateLimiters
rlm.Close()
}
80 changes: 80 additions & 0 deletions pkg/tcpproxy/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package tcpproxy

import (
"errors"
"sync"
"sync/atomic"
"time"
)

// tokenFillRate is the amount of tokens added to the bucket at the given FillRate.
const tokenFillRate = 1

// RateLimiter is an instance of rate limiting configuration, used for a single client.
type RateLimiter struct {
capacity int64
fillRate time.Duration
tokens atomic.Int64
closed atomic.Bool

shutdownC chan struct{}
wg sync.WaitGroup
}

// NewRateLimiter returns a RateLimiter and spawns a goroutine to add tokens up until the capacity.
// Use Close() to cleanup the goroutine and stop token accumulation.
func NewRateLimiter(capacity int64, fillRate time.Duration) *RateLimiter {
rl := &RateLimiter{
capacity: capacity,
fillRate: fillRate,

shutdownC: make(chan struct{}),
}
rl.tokens.Add(capacity)

rl.wg.Add(1)
go rl.fillTokens()

return rl
}

// Close stops the accumulation of tokens for the RateLimiter.
func (r *RateLimiter) Close() error {
if r.closed.Load() {
return errors.New("RateLimiter is already closed")
} else {
close(r.shutdownC)
r.closed.Store(true)
r.wg.Wait()
}

return nil
}

// ConnectionAllowed validates the RateLimiter isn't at the limit. If allowed, this returns true and decrements
// the tokens by 1. If false, it returns false and leaves the tokens as is.
func (r *RateLimiter) ConnectionAllowed() bool {
if r.tokens.Load() > 0 {
r.tokens.Add(-1)
return true
}

return false
}

func (r *RateLimiter) fillTokens() {
ticker := time.NewTicker(r.fillRate)
for {
select {
case <-r.shutdownC:
ticker.Stop()
r.wg.Done()
return
case <-ticker.C:
tokens := r.tokens.Load()
if tokens != 0 && tokens < r.capacity {
r.tokens.Add(tokenFillRate)
}
}
}
}
25 changes: 25 additions & 0 deletions pkg/tcpproxy/rate_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package tcpproxy

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRateLimiter_ConnectionAllowed(t *testing.T) {
// TODO: Set a FillRate that exceeds test runtime. In the future, we could use
// an interface and mock add/remove tokens from the bucket without the ticker.
rl := NewRateLimiter(1, 1*time.Minute)

// Allowed immediately at creation with 1 token
assert.Equal(t, true, rl.ConnectionAllowed())

// This brings capacity to zero, so connection not allowed
assert.Equal(t, false, rl.ConnectionAllowed())

// Ensure we can't double close a RateLimiter
require.NoError(t, rl.Close())
assert.Error(t, rl.Close())
}

0 comments on commit f1d7cc6

Please sign in to comment.