Skip to content

Commit

Permalink
Split interface Limiter and port default Counter from k1LoW/rlutil
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Dec 16, 2023
1 parent 372cdb9 commit be4c5ee
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 51 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ jobs:
with:
go-version-file: go.mod

- name: Run go mod tidy for gostyle
run: go mod tidy

- name: Run lint
uses: reviewdog/action-golangci-lint@v2
with:
Expand Down
8 changes: 2 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@ ci: depsdev test

test:
cp go.mod testdata/go_test.mod
go mod tidy -modfile=testdata/go_test.mod
go test ./... -modfile=testdata/go_test.mod -coverprofile=coverage.out -covermode=count
go test ./... -coverprofile=coverage.out -covermode=count

benchmark: depsdev
go mod tidy -modfile=testdata/go_test.mod
go test -modfile=testdata/go_test.mod -bench . -benchmem -benchtime 10000x -run Benchmark | octocov-go-test-bench --tee > custom_metrics_benchmark.json
go test -bench . -benchmem -benchtime 10000x -run Benchmark | octocov-go-test-bench --tee > custom_metrics_benchmark.json

cachegrind: depsdev
cd testdata/testbin && go build -o testbin
setarch `uname -m` -R valgrind --tool=cachegrind --cachegrind-out-file=cachegrind.out --I1=32768,8,64 --D1=32768,8,64 --LL=8388608,16,64 ./testdata/testbin/testbin 10000
cat cachegrind.out | octocov-cachegrind --tee > custom_metrics_cachegrind.json

lint:
go mod tidy
golangci-lint run ./...
go vet -vettool=`which gostyle` -gostyle.config=$(PWD)/.gostyle.yml ./...
git restore go.*

depsdev:
go install github.com/Songmu/ghch/cmd/ghch@latest
Expand Down
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var ErrRateLimitExceeded error = errors.New("rate limit exceeded")
type Context struct {
StatusCode int
Err error
Limiter Limiter
Limiter *limiter
RequestLimit int
WindowLen time.Duration
RateLimitRemaining int
Expand Down
82 changes: 82 additions & 0 deletions counter/counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package counter

import (
"fmt"
"sync/atomic"
"time"

"github.com/jellydator/ttlcache/v3"
)

// Counter is a sliding window counter implemented with a TTL cache
type Counter struct {
cache *ttlcache.Cache[string, *uint64]
// capacity is the maximum number of items to store in the cache
capacity uint64
// disableAutoDeleteExpired disables the automatic deletion of expired items
disableAutoDeleteExpired bool
}

type Option func(*Counter)

// WithCapacity sets the maximum number of items to store in the cache
func WithCapacity(capacity uint64) Option {
return func(c *Counter) {
c.capacity = capacity
}
}

// DisableAutoDeleteExpired disables the automatic deletion of expired items
func DisableAutoDeleteExpired() Option {
return func(c *Counter) {
c.disableAutoDeleteExpired = true
}
}

// NewCounter creates a new Counter
func New(windowLen time.Duration, opts ...Option) *Counter {
c := &Counter{}
for _, opt := range opts {
opt(c)
}
ttlOpts := []ttlcache.Option[string, *uint64]{
ttlcache.WithTTL[string, *uint64](windowLen * 2),
}
if c.capacity > 0 {
ttlOpts = append(ttlOpts, ttlcache.WithCapacity[string, *uint64](c.capacity))
}
cache := ttlcache.New[string, *uint64](ttlOpts...)
c.cache = cache
if !c.disableAutoDeleteExpired {
go cache.Start()
}
return c
}

// Get returns the count for the given key and window
func (c *Counter) Get(key string, window time.Time) (int, error) { //nostyle:getters
key = generateKey(key, window)
i := c.cache.Get(key)
if i == nil {
return 0, nil
}
return int(*i.Value()), nil
}

// Increment increments the count for the given key and window
func (c *Counter) Increment(key string, currWindow time.Time) error {
key = generateKey(key, currWindow)
zero := uint64(0)
i, _ := c.cache.GetOrSet(key, &zero)
atomic.AddUint64(i.Value(), 1)
return nil
}

func generateKey(key string, window time.Time) string {
return fmt.Sprintf("%s-%d", key, window.Unix())
}

// DeleteExpired deletes expired items from the cache
func (c *Counter) DeleteExpired() {
c.cache.DeleteExpired()
}
84 changes: 84 additions & 0 deletions counter/counter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package counter

import (
"testing"
"time"

"github.com/jellydator/ttlcache/v3"
"golang.org/x/sync/errgroup"
)

func TestGet(t *testing.T) {
windowLen := 1 * time.Second
window := time.Now().Truncate(windowLen)
key := "test"
c := New(windowLen)
t.Run("Zero value", func(t *testing.T) {
for i := 0; i < 2; i++ {
want := 0
got, err := c.Get(key, window)
if err != nil {
t.Error(err)
}
if got != want {
t.Errorf("Get() = %v, want %v", got, want)
}
}
})

t.Run("Get value", func(t *testing.T) {
want := 1
v := uint64(want)
c.cache.Set(generateKey(key, window), &v, ttlcache.DefaultTTL)
got, err := c.Get(key, window)
if err != nil {
t.Error(err)
}
if got != want {
t.Errorf("Get() = %v, want %v", got, want)
}
})
}

func TestIncrement(t *testing.T) {
windowLen := 1 * time.Millisecond
window := time.Now().Truncate(windowLen)
key := "test"
t.Run("Increment simply", func(t *testing.T) {
c := New(windowLen)
for i := 0; i < 5; i++ {
want := i + 1
if err := c.Increment(key, window); err != nil {
t.Error(err)
}
got, err := c.Get(key, window)
if err != nil {
t.Error(err)
}
if got != want {
t.Errorf("Get() = %v, want %v", got, want)
}
}
})

t.Run("Increment in parallel", func(t *testing.T) {
c := New(windowLen)
want := 1000
eg := new(errgroup.Group)
for i := 0; i < want; i++ {
eg.Go(func() error {
return c.Increment(key, window)
})
}
if err := eg.Wait(); err != nil {
t.Error(err)
}
got, err := c.Get(key, window)
if err != nil {
t.Error(err)
}
if got != want {
t.Errorf("Get() = %v, want %v", got, want)
}
})
}
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@ module github.com/2manymws/rl

go 1.21.0

require golang.org/x/sync v0.5.0
require (
github.com/go-chi/httprate v0.8.0
github.com/jellydator/ttlcache/v3 v3.1.1
golang.org/x/sync v0.5.0
)

require github.com/cespare/xxhash/v2 v2.1.2 // indirect
17 changes: 16 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/httprate v0.8.0 h1:CyKng28yhGnlGXH9EDGC/Qizj29afJQSNW15W/yj34o=
github.com/go-chi/httprate v0.8.0/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
github.com/jellydator/ttlcache/v3 v3.1.1 h1:RCgYJqo3jgvhl+fEWvjNW8thxGWsgxi+TPhRir1Y9y8=
github.com/jellydator/ttlcache/v3 v3.1.1/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
46 changes: 40 additions & 6 deletions rl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/2manymws/rl/counter"
"golang.org/x/sync/errgroup"
)

