From 493153939b5fc86b21251a7d6e212c3c82c34349 Mon Sep 17 00:00:00 2001 From: anoymouscoder <809532742@qq.com> Date: Mon, 22 Feb 2021 18:50:11 +0800 Subject: [PATCH] Feat: add token bucket ratelimit filter --- CHANGELOG.md | 2 + server/web/filter/ratelimit/bucket.go | 14 ++ server/web/filter/ratelimit/limiter.go | 169 ++++++++++++++++++ server/web/filter/ratelimit/limiter_test.go | 76 ++++++++ server/web/filter/ratelimit/token_bucket.go | 76 ++++++++ .../web/filter/ratelimit/token_bucket_test.go | 32 ++++ 6 files changed, 369 insertions(+) create mode 100644 server/web/filter/ratelimit/bucket.go create mode 100644 server/web/filter/ratelimit/limiter.go create mode 100644 server/web/filter/ratelimit/limiter_test.go create mode 100644 server/web/filter/ratelimit/token_bucket.go create mode 100644 server/web/filter/ratelimit/token_bucket_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index b2956256..35659217 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ - Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440) - fix: code quality issues [4513](https://github.com/beego/beego/pull/4513) - Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525) +- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508) + ## Fix Sonar diff --git a/server/web/filter/ratelimit/bucket.go b/server/web/filter/ratelimit/bucket.go new file mode 100644 index 00000000..67a3907e --- /dev/null +++ b/server/web/filter/ratelimit/bucket.go @@ -0,0 +1,14 @@ +package ratelimit + +import "time" + +// bucket is an interface store ratelimit info +type bucket interface { + take(amount uint) bool + getCapacity() uint + getRemaining() uint + getRate() time.Duration +} + +// bucketOption is constructor option +type bucketOption func(bucket) diff --git a/server/web/filter/ratelimit/limiter.go b/server/web/filter/ratelimit/limiter.go new file mode 100644 index 00000000..c7f156bf --- /dev/null +++ b/server/web/filter/ratelimit/limiter.go @@ -0,0 +1,169 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +// Limiter is an interface used to ratelimit +type Limiter interface { + take(amount uint, r *http.Request) bool +} + +// limiterOption is constructor option +type limiterOption func(l *limiter) + +type limiter struct { + sync.RWMutex + capacity uint + rate time.Duration + buckets map[string]bucket + bucketFactory func(opts ...bucketOption) bucket + sessionKey func(r *http.Request) string + resp RejectionResponse +} + +// RejectionResponse stores response information +// for the request rejected by limiter +type RejectionResponse struct { + code int + body string +} + +const perRequestConsumedAmount = 1 + +var defaultRejectionResponse = RejectionResponse{ + code: 429, + body: "too many requests", +} + +// NewLimiter return FilterFunc, the limiter enables rate limit +// according to the configuration. +func NewLimiter(opts ...limiterOption) web.FilterFunc { + l := &limiter{ + buckets: make(map[string]bucket), + sessionKey: func(r *http.Request) string { + return defaultSessionKey(r) + }, + bucketFactory: NewTokenBucket, + resp: defaultRejectionResponse, + } + for _, o := range opts { + o(l) + } + + return func(ctx *context.Context) { + if !l.take(perRequestConsumedAmount, ctx.Request) { + ctx.ResponseWriter.WriteHeader(l.resp.code) + ctx.WriteString(l.resp.body) + } + } +} + +// WithSessionKey return limiterOption. WithSessionKey config func +// which defines the request characteristic againstthe limit is applied +func WithSessionKey(f func(r *http.Request) string) limiterOption { + return func(l *limiter) { + l.sessionKey = f + } +} + +// WithRate return limiterOption. WithRate config how long it takes to +// generate a token. +func WithRate(r time.Duration) limiterOption { + return func(l *limiter) { + l.rate = r + } +} + +// WithCapacity return limiterOption. WithCapacity config the capacity size. +// The bucket with a capacity of n has n tokens after initialization. The capacity +// defines how many requests a client can make in excess of the rate. +func WithCapacity(c uint) limiterOption { + return func(l *limiter) { + l.capacity = c + } +} + +// WithBucketFactory return limiterOption. WithBucketFactory customize the +// implementation of Bucket. +func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption { + return func(l *limiter) { + l.bucketFactory = f + } +} + +// WithRejectionResponse return limiterOption. WithRejectionResponse +// customize the response for the request rejected by the limiter. +func WithRejectionResponse(resp RejectionResponse) limiterOption { + return func(l *limiter) { + l.resp = resp + } +} + +func (l *limiter) take(amount uint, r *http.Request) bool { + bucket := l.getBucket(r) + if bucket == nil { + return true + } + return bucket.take(amount) +} + +func (l *limiter) getBucket(r *http.Request) bucket { + key := l.sessionKey(r) + l.RLock() + b, ok := l.buckets[key] + l.RUnlock() + if !ok { + b = l.createBucket(key) + } + + return b +} + +func (l *limiter) createBucket(key string) bucket { + l.Lock() + defer l.Unlock() + // double check avoid overwriting + b, ok := l.buckets[key] + if ok { + return b + } + b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate)) + l.buckets[key] = b + return b +} + + +func defaultSessionKey(r *http.Request) string { + return "" +} + +func RemoteIPSessionKey(r *http.Request) string { + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + } + return IPAddress +} diff --git a/server/web/filter/ratelimit/limiter_test.go b/server/web/filter/ratelimit/limiter_test.go new file mode 100644 index 00000000..cafede5e --- /dev/null +++ b/server/web/filter/ratelimit/limiter_test.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.Header.Set("X-Real-Ip", requestIP) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code) + } +} + +func TestLimiter(t *testing.T) { + handler := web.NewControllerRegister() + err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey))) + if err != nil { + t.Error(err) + } + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + route := "/foo/1" + ip := "127.0.0.1" + testRequest(t, handler, ip, "GET", route, 200) + testRequest(t, handler, ip, "GET", route, 429) + testRequest(t, handler, "127.0.0.2", "GET", route, 200) + time.Sleep(1 * time.Millisecond) + testRequest(t, handler, ip, "GET", route, 200) +} + +func BenchmarkWithoutLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} + +func BenchmarkWithLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100))) + if err != nil { + b.Error(err) + } + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} diff --git a/server/web/filter/ratelimit/token_bucket.go b/server/web/filter/ratelimit/token_bucket.go new file mode 100644 index 00000000..5906ee9e --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "sync" + "time" +) + +type tokenBucket struct { + sync.RWMutex + remaining uint + capacity uint + lastCheckAt time.Time + rate time.Duration +} + +// NewTokenBucket return an bucket that implements token bucket +func NewTokenBucket(opts ...bucketOption) bucket { + b := &tokenBucket{lastCheckAt: time.Now()} + for _, o := range opts { + o(b) + } + return b +} + +func withCapacity(capacity uint) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.capacity = capacity + bucket.remaining = capacity + } +} + +func withRate(rate time.Duration) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.rate = rate + } +} + +func (b *tokenBucket) getRemaining() uint { + b.RLock() + defer b.RUnlock() + return b.remaining +} + +func (b *tokenBucket) getRate() time.Duration { + b.RLock() + defer b.RUnlock() + return b.rate +} + +func (b *tokenBucket) getCapacity() uint { + b.RLock() + defer b.RUnlock() + return b.capacity +} + +func (b *tokenBucket) take(amount uint) bool { + if b.rate <= 0 { + return true + } + b.Lock() + defer b.Unlock() + now := time.Now() + times := uint(now.Sub(b.lastCheckAt) / b.rate) + b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate) + b.remaining += times + if b.remaining < amount { + return false + } + b.remaining -= amount + if b.remaining > b.capacity { + b.remaining = b.capacity + } + return true +} diff --git a/server/web/filter/ratelimit/token_bucket_test.go b/server/web/filter/ratelimit/token_bucket_test.go new file mode 100644 index 00000000..93a1b3bd --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket_test.go @@ -0,0 +1,32 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetRate(t *testing.T) { + b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket) + assert.Equal(t, b.getRate(), 1*time.Second) +} + +func TestGetRemainingAndCapacity(t *testing.T) { + b := NewTokenBucket(withCapacity(10)) + assert.Equal(t, b.getRemaining(), uint(10)) + assert.Equal(t, b.getCapacity(), uint(10)) +} + +func TestTake(t *testing.T) { + b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) + for i := 0; i < 10; i++ { + assert.True(t, b.take(1)) + } + assert.False(t, b.take(1)) + assert.Equal(t, b.getRemaining(), uint(0)) + b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) + assert.True(t, b.take(1)) + time.Sleep(2 * time.Millisecond) + assert.True(t, b.take(1)) +}