Feat: add token bucket ratelimit filter
This commit is contained in:
parent
b1c9f661c6
commit
493153939b
@ -25,6 +25,8 @@
|
||||
- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440)
|
||||
- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513)
|
||||
- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525)
|
||||
- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508)
|
||||
|
||||
|
||||
|
||||
## Fix Sonar
|
||||
|
||||
14
server/web/filter/ratelimit/bucket.go
Normal file
14
server/web/filter/ratelimit/bucket.go
Normal file
@ -0,0 +1,14 @@
|
||||
package ratelimit
|
||||
|
||||
import "time"
|
||||
|
||||
// bucket is an interface store ratelimit info
|
||||
type bucket interface {
|
||||
take(amount uint) bool
|
||||
getCapacity() uint
|
||||
getRemaining() uint
|
||||
getRate() time.Duration
|
||||
}
|
||||
|
||||
// bucketOption is constructor option
|
||||
type bucketOption func(bucket)
|
||||
169
server/web/filter/ratelimit/limiter.go
Normal file
169
server/web/filter/ratelimit/limiter.go
Normal file
@ -0,0 +1,169 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/beego/beego/v2/server/web"
|
||||
"github.com/beego/beego/v2/server/web/context"
|
||||
)
|
||||
|
||||
// Limiter is an interface used to ratelimit
|
||||
type Limiter interface {
|
||||
take(amount uint, r *http.Request) bool
|
||||
}
|
||||
|
||||
// limiterOption is constructor option
|
||||
type limiterOption func(l *limiter)
|
||||
|
||||
type limiter struct {
|
||||
sync.RWMutex
|
||||
capacity uint
|
||||
rate time.Duration
|
||||
buckets map[string]bucket
|
||||
bucketFactory func(opts ...bucketOption) bucket
|
||||
sessionKey func(r *http.Request) string
|
||||
resp RejectionResponse
|
||||
}
|
||||
|
||||
// RejectionResponse stores response information
|
||||
// for the request rejected by limiter
|
||||
type RejectionResponse struct {
|
||||
code int
|
||||
body string
|
||||
}
|
||||
|
||||
const perRequestConsumedAmount = 1
|
||||
|
||||
var defaultRejectionResponse = RejectionResponse{
|
||||
code: 429,
|
||||
body: "too many requests",
|
||||
}
|
||||
|
||||
// NewLimiter return FilterFunc, the limiter enables rate limit
|
||||
// according to the configuration.
|
||||
func NewLimiter(opts ...limiterOption) web.FilterFunc {
|
||||
l := &limiter{
|
||||
buckets: make(map[string]bucket),
|
||||
sessionKey: func(r *http.Request) string {
|
||||
return defaultSessionKey(r)
|
||||
},
|
||||
bucketFactory: NewTokenBucket,
|
||||
resp: defaultRejectionResponse,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(l)
|
||||
}
|
||||
|
||||
return func(ctx *context.Context) {
|
||||
if !l.take(perRequestConsumedAmount, ctx.Request) {
|
||||
ctx.ResponseWriter.WriteHeader(l.resp.code)
|
||||
ctx.WriteString(l.resp.body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionKey return limiterOption. WithSessionKey config func
|
||||
// which defines the request characteristic againstthe limit is applied
|
||||
func WithSessionKey(f func(r *http.Request) string) limiterOption {
|
||||
return func(l *limiter) {
|
||||
l.sessionKey = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithRate return limiterOption. WithRate config how long it takes to
|
||||
// generate a token.
|
||||
func WithRate(r time.Duration) limiterOption {
|
||||
return func(l *limiter) {
|
||||
l.rate = r
|
||||
}
|
||||
}
|
||||
|
||||
// WithCapacity return limiterOption. WithCapacity config the capacity size.
|
||||
// The bucket with a capacity of n has n tokens after initialization. The capacity
|
||||
// defines how many requests a client can make in excess of the rate.
|
||||
func WithCapacity(c uint) limiterOption {
|
||||
return func(l *limiter) {
|
||||
l.capacity = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithBucketFactory return limiterOption. WithBucketFactory customize the
|
||||
// implementation of Bucket.
|
||||
func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption {
|
||||
return func(l *limiter) {
|
||||
l.bucketFactory = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithRejectionResponse return limiterOption. WithRejectionResponse
|
||||
// customize the response for the request rejected by the limiter.
|
||||
func WithRejectionResponse(resp RejectionResponse) limiterOption {
|
||||
return func(l *limiter) {
|
||||
l.resp = resp
|
||||
}
|
||||
}
|
||||
|
||||
func (l *limiter) take(amount uint, r *http.Request) bool {
|
||||
bucket := l.getBucket(r)
|
||||
if bucket == nil {
|
||||
return true
|
||||
}
|
||||
return bucket.take(amount)
|
||||
}
|
||||
|
||||
func (l *limiter) getBucket(r *http.Request) bucket {
|
||||
key := l.sessionKey(r)
|
||||
l.RLock()
|
||||
b, ok := l.buckets[key]
|
||||
l.RUnlock()
|
||||
if !ok {
|
||||
b = l.createBucket(key)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (l *limiter) createBucket(key string) bucket {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
// double check avoid overwriting
|
||||
b, ok := l.buckets[key]
|
||||
if ok {
|
||||
return b
|
||||
}
|
||||
b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate))
|
||||
l.buckets[key] = b
|
||||
return b
|
||||
}
|
||||
|
||||
|
||||
func defaultSessionKey(r *http.Request) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func RemoteIPSessionKey(r *http.Request) string {
|
||||
IPAddress := r.Header.Get("X-Real-Ip")
|
||||
if IPAddress == "" {
|
||||
IPAddress = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
if IPAddress == "" {
|
||||
IPAddress = r.RemoteAddr
|
||||
}
|
||||
return IPAddress
|
||||
}
|
||||
76
server/web/filter/ratelimit/limiter_test.go
Normal file
76
server/web/filter/ratelimit/limiter_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/beego/beego/v2/server/web"
|
||||
"github.com/beego/beego/v2/server/web/context"
|
||||
)
|
||||
|
||||
func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) {
|
||||
r, _ := http.NewRequest(method, path, nil)
|
||||
r.Header.Set("X-Real-Ip", requestIP)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Code != code {
|
||||
t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiter(t *testing.T) {
|
||||
handler := web.NewControllerRegister()
|
||||
err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey)))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
handler.Any("*", func(ctx *context.Context) {
|
||||
ctx.Output.SetStatus(200)
|
||||
})
|
||||
|
||||
route := "/foo/1"
|
||||
ip := "127.0.0.1"
|
||||
testRequest(t, handler, ip, "GET", route, 200)
|
||||
testRequest(t, handler, ip, "GET", route, 429)
|
||||
testRequest(t, handler, "127.0.0.2", "GET", route, 200)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
testRequest(t, handler, ip, "GET", route, 200)
|
||||
}
|
||||
|
||||
func BenchmarkWithoutLimiter(b *testing.B) {
|
||||
recorder := httptest.NewRecorder()
|
||||
handler := web.NewControllerRegister()
|
||||
web.BConfig.RunMode = web.PROD
|
||||
handler.Any("/foo", func(ctx *context.Context) {
|
||||
ctx.Output.SetStatus(500)
|
||||
})
|
||||
b.ResetTimer()
|
||||
r, _ := http.NewRequest("PUT", "/foo", nil)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
handler.ServeHTTP(recorder, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkWithLimiter(b *testing.B) {
|
||||
recorder := httptest.NewRecorder()
|
||||
handler := web.NewControllerRegister()
|
||||
web.BConfig.RunMode = web.PROD
|
||||
err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100)))
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
handler.Any("/foo", func(ctx *context.Context) {
|
||||
ctx.Output.SetStatus(500)
|
||||
})
|
||||
b.ResetTimer()
|
||||
r, _ := http.NewRequest("PUT", "/foo", nil)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
handler.ServeHTTP(recorder, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
76
server/web/filter/ratelimit/token_bucket.go
Normal file
76
server/web/filter/ratelimit/token_bucket.go
Normal file
@ -0,0 +1,76 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type tokenBucket struct {
|
||||
sync.RWMutex
|
||||
remaining uint
|
||||
capacity uint
|
||||
lastCheckAt time.Time
|
||||
rate time.Duration
|
||||
}
|
||||
|
||||
// NewTokenBucket return an bucket that implements token bucket
|
||||
func NewTokenBucket(opts ...bucketOption) bucket {
|
||||
b := &tokenBucket{lastCheckAt: time.Now()}
|
||||
for _, o := range opts {
|
||||
o(b)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func withCapacity(capacity uint) bucketOption {
|
||||
return func(b bucket) {
|
||||
bucket := b.(*tokenBucket)
|
||||
bucket.capacity = capacity
|
||||
bucket.remaining = capacity
|
||||
}
|
||||
}
|
||||
|
||||
func withRate(rate time.Duration) bucketOption {
|
||||
return func(b bucket) {
|
||||
bucket := b.(*tokenBucket)
|
||||
bucket.rate = rate
|
||||
}
|
||||
}
|
||||
|
||||
func (b *tokenBucket) getRemaining() uint {
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
return b.remaining
|
||||
}
|
||||
|
||||
func (b *tokenBucket) getRate() time.Duration {
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
return b.rate
|
||||
}
|
||||
|
||||
func (b *tokenBucket) getCapacity() uint {
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
return b.capacity
|
||||
}
|
||||
|
||||
func (b *tokenBucket) take(amount uint) bool {
|
||||
if b.rate <= 0 {
|
||||
return true
|
||||
}
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
now := time.Now()
|
||||
times := uint(now.Sub(b.lastCheckAt) / b.rate)
|
||||
b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate)
|
||||
b.remaining += times
|
||||
if b.remaining < amount {
|
||||
return false
|
||||
}
|
||||
b.remaining -= amount
|
||||
if b.remaining > b.capacity {
|
||||
b.remaining = b.capacity
|
||||
}
|
||||
return true
|
||||
}
|
||||
32
server/web/filter/ratelimit/token_bucket_test.go
Normal file
32
server/web/filter/ratelimit/token_bucket_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetRate(t *testing.T) {
|
||||
b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket)
|
||||
assert.Equal(t, b.getRate(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestGetRemainingAndCapacity(t *testing.T) {
|
||||
b := NewTokenBucket(withCapacity(10))
|
||||
assert.Equal(t, b.getRemaining(), uint(10))
|
||||
assert.Equal(t, b.getCapacity(), uint(10))
|
||||
}
|
||||
|
||||
func TestTake(t *testing.T) {
|
||||
b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket)
|
||||
for i := 0; i < 10; i++ {
|
||||
assert.True(t, b.take(1))
|
||||
}
|
||||
assert.False(t, b.take(1))
|
||||
assert.Equal(t, b.getRemaining(), uint(0))
|
||||
b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket)
|
||||
assert.True(t, b.take(1))
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
assert.True(t, b.take(1))
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user