Expand All @@ -34,18 +35,47 @@ type Limiter interface {
ShouldSetXRateLimitHeaders(*Context) bool
// OnRequestLimit returns the handler to be called when the rate limit is exceeded
OnRequestLimit(*Context) http.HandlerFunc
}

type Counter interface {
// Get returns the current count for the key and window
Get(key string, window time.Time) (count int, err error) //nostyle:getters
// Increment increments the count for the key and window
Increment(key string, currWindow time.Time) error
}

type limiter struct {
Limiter
Get func(key string, window time.Time) (count int, err error) //nostyle:getters
Increment func(key string, currWindow time.Time) error
}

func newLimiter(l Limiter) *limiter {
const defaultWindowLen = 1 * time.Hour
ll := &limiter{
Limiter: l,
}
if c, ok := l.(Counter); ok {
ll.Get = c.Get
ll.Increment = c.Increment
} else {
dl := defaultWindowLen
r, err := ll.Rule(&http.Request{})
if err == nil {
dl = r.WindowLen
}
cc := counter.New(dl)
ll.Get = cc.Get
ll.Increment = cc.Increment
}
return ll
}

type limitHandler struct {
key string
reqLimit int
windowLen time.Duration
limiter Limiter
limiter *limiter
rateLimitRemaining int
rateLimitReset int
mu sync.Mutex
Expand All @@ -69,12 +99,16 @@ func (lh *limitHandler) status(now, currWindow time.Time) (float64, error) {
}

type limitMw struct {
limiters []Limiter
limiters []*limiter
}

func newLimitMw(limiters []Limiter) *limitMw {
var ls []*limiter
for _, l := range limiters {
ls = append(ls, newLimiter(l))
}
return &limitMw{
limiters: limiters,
limiters: ls,
}
}

Expand All @@ -83,8 +117,8 @@ func (lm *limitMw) Handler(next http.Handler) http.Handler {
now := time.Now().UTC()
var lastLH *limitHandler
eg, ctx := errgroup.WithContext(context.Background())
for _, limiter := range lm.limiters {
rule, err := limiter.Rule(r)
for _, l := range lm.limiters {
rule, err := l.Rule(r)
if err != nil {
http.Error(w, err.Error(), http.StatusPreconditionRequired)
return
Expand All @@ -102,7 +136,7 @@ func (lm *limitMw) Handler(next http.Handler) http.Handler {
key: rule.Key,
reqLimit: rule.ReqLimit,
windowLen: rule.WindowLen,
limiter: limiter,
limiter: l,
}
lastLH = lh
eg.Go(func() error {
Expand Down
8 changes: 7 additions & 1 deletion testdata/go_test.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@ module github.com/2manymws/rl

go 1.21.0

require golang.org/x/sync v0.5.0
require (
github.com/go-chi/httprate v0.8.0
github.com/jellydator/ttlcache/v3 v3.1.1
golang.org/x/sync v0.5.0
)

require github.com/cespare/xxhash/v2 v2.1.2 // indirect
6 changes: 0 additions & 6 deletions testdata/go_test.sum

This file was deleted.

Loading

0 comments on commit be4c5ee

Please sign in to comment.