Feat: add token bucket ratelimit filter
This commit is contained in:
parent
b1c9f661c6
commit
493153939b
@ -25,6 +25,8 @@
|
|||||||
- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440)
|
- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440)
|
||||||
- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513)
|
- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513)
|
||||||
- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525)
|
- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525)
|
||||||
|
- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Fix Sonar
|
## Fix Sonar
|
||||||
|
|||||||
14
server/web/filter/ratelimit/bucket.go
Normal file
14
server/web/filter/ratelimit/bucket.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// bucket is an interface store ratelimit info
|
||||||
|
type bucket interface {
|
||||||
|
take(amount uint) bool
|
||||||
|
getCapacity() uint
|
||||||
|
getRemaining() uint
|
||||||
|
getRate() time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// bucketOption is constructor option
|
||||||
|
type bucketOption func(bucket)
|
||||||
169
server/web/filter/ratelimit/limiter.go
Normal file
169
server/web/filter/ratelimit/limiter.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Limiter is an interface used to ratelimit
|
||||||
|
type Limiter interface {
|
||||||
|
take(amount uint, r *http.Request) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// limiterOption is constructor option
|
||||||
|
type limiterOption func(l *limiter)
|
||||||
|
|
||||||
|
type limiter struct {
|
||||||
|
sync.RWMutex
|
||||||
|
capacity uint
|
||||||
|
rate time.Duration
|
||||||
|
buckets map[string]bucket
|
||||||
|
bucketFactory func(opts ...bucketOption) bucket
|
||||||
|
sessionKey func(r *http.Request) string
|
||||||
|
resp RejectionResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectionResponse stores response information
|
||||||
|
// for the request rejected by limiter
|
||||||
|
type RejectionResponse struct {
|
||||||
|
code int
|
||||||
|
body string
|
||||||
|
}
|
||||||
|
|
||||||
|
const perRequestConsumedAmount = 1
|
||||||
|
|
||||||
|
var defaultRejectionResponse = RejectionResponse{
|
||||||
|
code: 429,
|
||||||
|
body: "too many requests",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLimiter return FilterFunc, the limiter enables rate limit
|
||||||
|
// according to the configuration.
|
||||||
|
func NewLimiter(opts ...limiterOption) web.FilterFunc {
|
||||||
|
l := &limiter{
|
||||||
|
buckets: make(map[string]bucket),
|
||||||
|
sessionKey: func(r *http.Request) string {
|
||||||
|
return defaultSessionKey(r)
|
||||||
|
},
|
||||||
|
bucketFactory: NewTokenBucket,
|
||||||
|
resp: defaultRejectionResponse,
|
||||||
|
}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
if !l.take(perRequestConsumedAmount, ctx.Request) {
|
||||||
|
ctx.ResponseWriter.WriteHeader(l.resp.code)
|
||||||
|
ctx.WriteString(l.resp.body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSessionKey return limiterOption. WithSessionKey config func
|
||||||
|
// which defines the request characteristic againstthe limit is applied
|
||||||
|
func WithSessionKey(f func(r *http.Request) string) limiterOption {
|
||||||
|
return func(l *limiter) {
|
||||||
|
l.sessionKey = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRate return limiterOption. WithRate config how long it takes to
|
||||||
|
// generate a token.
|
||||||
|
func WithRate(r time.Duration) limiterOption {
|
||||||
|
return func(l *limiter) {
|
||||||
|
l.rate = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCapacity return limiterOption. WithCapacity config the capacity size.
|
||||||
|
// The bucket with a capacity of n has n tokens after initialization. The capacity
|
||||||
|
// defines how many requests a client can make in excess of the rate.
|
||||||
|
func WithCapacity(c uint) limiterOption {
|
||||||
|
return func(l *limiter) {
|
||||||
|
l.capacity = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBucketFactory return limiterOption. WithBucketFactory customize the
|
||||||
|
// implementation of Bucket.
|
||||||
|
func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption {
|
||||||
|
return func(l *limiter) {
|
||||||
|
l.bucketFactory = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRejectionResponse return limiterOption. WithRejectionResponse
|
||||||
|
// customize the response for the request rejected by the limiter.
|
||||||
|
func WithRejectionResponse(resp RejectionResponse) limiterOption {
|
||||||
|
return func(l *limiter) {
|
||||||
|
l.resp = resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *limiter) take(amount uint, r *http.Request) bool {
|
||||||
|
bucket := l.getBucket(r)
|
||||||
|
if bucket == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return bucket.take(amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *limiter) getBucket(r *http.Request) bucket {
|
||||||
|
key := l.sessionKey(r)
|
||||||
|
l.RLock()
|
||||||
|
b, ok := l.buckets[key]
|
||||||
|
l.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
b = l.createBucket(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *limiter) createBucket(key string) bucket {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
// double check avoid overwriting
|
||||||
|
b, ok := l.buckets[key]
|
||||||
|
if ok {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate))
|
||||||
|
l.buckets[key] = b
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func defaultSessionKey(r *http.Request) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func RemoteIPSessionKey(r *http.Request) string {
|
||||||
|
IPAddress := r.Header.Get("X-Real-Ip")
|
||||||
|
if IPAddress == "" {
|
||||||
|
IPAddress = r.Header.Get("X-Forwarded-For")
|
||||||
|
}
|
||||||
|
if IPAddress == "" {
|
||||||
|
IPAddress = r.RemoteAddr
|
||||||
|
}
|
||||||
|
return IPAddress
|
||||||
|
}
|
||||||
76
server/web/filter/ratelimit/limiter_test.go
Normal file
76
server/web/filter/ratelimit/limiter_test.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) {
|
||||||
|
r, _ := http.NewRequest(method, path, nil)
|
||||||
|
r.Header.Set("X-Real-Ip", requestIP)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Code != code {
|
||||||
|
t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimiter(t *testing.T) {
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey)))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
handler.Any("*", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
route := "/foo/1"
|
||||||
|
ip := "127.0.0.1"
|
||||||
|
testRequest(t, handler, ip, "GET", route, 200)
|
||||||
|
testRequest(t, handler, ip, "GET", route, 429)
|
||||||
|
testRequest(t, handler, "127.0.0.2", "GET", route, 200)
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
testRequest(t, handler, ip, "GET", route, 200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWithoutLimiter(b *testing.B) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
web.BConfig.RunMode = web.PROD
|
||||||
|
handler.Any("/foo", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(500)
|
||||||
|
})
|
||||||
|
b.ResetTimer()
|
||||||
|
r, _ := http.NewRequest("PUT", "/foo", nil)
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
handler.ServeHTTP(recorder, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWithLimiter(b *testing.B) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
web.BConfig.RunMode = web.PROD
|
||||||
|
err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100)))
|
||||||
|
if err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
handler.Any("/foo", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(500)
|
||||||
|
})
|
||||||
|
b.ResetTimer()
|
||||||
|
r, _ := http.NewRequest("PUT", "/foo", nil)
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
handler.ServeHTTP(recorder, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
76
server/web/filter/ratelimit/token_bucket.go
Normal file
76
server/web/filter/ratelimit/token_bucket.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenBucket struct {
|
||||||
|
sync.RWMutex
|
||||||
|
remaining uint
|
||||||
|
capacity uint
|
||||||
|
lastCheckAt time.Time
|
||||||
|
rate time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenBucket return an bucket that implements token bucket
|
||||||
|
func NewTokenBucket(opts ...bucketOption) bucket {
|
||||||
|
b := &tokenBucket{lastCheckAt: time.Now()}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(b)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func withCapacity(capacity uint) bucketOption {
|
||||||
|
return func(b bucket) {
|
||||||
|
bucket := b.(*tokenBucket)
|
||||||
|
bucket.capacity = capacity
|
||||||
|
bucket.remaining = capacity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withRate(rate time.Duration) bucketOption {
|
||||||
|
return func(b bucket) {
|
||||||
|
bucket := b.(*tokenBucket)
|
||||||
|
bucket.rate = rate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *tokenBucket) getRemaining() uint {
|
||||||
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
return b.remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *tokenBucket) getRate() time.Duration {
|
||||||
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
return b.rate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *tokenBucket) getCapacity() uint {
|
||||||
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
return b.capacity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *tokenBucket) take(amount uint) bool {
|
||||||
|
if b.rate <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
times := uint(now.Sub(b.lastCheckAt) / b.rate)
|
||||||
|
b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate)
|
||||||
|
b.remaining += times
|
||||||
|
if b.remaining < amount {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b.remaining -= amount
|
||||||
|
if b.remaining > b.capacity {
|
||||||
|
b.remaining = b.capacity
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
32
server/web/filter/ratelimit/token_bucket_test.go
Normal file
32
server/web/filter/ratelimit/token_bucket_test.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRate(t *testing.T) {
|
||||||
|
b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket)
|
||||||
|
assert.Equal(t, b.getRate(), 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRemainingAndCapacity(t *testing.T) {
|
||||||
|
b := NewTokenBucket(withCapacity(10))
|
||||||
|
assert.Equal(t, b.getRemaining(), uint(10))
|
||||||
|
assert.Equal(t, b.getCapacity(), uint(10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTake(t *testing.T) {
|
||||||
|
b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
assert.True(t, b.take(1))
|
||||||
|
}
|
||||||
|
assert.False(t, b.take(1))
|
||||||
|
assert.Equal(t, b.getRemaining(), uint(0))
|
||||||
|
b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket)
|
||||||
|
assert.True(t, b.take(1))
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
assert.True(t, b.take(1))
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user