reverse filter chain sort and add test to ensure the FIFO

This commit is contained in:
Ming Deng 2021-01-08 20:52:36 +08:00
parent 7d2c5486be
commit 9105402f8c
4 changed files with 88 additions and 6 deletions

View File

@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) {
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
}
func TestFilterChainOrder(t *testing.T) {
req := Get("http://beego.me")
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("first"), nil
}
})
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("second"), nil
}
})
resp, err := req.DoRequestWithCtx(context.Background())
assert.Nil(t, err)
data := make([]byte, 5)
_, _ = resp.Body.Read(data)
assert.Equal(t, "first", string(data))
}
func TestHead(t *testing.T) {
req := Head("http://beego.me")
assert.NotNil(t, req)

View File

@ -15,9 +15,12 @@
package web
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello"))
_ = ctx.Output.Body([]byte("hello"))
})
r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
}
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
r, _ := http.NewRequest("GET", "/abc", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
first := w.Header().Get("first")
second := w.Header().Get("second")
ft, _ := strconv.ParseInt(first, 10, 64)
st, _ := strconv.ParseInt(second, 10, 64)
assert.True(t, st > ft)
}

View File

@ -139,6 +139,12 @@ func WithRouterSessionOn(sessionOn bool) ControllerOption {
}
}
type filterChainConfig struct {
pattern string
chain FilterChain
opts []FilterOpt
}
// ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct {
routers map[string]*Tree
@ -151,6 +157,9 @@ type ControllerRegister struct {
// the filter created by FilterChain
chainRoot *FilterRouter
// keep registered chain and build it when serve http
filterChains []filterChainConfig
cfg *Config
}
@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
},
},
cfg: cfg,
filterChains: make([]filterChainConfig, 0, 4),
}
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
return res
}
// Init will be executed when HttpServer start running
func (p *ControllerRegister) Init() {
for i := len(p.filterChains) - 1; i >= 0 ; i -- {
fc := p.filterChains[i]
root := p.chainRoot
filterFunc := fc.chain(root.filterFunc)
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
p.chainRoot.next = root
}
}
// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// }
// }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
p.chainRoot.next = root
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.filterChains = append(p.filterChains, filterChainConfig{
pattern: pattern,
chain: chain,
opts: opts,
})
}
// add Filter into

View File

@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
// init...
app.initAddr(addr)
app.Handlers.Init()
addr = app.Cfg.Listen.HTTPAddr