Feat: add token bucket ratelimit filter

This commit is contained in:
anoymouscoder 2021-02-22 18:50:11 +08:00
parent b1c9f661c6
commit 493153939b
6 changed files with 369 additions and 0 deletions

View File

@ -25,6 +25,8 @@
- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440)
- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513)
- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525)
- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508)
## Fix Sonar

View File

@ -0,0 +1,14 @@
package ratelimit
import "time"
// bucket is an interface store ratelimit info
type bucket interface {
take(amount uint) bool
getCapacity() uint
getRemaining() uint
getRate() time.Duration
}
// bucketOption is constructor option
type bucketOption func(bucket)

View File

@ -0,0 +1,169 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ratelimit
import (
"net/http"
"sync"
"time"
"github.com/beego/beego/v2/server/web"
"github.com/beego/beego/v2/server/web/context"
)
// Limiter is an interface used to ratelimit
type Limiter interface {
take(amount uint, r *http.Request) bool
}
// limiterOption is constructor option
type limiterOption func(l *limiter)
type limiter struct {
sync.RWMutex
capacity uint
rate time.Duration
buckets map[string]bucket
bucketFactory func(opts ...bucketOption) bucket
sessionKey func(r *http.Request) string
resp RejectionResponse
}
// RejectionResponse stores response information
// for the request rejected by limiter
type RejectionResponse struct {
code int
body string
}
const perRequestConsumedAmount = 1
var defaultRejectionResponse = RejectionResponse{
code: 429,
body: "too many requests",
}
// NewLimiter return FilterFunc, the limiter enables rate limit
// according to the configuration.
func NewLimiter(opts ...limiterOption) web.FilterFunc {
l := &limiter{
buckets: make(map[string]bucket),
sessionKey: func(r *http.Request) string {
return defaultSessionKey(r)
},
bucketFactory: NewTokenBucket,
resp: defaultRejectionResponse,
}
for _, o := range opts {
o(l)
}
return func(ctx *context.Context) {
if !l.take(perRequestConsumedAmount, ctx.Request) {
ctx.ResponseWriter.WriteHeader(l.resp.code)
ctx.WriteString(l.resp.body)
}
}
}
// WithSessionKey return limiterOption. WithSessionKey config func
// which defines the request characteristic againstthe limit is applied
func WithSessionKey(f func(r *http.Request) string) limiterOption {
return func(l *limiter) {
l.sessionKey = f
}
}
// WithRate return limiterOption. WithRate config how long it takes to
// generate a token.
func WithRate(r time.Duration) limiterOption {
return func(l *limiter) {
l.rate = r
}
}
// WithCapacity return limiterOption. WithCapacity config the capacity size.
// The bucket with a capacity of n has n tokens after initialization. The capacity
// defines how many requests a client can make in excess of the rate.
func WithCapacity(c uint) limiterOption {
return func(l *limiter) {
l.capacity = c
}
}
// WithBucketFactory return limiterOption. WithBucketFactory customize the
// implementation of Bucket.
func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption {
return func(l *limiter) {
l.bucketFactory = f
}
}
// WithRejectionResponse return limiterOption. WithRejectionResponse
// customize the response for the request rejected by the limiter.
func WithRejectionResponse(resp RejectionResponse) limiterOption {
return func(l *limiter) {
l.resp = resp
}
}
func (l *limiter) take(amount uint, r *http.Request) bool {
bucket := l.getBucket(r)
if bucket == nil {
return true
}
return bucket.take(amount)
}
func (l *limiter) getBucket(r *http.Request) bucket {
key := l.sessionKey(r)
l.RLock()
b, ok := l.buckets[key]
l.RUnlock()
if !ok {
b = l.createBucket(key)
}
return b
}
func (l *limiter) createBucket(key string) bucket {
l.Lock()
defer l.Unlock()
// double check avoid overwriting
b, ok := l.buckets[key]
if ok {
return b
}
b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate))
l.buckets[key] = b
return b
}
func defaultSessionKey(r *http.Request) string {
return ""
}
func RemoteIPSessionKey(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
}
return IPAddress
}

