From 463ec4b0851174221198583775dbffde3d77ce6b Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Fri, 20 Nov 2020 10:43:30 +0800 Subject: [PATCH 1/6] fix issue #4176: refactor Router with option pattern --- adapter/app.go | 2 +- adapter/namespace.go | 2 +- adapter/router.go | 2 +- server/web/admin.go | 14 ++--- server/web/flash_test.go | 2 +- server/web/namespace.go | 2 +- server/web/router.go | 96 +++++++++++++++++++++-------------- server/web/router_test.go | 18 +++---- server/web/server.go | 8 +-- server/web/unregroute_test.go | 24 ++++----- 10 files changed, 96 insertions(+), 74 deletions(-) diff --git a/adapter/app.go b/adapter/app.go index e20cd9d2..d85af7ee 100644 --- a/adapter/app.go +++ b/adapter/app.go @@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - return (*App)(web.Router(rootpath, c, mappingMethods...)) + return (*App)(web.Router(rootpath, c, web.WithMethods(c, mappingMethods...))) } // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful diff --git a/adapter/namespace.go b/adapter/namespace.go index 98cbd8a5..fb6beffa 100644 --- a/adapter/namespace.go +++ b/adapter/namespace.go @@ -272,7 +272,7 @@ func NSInclude(cList ...ControllerInterface) LinkNamespace { // NSRouter call Namespace Router func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { return func(namespace *Namespace) { - web.Router(rootpath, c, mappingMethods...) + web.Router(rootpath, c, web.WithMethods(c, mappingMethods...)) } } diff --git a/adapter/router.go b/adapter/router.go index c91a09f1..8c4f2949 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) + (*web.ControllerRegister)(p).Add(pattern, c, web.WithMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/server/web/admin.go b/server/web/admin.go index 1b06f486..0595359b 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -112,13 +112,13 @@ func registerAdmin() error { HttpServer: NewHttpServerWithCfg(BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Router("/", c, "get:AdminIndex") - beeAdminApp.Router("/qps", c, "get:QpsIndex") - beeAdminApp.Router("/prof", c, "get:ProfIndex") - beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") - beeAdminApp.Router("/task", c, "get:TaskStatus") - beeAdminApp.Router("/listconf", c, "get:ListConf") - beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") + beeAdminApp.Router("/", c, WithMethods(c, "get:AdminIndex")) + beeAdminApp.Router("/qps", c, WithMethods(c, "get:QpsIndex")) + beeAdminApp.Router("/prof", c, WithMethods(c, "get:ProfIndex")) + beeAdminApp.Router("/healthcheck", c, WithMethods(c, "get:Healthcheck")) + beeAdminApp.Router("/task", c, WithMethods(c, "get:TaskStatus")) + beeAdminApp.Router("/listconf", c, WithMethods(c, "get:ListConf")) + beeAdminApp.Router("/metrics", c, WithMethods(c, "get:PrometheusMetrics")) go beeAdminApp.Run() } diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 2deef54e..4c59acb2 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, WithMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/namespace.go b/server/web/namespace.go index 58afb6c7..fee5ff95 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/astaxie/beego#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, WithMethods(c, mappingMethods...)) return n } diff --git a/server/web/router.go b/server/web/router.go index 7bb89d82..f88d98db 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -120,10 +120,18 @@ type ControllerInfo struct { methodParams []*param.MethodParam } +type ControllerOptions func(*ControllerInfo) + func (c *ControllerInfo) GetPattern() string { return c.pattern } +func WithMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -171,39 +179,65 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) { + p.addWithMethodParams(pattern, c, nil, opts...) } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) } } + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + route := &ControllerInfo{} route.pattern = pattern - route.methods = methods route.routerType = routerTypeBeego route.controllerType = t route.initialize = func() ControllerInterface { @@ -229,23 +263,11 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt return execController } - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + for i := range opts { + opts[i](route) } + p.addRouterForMethod(route) } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { @@ -274,7 +296,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + p.addWithMethodParams(a.Router, c, a.MethodParams, WithMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } diff --git a/server/web/router_test.go b/server/web/router_test.go index 59ccd1fc..3ffc4a60 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -89,8 +89,8 @@ func (jc *JSONController) Get() { func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, WithMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -113,9 +113,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +145,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, WithMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -235,7 +235,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, WithMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -249,7 +249,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -266,7 +266,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, WithMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { diff --git a/server/web/server.go b/server/web/server.go index 25841563..55e30d38 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -266,8 +266,8 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { } // Router see HttpServer.Router -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - return BeeApp.Router(rootpath, c, mappingMethods...) +func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { + return BeeApp.Router(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -286,8 +286,8 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H // beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - app.Handlers.Add(rootPath, c, mappingMethods...) +func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) return app } diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index c675ae7d..16bbf032 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) From 963de3e68b4ccf14a6810264ca6b6b37c4aa1265 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Tue, 24 Nov 2020 10:15:28 +0800 Subject: [PATCH 2/6] add SessionOn option for web.Router --- adapter/app.go | 2 +- adapter/namespace.go | 2 +- adapter/router.go | 2 +- server/web/admin.go | 14 ++++++------- server/web/flash_test.go | 2 +- server/web/namespace.go | 2 +- server/web/router.go | 37 ++++++++++++++++++++++++++--------- server/web/router_test.go | 18 ++++++++--------- server/web/unregroute_test.go | 24 +++++++++++------------ 9 files changed, 61 insertions(+), 42 deletions(-) diff --git a/adapter/app.go b/adapter/app.go index d85af7ee..ee008642 100644 --- a/adapter/app.go +++ b/adapter/app.go @@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - return (*App)(web.Router(rootpath, c, web.WithMethods(c, mappingMethods...))) + return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...))) } // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful diff --git a/adapter/namespace.go b/adapter/namespace.go index fb6beffa..6a148383 100644 --- a/adapter/namespace.go +++ b/adapter/namespace.go @@ -272,7 +272,7 @@ func NSInclude(cList ...ControllerInterface) LinkNamespace { // NSRouter call Namespace Router func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { return func(namespace *Namespace) { - web.Router(rootpath, c, web.WithMethods(c, mappingMethods...)) + web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...)) } } diff --git a/adapter/router.go b/adapter/router.go index 8c4f2949..7a6f8307 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, web.WithMethods(c, mappingMethods...)) + (*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/server/web/admin.go b/server/web/admin.go index 0595359b..6c531b84 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -112,13 +112,13 @@ func registerAdmin() error { HttpServer: NewHttpServerWithCfg(BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Router("/", c, WithMethods(c, "get:AdminIndex")) - beeAdminApp.Router("/qps", c, WithMethods(c, "get:QpsIndex")) - beeAdminApp.Router("/prof", c, WithMethods(c, "get:ProfIndex")) - beeAdminApp.Router("/healthcheck", c, WithMethods(c, "get:Healthcheck")) - beeAdminApp.Router("/task", c, WithMethods(c, "get:TaskStatus")) - beeAdminApp.Router("/listconf", c, WithMethods(c, "get:ListConf")) - beeAdminApp.Router("/metrics", c, WithMethods(c, "get:PrometheusMetrics")) + beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex")) + beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex")) + beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex")) + beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck")) + beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus")) + beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf")) + beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics")) go beeAdminApp.Run() } diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 4c59acb2..c1ca9554 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, WithMethods(&TestFlashController{}, "get:TestWriteFlash")) + handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/namespace.go b/server/web/namespace.go index fee5ff95..aac64d7a 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/astaxie/beego#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, WithMethods(c, mappingMethods...)) + n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...)) return n } diff --git a/server/web/router.go b/server/web/router.go index f88d98db..6ef53b87 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -118,6 +118,7 @@ type ControllerInfo struct { routerType int initialize func() ControllerInterface methodParams []*param.MethodParam + sessionOn bool } type ControllerOptions func(*ControllerInfo) @@ -126,12 +127,18 @@ func (c *ControllerInfo) GetPattern() string { return c.pattern } -func WithMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { +func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { return func(c *ControllerInfo) { c.methods = parseMappingMethods(ctrlInterface, mappingMethod) } } +func SetRouterSessionOn(sessionOn bool) ControllerOptions { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -239,6 +246,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeBeego + route.sessionOn = BConfig.WebConfig.Session.SessionOn route.controllerType = t route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) @@ -296,7 +304,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, WithMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) + p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } @@ -402,6 +410,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeRESTFul + route.sessionOn = BConfig.WebConfig.Session.SessionOn route.runFunction = f methods := make(map[string]string) if method == "*" { @@ -428,6 +437,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeHandler + route.sessionOn = BConfig.WebConfig.Session.SessionOn route.handler = h if len(options) > 0 { if _, ok := options[0].(bool); ok { @@ -462,6 +472,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) if !utils.InSlice(rt.Method(i).Name, exceptMethod) { route := &ControllerInfo{} route.routerType = routerTypeBeego + route.sessionOn = BConfig.WebConfig.Session.SessionOn route.methods = map[string]string{"*": rt.Method(i).Name} route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") @@ -693,12 +704,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool ) if p.cfg.RecoverFunc != nil { @@ -764,7 +778,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - if p.cfg.WebConfig.Session.SessionOn { + currentSessionOn = BConfig.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) diff --git a/server/web/router_test.go b/server/web/router_test.go index 3ffc4a60..fc1e8019 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -89,8 +89,8 @@ func (jc *JSONController) Get() { func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, WithMethods(&TestController{}, "*:List")) - handler.Add("/person/:last/:first", &TestController{}, WithMethods(&TestController{}, "*:Param")) + handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -113,9 +113,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithMethods(&TestController{}, "*:List")) - handler.Add("/v1/:username/edit", &TestController{}, WithMethods(&TestController{}, "get:GetURL")) - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithMethods(&TestController{}, "*:Param")) + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +145,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, WithMethods(&TestController{}, "*:List")) + handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -235,7 +235,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, WithMethods(&TestController{}, "get:GetParams")) + handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -249,7 +249,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithMethods(&TestController{}, "get:GetManyRouter")) + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -266,7 +266,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, WithMethods(&TestController{}, "get:GetEmptyBody")) + handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index 16bbf032..9745dbac 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedRoot")) + handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, WithMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, WithMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) + handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) From 7d7c4539d33bf1e0a49e834254ad20d3b70934b2 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Tue, 24 Nov 2020 11:22:07 +0800 Subject: [PATCH 3/6] add check of global sessionOn to prevent nil pointer panic: sessionOn of router is set to false while global sessionOn is false --- server/web/router.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/web/router.go b/server/web/router.go index 6ef53b87..46641e15 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -781,7 +781,11 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { currentSessionOn = BConfig.WebConfig.Session.SessionOn originRouterInfo, originFindRouter = p.FindRouter(ctx) if originFindRouter { - currentSessionOn = originRouterInfo.sessionOn + if !currentSessionOn && originRouterInfo.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] %s can't be set to true", ctx.Request.Method, ctx.Request.URL.Path) + } else { + currentSessionOn = originRouterInfo.sessionOn + } } if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) From 34068e63e674b26cc268ea92d3afae4baa45afbd Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Tue, 24 Nov 2020 15:58:06 +0800 Subject: [PATCH 4/6] add test cases for router sessionOn config --- server/web/router_test.go | 56 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/server/web/router_test.go b/server/web/router_test.go index fc1e8019..e78e8633 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -750,3 +750,59 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } + +} From 3783847dcefc06c4ac9c836a72e857e0b25fa1a1 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Sat, 26 Dec 2020 17:12:42 +0800 Subject: [PATCH 5/6] modify cfg reference --- server/web/router.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/web/router.go b/server/web/router.go index 645af4f5..8a78f382 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -246,7 +246,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeBeego - route.sessionOn = BConfig.WebConfig.Session.SessionOn + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.controllerType = t route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) @@ -410,7 +410,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeRESTFul - route.sessionOn = BConfig.WebConfig.Session.SessionOn + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.runFunction = f methods := make(map[string]string) if method == "*" { @@ -437,7 +437,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeHandler - route.sessionOn = BConfig.WebConfig.Session.SessionOn + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.handler = h if len(options) > 0 { if _, ok := options[0].(bool); ok { @@ -472,7 +472,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) if !utils.InSlice(rt.Method(i).Name, exceptMethod) { route := &ControllerInfo{} route.routerType = routerTypeBeego - route.sessionOn = BConfig.WebConfig.Session.SessionOn + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.methods = map[string]string{"*": rt.Method(i).Name} route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") @@ -778,7 +778,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - currentSessionOn = BConfig.WebConfig.Session.SessionOn + currentSessionOn = p.cfg.WebConfig.Session.SessionOn originRouterInfo, originFindRouter = p.FindRouter(ctx) if originFindRouter { if !currentSessionOn && originRouterInfo.sessionOn { From dd35e4b1b2e100492c1b2ec048422a345b8c866d Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Sat, 26 Dec 2020 17:56:49 +0800 Subject: [PATCH 6/6] move route sessionOn calculate to router adding --- server/web/router.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/web/router.go b/server/web/router.go index 8a78f382..0d4a9610 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -275,6 +275,13 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt for i := range opts { opts[i](route) } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + p.addRouterForMethod(route) } @@ -781,11 +788,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { currentSessionOn = p.cfg.WebConfig.Session.SessionOn originRouterInfo, originFindRouter = p.FindRouter(ctx) if originFindRouter { - if !currentSessionOn && originRouterInfo.sessionOn { - logs.Warn("global sessionOn is false, sessionOn of router [%s] %s can't be set to true", ctx.Request.Method, ctx.Request.URL.Path) - } else { - currentSessionOn = originRouterInfo.sessionOn - } + currentSessionOn = originRouterInfo.sessionOn } if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)