-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a Token Bucket rate limiter (#4)
* Implement a Token Bucket rate limiter * Misc cleanup, comments * RateLimitManager gracefully handles individual RateLimiters being closed already.
- Loading branch information
1 parent
07b0fa1
commit f1d7cc6
Showing
6 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package tcpproxy | ||
|
||
import ( | ||
"log/slog" | ||
"sync" | ||
"time" | ||
) | ||
|
||
// RateLimitManager wraps many RateLimiters and provides mechanisms for getting per-client RateLimiters. | ||
type RateLimitManager struct { | ||
defaultCapacity int | ||
defaultFillRate time.Duration | ||
logger *slog.Logger | ||
rateLimiters map[string]*RateLimiter | ||
mutex sync.RWMutex | ||
} | ||
|
||
// NewRateLimitManager returns a configured RateLimitManager. | ||
func NewRateLimitManager(capacity int, fillRate time.Duration, logger *slog.Logger) *RateLimitManager { | ||
return &RateLimitManager{ | ||
defaultCapacity: capacity, | ||
defaultFillRate: fillRate, | ||
logger: logger, | ||
rateLimiters: make(map[string]*RateLimiter), | ||
mutex: sync.RWMutex{}, | ||
} | ||
} | ||
|
||
// RateLimiterFor returns, or creates, a RateLimiter for a given client string. | ||
func (r *RateLimitManager) RateLimiterFor(client string) *RateLimiter { | ||
var rateLimiter *RateLimiter | ||
|
||
r.mutex.Lock() | ||
if r.rateLimiters[client] == nil { | ||
rateLimiter = NewRateLimiter(int64(r.defaultCapacity), r.defaultFillRate) | ||
r.rateLimiters[client] = rateLimiter | ||
} else { | ||
rateLimiter = r.rateLimiters[client] | ||
} | ||
r.mutex.Unlock() | ||
|
||
return rateLimiter | ||
} | ||
|
||
// Close calls Close() on all known RateLimiters. RateLimiters can only be closed once, however this | ||
// func will handle if a RateLimiter is already closed. | ||
func (r *RateLimitManager) Close() { | ||
r.mutex.RLock() | ||
for _, rateLimiter := range r.rateLimiters { | ||
if rateLimiter != nil { | ||
if err := rateLimiter.Close(); err != nil { | ||
r.logger.Warn("error closing rate limiter", "error", err) | ||
} | ||
} | ||
} | ||
r.mutex.RUnlock() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package tcpproxy | ||
|
||
import ( | ||
"log/slog" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestNewRateLimitManager(t *testing.T) { | ||
rlm := NewRateLimitManager(10, 1*time.Millisecond, slog.Default()) | ||
|
||
user1 := rlm.RateLimiterFor("user1") | ||
assert.Equal(t, user1, rlm.RateLimiterFor("user1")) | ||
|
||
// Ensure we clean up goroutines for the manager and any child RateLimiters | ||
rlm.Close() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package tcpproxy | ||
|
||
import ( | ||
"errors" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
// tokenFillRate is the amount of tokens added to the bucket at the given FillRate. | ||
const tokenFillRate = 1 | ||
|
||
// RateLimiter is an instance of rate limiting configuration, used for a single client. | ||
type RateLimiter struct { | ||
capacity int64 | ||
fillRate time.Duration | ||
tokens atomic.Int64 | ||
closed atomic.Bool | ||
|
||
shutdownC chan struct{} | ||
wg sync.WaitGroup | ||
} | ||
|
||
// NewRateLimiter returns a RateLimiter and spawns a goroutine to add tokens up until the capacity. | ||
// Use Close() to cleanup the goroutine and stop token accumulation. | ||
func NewRateLimiter(capacity int64, fillRate time.Duration) *RateLimiter { | ||
rl := &RateLimiter{ | ||
capacity: capacity, | ||
fillRate: fillRate, | ||
|
||
shutdownC: make(chan struct{}), | ||
} | ||
rl.tokens.Add(capacity) | ||
|
||
rl.wg.Add(1) | ||
go rl.fillTokens() | ||
|
||
return rl | ||
} | ||
|
||
// Close stops the accumulation of tokens for the RateLimiter. | ||
func (r *RateLimiter) Close() error { | ||
if r.closed.Load() { | ||
return errors.New("RateLimiter is already closed") | ||
} else { | ||
close(r.shutdownC) | ||
r.closed.Store(true) | ||
r.wg.Wait() | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// ConnectionAllowed validates the RateLimiter isn't at the limit. If allowed, this returns true and decrements | ||
// the tokens by 1. If false, it returns false and leaves the tokens as is. | ||
func (r *RateLimiter) ConnectionAllowed() bool { | ||
if r.tokens.Load() > 0 { | ||
r.tokens.Add(-1) | ||
return true | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (r *RateLimiter) fillTokens() { | ||
ticker := time.NewTicker(r.fillRate) | ||
for { | ||
select { | ||
case <-r.shutdownC: | ||
ticker.Stop() | ||
r.wg.Done() | ||
return | ||
case <-ticker.C: | ||
tokens := r.tokens.Load() | ||
if tokens != 0 && tokens < r.capacity { | ||
r.tokens.Add(tokenFillRate) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package tcpproxy | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestRateLimiter_ConnectionAllowed(t *testing.T) { | ||
// TODO: Set a FillRate that exceeds test runtime. In the future, we could use | ||
// an interface and mock add/remove tokens from the bucket without the ticker. | ||
rl := NewRateLimiter(1, 1*time.Minute) | ||
|
||
// Allowed immediately at creation with 1 token | ||
assert.Equal(t, true, rl.ConnectionAllowed()) | ||
|
||
// This brings capacity to zero, so connection not allowed | ||
assert.Equal(t, false, rl.ConnectionAllowed()) | ||
|
||
// Ensure we can't double close a RateLimiter | ||
require.NoError(t, rl.Close()) | ||
assert.Error(t, rl.Close()) | ||
} |