View File

@ -0,0 +1,76 @@
package ratelimit
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/beego/beego/v2/server/web"
"github.com/beego/beego/v2/server/web/context"
)
func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) {
r, _ := http.NewRequest(method, path, nil)
r.Header.Set("X-Real-Ip", requestIP)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code)
}
}
func TestLimiter(t *testing.T) {
handler := web.NewControllerRegister()
err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey)))
if err != nil {
t.Error(err)
}
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
route := "/foo/1"
ip := "127.0.0.1"
testRequest(t, handler, ip, "GET", route, 200)
testRequest(t, handler, ip, "GET", route, 429)
testRequest(t, handler, "127.0.0.2", "GET", route, 200)
time.Sleep(1 * time.Millisecond)
testRequest(t, handler, ip, "GET", route, 200)
}
func BenchmarkWithoutLimiter(b *testing.B) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
web.BConfig.RunMode = web.PROD
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
r, _ := http.NewRequest("PUT", "/foo", nil)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
handler.ServeHTTP(recorder, r)
}
})
}
func BenchmarkWithLimiter(b *testing.B) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
web.BConfig.RunMode = web.PROD
err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100)))
if err != nil {
b.Error(err)
}
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
r, _ := http.NewRequest("PUT", "/foo", nil)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
handler.ServeHTTP(recorder, r)
}
})
}

View File

@ -0,0 +1,76 @@
package ratelimit
import (
"sync"
"time"
)
type tokenBucket struct {
sync.RWMutex
remaining uint
capacity uint
lastCheckAt time.Time
rate time.Duration
}
// NewTokenBucket return an bucket that implements token bucket
func NewTokenBucket(opts ...bucketOption) bucket {
b := &tokenBucket{lastCheckAt: time.Now()}
for _, o := range opts {
o(b)
}
return b
}
func withCapacity(capacity uint) bucketOption {
return func(b bucket) {
bucket := b.(*tokenBucket)
bucket.capacity = capacity
bucket.remaining = capacity
}
}
func withRate(rate time.Duration) bucketOption {
return func(b bucket) {
bucket := b.(*tokenBucket)
bucket.rate = rate
}
}
func (b *tokenBucket) getRemaining() uint {
b.RLock()
defer b.RUnlock()
return b.remaining
}
func (b *tokenBucket) getRate() time.Duration {
b.RLock()
defer b.RUnlock()
return b.rate
}
func (b *tokenBucket) getCapacity() uint {
b.RLock()
defer b.RUnlock()
return b.capacity
}
func (b *tokenBucket) take(amount uint) bool {
if b.rate <= 0 {
return true
}
b.Lock()
defer b.Unlock()
now := time.Now()
times := uint(now.Sub(b.lastCheckAt) / b.rate)
b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate)
b.remaining += times
if b.remaining < amount {
return false
}
b.remaining -= amount
if b.remaining > b.capacity {
b.remaining = b.capacity
}
return true
}

View File

@ -0,0 +1,32 @@
package ratelimit
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestGetRate(t *testing.T) {
b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket)
assert.Equal(t, b.getRate(), 1*time.Second)
}
func TestGetRemainingAndCapacity(t *testing.T) {
b := NewTokenBucket(withCapacity(10))
assert.Equal(t, b.getRemaining(), uint(10))
assert.Equal(t, b.getCapacity(), uint(10))
}
func TestTake(t *testing.T) {
b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket)
for i := 0; i < 10; i++ {
assert.True(t, b.take(1))
}
assert.False(t, b.take(1))
assert.Equal(t, b.getRemaining(), uint(0))
b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket)
assert.True(t, b.take(1))
time.Sleep(2 * time.Millisecond)
assert.True(t, b.take(1))
}