diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index 1763b1b5..b8cd1112 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) { assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) } +func TestFilterChainOrder(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + func TestHead(t *testing.T) { req := Head("http://beego.me") assert.NotNil(t, req) diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go index 2a428b78..5dd38fc5 100644 --- a/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -15,9 +15,12 @@ package web import ( + "fmt" "net/http" "net/http/httptest" + "strconv" "testing" + "time" "github.com/stretchr/testify/assert" @@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ns := NewNamespace("/chain") ns.Get("/*", func(ctx *context.Context) { - ctx.Output.Body([]byte("hello")) + _ = ctx.Output.Body([]byte("hello")) }) r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() + BeeApp.Handlers.Init() BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(t, "filter-chain", w.Header().Get("filter")) } + +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} diff --git a/server/web/router.go b/server/web/router.go index 613c4172..9ba078bf 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -139,6 +139,12 @@ func WithRouterSessionOn(sessionOn bool) ControllerOption { } } +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -151,6 +157,9 @@ type ControllerRegister struct { // the filter created by FilterChain chainRoot *FilterRouter + // keep registered chain and build it when serve http + filterChains []filterChainConfig + cfg *Config } @@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { }, }, cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0 ; i -- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(root.filterFunc) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + // Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method @@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // } // } func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { - root := p.chainRoot - filterFunc := chain(root.filterFunc) - opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) - p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) - p.chainRoot.next = root + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, + }) } // add Filter into diff --git a/server/web/server.go b/server/web/server.go index 6c4eaf9e..548b906b 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { initBeforeHTTPRun() + // init... app.initAddr(addr) + app.Handlers.Init() addr = app.Cfg.Listen.HTTPAddr