From 3981234bfbcc6a71b6ba6c3b302b631efba87cde Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Tue, 19 Oct 2021 21:47:09 +0800 Subject: [PATCH] set default rate and capacity for ratelimit filter --- CHANGELOG.md | 1 + server/web/filter/ratelimit/limiter.go | 37 ++++++++----------- server/web/filter/ratelimit/token_bucket.go | 4 +- .../web/filter/ratelimit/token_bucket_test.go | 8 ++-- server/web/filter/session/filter.go | 1 + 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5bdf27d..18231a90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,6 +67,7 @@ - Fix 4698: Prompt error when config format is incorrect. [4757](https://github.com/beego/beego/pull/4757) - Fix 4674: Tx Orm missing debug log [4756](https://github.com/beego/beego/pull/4756) - Fix 4759: fix numeric notation of permissions [4759](https://github.com/beego/beego/pull/4759) +- set default rate and capacity for ratelimit filter [4796](https://github.com/beego/beego/pull/4796) ## Fix Sonar - [4677](https://github.com/beego/beego/pull/4677) diff --git a/server/web/filter/ratelimit/limiter.go b/server/web/filter/ratelimit/limiter.go index 5b64b5dd..e0aac5c1 100644 --- a/server/web/filter/ratelimit/limiter.go +++ b/server/web/filter/ratelimit/limiter.go @@ -15,7 +15,6 @@ package ratelimit import ( - "net/http" "sync" "time" @@ -23,11 +22,6 @@ import ( "github.com/beego/beego/v2/server/web/context" ) -// Limiter is an interface used to ratelimit -type Limiter interface { - take(amount uint, r *http.Request) bool -} - // limiterOption is constructor option type limiterOption func(l *limiter) @@ -37,7 +31,7 @@ type limiter struct { rate time.Duration buckets map[string]bucket bucketFactory func(opts ...bucketOption) bucket - sessionKey func(r *http.Request) string + sessionKey func(ctx *context.Context) string resp RejectionResponse } @@ -60,10 +54,10 @@ var defaultRejectionResponse = RejectionResponse{ func NewLimiter(opts ...limiterOption) web.FilterFunc { l := &limiter{ buckets: make(map[string]bucket), - sessionKey: func(r *http.Request) string { - return defaultSessionKey(r) - }, - bucketFactory: NewTokenBucket, + sessionKey: defaultSessionKey, + rate: time.Millisecond * 10, + capacity: 100, + bucketFactory: newTokenBucket, resp: defaultRejectionResponse, } for _, o := range opts { @@ -71,7 +65,7 @@ func NewLimiter(opts ...limiterOption) web.FilterFunc { } return func(ctx *context.Context) { - if !l.take(perRequestConsumedAmount, ctx.Request) { + if !l.take(perRequestConsumedAmount, ctx) { ctx.ResponseWriter.WriteHeader(l.resp.code) ctx.WriteString(l.resp.body) } @@ -79,8 +73,8 @@ func NewLimiter(opts ...limiterOption) web.FilterFunc { } // WithSessionKey return limiterOption. WithSessionKey config func -// which defines the request characteristic againstthe limit is applied -func WithSessionKey(f func(r *http.Request) string) limiterOption { +// which defines the request characteristic against the limit is applied +func WithSessionKey(f func(ctx *context.Context) string) limiterOption { return func(l *limiter) { l.sessionKey = f } @@ -119,16 +113,16 @@ func WithRejectionResponse(resp RejectionResponse) limiterOption { } } -func (l *limiter) take(amount uint, r *http.Request) bool { - bucket := l.getBucket(r) +func (l *limiter) take(amount uint, ctx *context.Context) bool { + bucket := l.getBucket(ctx) if bucket == nil { return true } return bucket.take(amount) } -func (l *limiter) getBucket(r *http.Request) bucket { - key := l.sessionKey(r) +func (l *limiter) getBucket(ctx *context.Context) bucket { + key := l.sessionKey(ctx) l.RLock() b, ok := l.buckets[key] l.RUnlock() @@ -152,11 +146,12 @@ func (l *limiter) createBucket(key string) bucket { return b } -func defaultSessionKey(r *http.Request) string { - return "" +func defaultSessionKey(ctx *context.Context) string { + return "BEEGO_ALL" } -func RemoteIPSessionKey(r *http.Request) string { +func RemoteIPSessionKey(ctx *context.Context) string { + r := ctx.Request IPAddress := r.Header.Get("X-Real-Ip") if IPAddress == "" { IPAddress = r.Header.Get("X-Forwarded-For") diff --git a/server/web/filter/ratelimit/token_bucket.go b/server/web/filter/ratelimit/token_bucket.go index 5906ee9e..da7bb7fc 100644 --- a/server/web/filter/ratelimit/token_bucket.go +++ b/server/web/filter/ratelimit/token_bucket.go @@ -13,8 +13,8 @@ type tokenBucket struct { rate time.Duration } -// NewTokenBucket return an bucket that implements token bucket -func NewTokenBucket(opts ...bucketOption) bucket { +// newTokenBucket return an bucket that implements token bucket +func newTokenBucket(opts ...bucketOption) bucket { b := &tokenBucket{lastCheckAt: time.Now()} for _, o := range opts { o(b) diff --git a/server/web/filter/ratelimit/token_bucket_test.go b/server/web/filter/ratelimit/token_bucket_test.go index 93a1b3bd..91088eb6 100644 --- a/server/web/filter/ratelimit/token_bucket_test.go +++ b/server/web/filter/ratelimit/token_bucket_test.go @@ -8,24 +8,24 @@ import ( ) func TestGetRate(t *testing.T) { - b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket) + b := newTokenBucket(withRate(1 * time.Second)).(*tokenBucket) assert.Equal(t, b.getRate(), 1*time.Second) } func TestGetRemainingAndCapacity(t *testing.T) { - b := NewTokenBucket(withCapacity(10)) + b := newTokenBucket(withCapacity(10)) assert.Equal(t, b.getRemaining(), uint(10)) assert.Equal(t, b.getCapacity(), uint(10)) } func TestTake(t *testing.T) { - b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) + b := newTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) for i := 0; i < 10; i++ { assert.True(t, b.take(1)) } assert.False(t, b.take(1)) assert.Equal(t, b.getRemaining(), uint(0)) - b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) + b = newTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) assert.True(t, b.take(1)) time.Sleep(2 * time.Millisecond) assert.True(t, b.take(1)) diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index 40f3e198..b3bb99c5 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -11,6 +11,7 @@ import ( // Session maintain session for web service // Session new a session storage and store it into webContext.Context +// experimental feature, we may change this in the future func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { sessionConfig := session.NewManagerConfig(options...) sessionManager, _ := session.NewManager(string(providerType), sessionConfig)