diff --git a/adapter/app.go b/adapter/app.go index 8502256b..565a9795 100644 --- a/adapter/app.go +++ b/adapter/app.go @@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - return (*App)(web.Router(rootpath, c, mappingMethods...)) + return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...))) } // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful diff --git a/adapter/router.go b/adapter/router.go index 900e3eb7..17e270ca 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) + (*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/server/web/admin.go b/server/web/admin.go index 89c9ddb9..d640c1be 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -112,13 +112,13 @@ func registerAdmin() error { HttpServer: NewHttpServerWithCfg(BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Router("/", c, "get:AdminIndex") - beeAdminApp.Router("/qps", c, "get:QpsIndex") - beeAdminApp.Router("/prof", c, "get:ProfIndex") - beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") - beeAdminApp.Router("/task", c, "get:TaskStatus") - beeAdminApp.Router("/listconf", c, "get:ListConf") - beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") + beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex")) + beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex")) + beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex")) + beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck")) + beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus")) + beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf")) + beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics")) go beeAdminApp.Run() } diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 2deef54e..c1ca9554 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/namespace.go b/server/web/namespace.go index 4e0c3b85..3598a222 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...)) return n } diff --git a/server/web/router.go b/server/web/router.go index fd94a769..a030224b 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -118,12 +118,27 @@ type ControllerInfo struct { routerType int initialize func() ControllerInterface methodParams []*param.MethodParam + sessionOn bool } +type ControllerOptions func(*ControllerInfo) + func (c *ControllerInfo) GetPattern() string { return c.pattern } +func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + +func SetRouterSessionOn(sessionOn bool) ControllerOptions { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -171,40 +186,67 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) { + p.addWithMethodParams(pattern, c, nil, opts...) } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) } } + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + route := &ControllerInfo{} route.pattern = pattern - route.methods = methods route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.controllerType = t route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) @@ -229,23 +271,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt return execController } - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + for i := range opts { + opts[i](route) } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + + p.addRouterForMethod(route) } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { @@ -273,7 +310,8 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { for _, f := range a.Filters { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + + p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } @@ -379,6 +417,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.runFunction = f methods := make(map[string]string) if method == "*" { @@ -399,6 +438,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.handler = h if len(options) > 0 { if _, ok := options[0].(bool); ok { @@ -433,6 +473,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) if !utils.InSlice(rt.Method(i).Name, exceptMethod) { route := &ControllerInfo{} route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.methods = map[string]string{"*": rt.Method(i).Name} route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") @@ -664,12 +705,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool ) if p.cfg.RecoverFunc != nil { @@ -735,7 +779,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - if p.cfg.WebConfig.Session.SessionOn { + currentSessionOn = p.cfg.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) diff --git a/server/web/router_test.go b/server/web/router_test.go index 87997322..f9f862e9 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -89,8 +89,8 @@ func (jc *JSONController) Get() { func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -113,9 +113,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +145,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -235,7 +235,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -249,7 +249,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -266,7 +266,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -750,3 +750,59 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), + SetRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } + +} diff --git a/server/web/server.go b/server/web/server.go index f0a4f4ea..280828ff 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -266,8 +266,8 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { } // Router see HttpServer.Router -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - return BeeApp.Router(rootpath, c, mappingMethods...) +func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { + return BeeApp.Router(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -286,8 +286,8 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H // beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - app.Handlers.Add(rootPath, c, mappingMethods...) +func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) return app } diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index c675ae7d..9745dbac 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)