diff --git a/CHANGELOG.md b/CHANGELOG.md index fb51c215..31eaaa0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # developing +- Add: Resp() method for web.Controller. [4588](https://github.com/beego/beego/pull/4588) - Web mock and test support. [4565](https://github.com/beego/beego/pull/4565) [4574](https://github.com/beego/beego/pull/4574) - Error codes definition of cache module. [4493](https://github.com/beego/beego/pull/4493) - Remove generateCommentRoute http hook. Using `bee generate routers` commands instead.[4486](https://github.com/beego/beego/pull/4486) [bee PR 762](https://github.com/beego/bee/pull/762) diff --git a/server/web/controller.go b/server/web/controller.go index 6a29c79e..2648a544 100644 --- a/server/web/controller.go +++ b/server/web/controller.go @@ -21,8 +21,6 @@ import ( "encoding/xml" "errors" "fmt" - "github.com/gogo/protobuf/proto" - "gopkg.in/yaml.v2" "html/template" "io" "mime/multipart" @@ -37,6 +35,8 @@ import ( "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/context/param" "github.com/beego/beego/v2/server/web/session" + "github.com/gogo/protobuf/proto" + "gopkg.in/yaml.v2" ) var ( @@ -436,6 +436,22 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string { } return URLFor(endpoint, values...) } +// Resp sends response based on the Accept Header +// By default response will be in JSON +func (c *Controller) Resp(data interface{}) error { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case context.ApplicationYAML: + c.Data["yaml"] = data + return c.ServeYAML() + case context.ApplicationXML, context.TextXML: + c.Data["xml"] = data + return c.ServeXML() + default: + c.Data["json"] = data + return c.ServeJSON() + } +} // ServeJSON sends a json response with encoding charset. func (c *Controller) ServeJSON(encoding ...bool) error { diff --git a/server/web/controller_test.go b/server/web/controller_test.go index 8a52f097..7f810c1f 100644 --- a/server/web/controller_test.go +++ b/server/web/controller_test.go @@ -15,16 +15,18 @@ package web import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "io/ioutil" "math" "net/http" + "net/http/httptest" "os" "path/filepath" "strconv" "testing" "github.com/beego/beego/v2/server/web/context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetInt(t *testing.T) { @@ -244,3 +246,67 @@ func TestBindYAML(t *testing.T) { require.NoError(t, err) assert.Equal(t, "FOO", s.Foo) } + +type TestRespController struct { + Controller +} + +func (t *TestRespController) TestResponse() { + type S struct { + Foo string `json:"foo" xml:"foo" yaml:"foo"` + } + + bar := S{Foo: "bar"} + + _ = t.Resp(bar) +} + +type respTestCase struct { + Accept string + ExpectedContentLength int64 + ExpectedResponse string +} + +func TestControllerResp(t *testing.T) { + // test cases + tcs := []respTestCase{ + {Accept: context.ApplicationJSON, ExpectedContentLength: 18, ExpectedResponse: "{\n \"foo\": \"bar\"\n}"}, + {Accept: context.ApplicationXML, ExpectedContentLength: 25, ExpectedResponse: "\n bar\n"}, + {Accept: context.ApplicationYAML, ExpectedContentLength: 9, ExpectedResponse: "foo: bar\n"}, + {Accept: "OTHER", ExpectedContentLength: 18, ExpectedResponse: "{\n \"foo\": \"bar\"\n}"}, + } + + for _, tc := range tcs { + testControllerRespTestCases(t, tc) + } +} + +func testControllerRespTestCases(t *testing.T, tc respTestCase) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + r.Header.Set("Accept", tc.Accept) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestRespController{}, WithRouterMethods(&TestRespController{}, "get:TestResponse")) + handler.ServeHTTP(w, r) + + response := w.Result() + if response.ContentLength != tc.ExpectedContentLength { + t.Errorf("TestResponse() unable to validate content length %d for %s", response.ContentLength, tc.Accept) + } + + if response.StatusCode != http.StatusOK { + t.Errorf("TestResponse() failed to validate response code for %s", tc.Accept) + } + + bodyBytes, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Errorf("TestResponse() failed to parse response body for %s", tc.Accept) + } + bodyString := string(bodyBytes) + if bodyString != tc.ExpectedResponse { + t.Errorf("TestResponse() failed to validate response body '%s' for %s", bodyString, tc.Accept) + } +}