diff --git a/CHANGELOG.md b/CHANGELOG.md index 22c6d0ab..9f85415f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,4 +15,5 @@ - Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) - Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453) - Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446) -- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) \ No newline at end of file +- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) +- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456) \ No newline at end of file diff --git a/server/web/namespace_test.go b/server/web/namespace_test.go index 84e7aca5..30d17cb2 100644 --- a/server/web/namespace_test.go +++ b/server/web/namespace_test.go @@ -25,7 +25,8 @@ import ( ) const ( - exampleBody = "hello world" + exampleBody = "hello world" + examplePointerBody = "hello world pointer" nsNamespace = "/router" nsPath = "/user" @@ -43,6 +44,13 @@ func (m ExampleController) Ping() { } } +func (m *ExampleController) PingPointer() { + err := m.Ctx.Output.Body([]byte(examplePointerBody)) + if err != nil { + fmt.Println(err) + } +} + func (m ExampleController) ping() { err := m.Ctx.Output.Body([]byte("ping method")) if err != nil { diff --git a/server/web/router.go b/server/web/router.go index d7fd3047..e80694c0 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -576,7 +576,7 @@ func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method if len(method) == 0 { panic("method name is empty") } else if method[0] > 96 || method[0] < 65 { - panic("not a public method") + panic(fmt.Sprintf("%s is not a public method", method)) } // check only one param which is the method receiver @@ -584,12 +584,7 @@ func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method panic("invalid number of param in") } - // check the receiver implement ControllerInterface controllerType = funcType.In(0) - _, ok := reflect.New(controllerType).Interface().(ControllerInterface) - if !ok { - panic(controllerType.String() + " is not implemented ControllerInterface") - } // check controller has the method _, exists := controllerType.MethodByName(method) @@ -597,6 +592,16 @@ func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method panic(controllerType.String() + " has no method " + method) } + // check the receiver implement ControllerInterface + if controllerType.Kind() == reflect.Ptr { + controllerType = controllerType.Elem() + } + controller := reflect.New(controllerType) + _, ok := controller.Interface().(ControllerInterface) + if !ok { + panic(controllerType.String() + " is not implemented ControllerInterface") + } + return } diff --git a/server/web/router_test.go b/server/web/router_test.go index 7b8ebb6c..3633aee7 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -42,6 +42,10 @@ func (m TestControllerWithInterface) Ping() { fmt.Println("pong") } +func (m *TestControllerWithInterface) PingPointer() { + fmt.Println("pong pointer") +} + type TestController struct { Controller } @@ -923,6 +927,92 @@ func TestRouterRouterAny(t *testing.T) { } } +func TestRouterRouterGetPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterGet("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterGetPointerMethod can't run") + } +} + +func TestRouterRouterPostPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPost("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPostPointerMethod can't run") + } +} + +func TestRouterRouterHeadPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterHead("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterHeadPointerMethod can't run") + } +} + +func TestRouterRouterPutPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPut("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPutPointerMethod can't run") + } +} + +func TestRouterRouterPatchPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPatch("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPatchPointerMethod can't run") + } +} + +func TestRouterRouterDeletePointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterDelete("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterDeletePointerMethod can't run") + } +} + +func TestRouterRouterAnyPointerMethod(t *testing.T) { + handler := NewControllerRegister() + handler.RouterAny("/user", (*ExampleController).PingPointer) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterAnyPointerMethod can't run, get the response is " + w.Body.String()) + } + } +} + func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) { method := "some random method" message := "not support http method: " + strings.ToUpper(method) @@ -961,7 +1051,7 @@ func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) { func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) { method := http.MethodGet - message := "not a public method" + message := "ping is not a public method" defer func() { err := recover() if err != nil { //产生了panic异常 @@ -994,3 +1084,21 @@ func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) { handler := NewControllerRegister() handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping) } + +func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) +}