diff --git a/CHANGELOG.md b/CHANGELOG.md index c4c5b500..94ec3687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - [add singleflight cache for cache module](https://github.com/beego/beego/pull/5119) - [Fix 5129: must set formatter after init the logger](https://github.com/beego/beego/pull/5130) - [Fix 5079: only log msg when the channel is not closed](https://github.com/beego/beego/pull/5132) +- [Fix 4435: Controller SaveToFile remove all temp file](https://github.com/beego/beego/pull/5138) # v2.0.7 - [Upgrade github.com/go-kit/kit, CVE-2022-24450](https://github.com/beego/beego/pull/5121) diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 53717d31..fede12fe 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -102,7 +102,7 @@ func TestSetCookie(t *testing.T) { output.Context.Reset(httptest.NewRecorder(), r) for _, item := range c.valueGp { params := item.item - var others = []interface{}{params.MaxAge, params.Path, params.Domain, params.Secure, params.HttpOnly, params.SameSite} + others := []interface{}{params.MaxAge, params.Path, params.Domain, params.Secure, params.HttpOnly, params.SameSite} output.Context.SetCookie(params.Name, params.Value, others...) got := output.Context.ResponseWriter.Header().Get("Set-Cookie") if got != item.want { diff --git a/server/web/context/input_test.go b/server/web/context/input_test.go index 005ccd9a..058dc805 100644 --- a/server/web/context/input_test.go +++ b/server/web/context/input_test.go @@ -71,11 +71,13 @@ func TestBind(t *testing.T) { {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, - {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + { + "/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", []testItem{{"human", []Human{}, []Human{ {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, - }}}}, + }}}, + }, { "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", diff --git a/server/web/controller_test.go b/server/web/controller_test.go index 6b0a5203..faf9ac98 100644 --- a/server/web/controller_test.go +++ b/server/web/controller_test.go @@ -15,8 +15,11 @@ package web import ( + "bytes" + "io" "io/ioutil" "math" + "mime/multipart" "net/http" "net/http/httptest" "os" @@ -30,6 +33,12 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +var ( + fileKey = "file1" + testFile = filepath.Join(currentWorkDir, "test/static/file.txt") + toFile = filepath.Join(currentWorkDir, "test/static/file2.txt") +) + func TestGetInt(t *testing.T) { i := context.NewInput() i.SetParam("age", "40") @@ -260,6 +269,18 @@ func (t *TestRespController) TestResponse() { _ = t.Resp(bar) } +func (t *TestRespController) TestSaveToFile() { + err := t.SaveToFile(fileKey, toFile) + if err != nil { + t.Ctx.WriteString("save file fail") + } + err = os.Remove(toFile) + if err != nil { + t.Ctx.WriteString("save file fail") + } + t.Ctx.WriteString("save file success") +} + type respTestCase struct { Accept string ExpectedContentLength int64 @@ -309,3 +330,72 @@ func testControllerRespTestCases(t *testing.T, tc respTestCase) { t.Errorf("TestResponse() failed to validate response body '%s' for %s", bodyString, tc.Accept) } } + +func createReqBody(filePath string) (string, io.Reader, error) { + var err error + + buf := new(bytes.Buffer) + bw := multipart.NewWriter(buf) // body writer + + f, err := os.Open(filePath) + if err != nil { + return "", nil, err + } + defer func() { + _ = f.Close() + }() + + // text part1 + p1w, _ := bw.CreateFormField("name") + _, err = p1w.Write([]byte("Tony Bai")) + if err != nil { + return "", nil, err + } + + // text part2 + p2w, _ := bw.CreateFormField("age") + _, err = p2w.Write([]byte("15")) + if err != nil { + return "", nil, err + } + + // file part1 + _, fileName := filepath.Split(filePath) + fw1, _ := bw.CreateFormFile(fileKey, fileName) + _, err = io.Copy(fw1, f) + if err != nil { + return "", nil, err + } + + _ = bw.Close() // write the tail boundry + return bw.FormDataContentType(), buf, nil +} + +func TestControllerSaveFile(t *testing.T) { + // create body + contType, bodyReader, err := createReqBody(testFile) + assert.NoError(t, err) + + // create fake POST request + r, _ := http.NewRequest("POST", "/upload_file", bodyReader) + r.Header.Set("Accept", context.ApplicationForm) + r.Header.Set("Content-Type", contType) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/upload_file", &TestRespController{}, + WithRouterMethods(&TestRespController{}, "post:TestSaveToFile")) + handler.ServeHTTP(w, r) + + response := w.Result() + bs := make([]byte, 100) + n, err := response.Body.Read(bs) + assert.NoError(t, err) + if string(bs[:n]) == "save file fail" { + t.Errorf("TestSaveToFile() failed to validate response") + } + if response.StatusCode != http.StatusOK { + t.Errorf("TestSaveToFile() failed to validate response code for %s", context.ApplicationJSON) + } +} diff --git a/server/web/test/static/file.txt b/server/web/test/static/file.txt new file mode 100644 index 00000000..7083a772 --- /dev/null +++ b/server/web/test/static/file.txt @@ -0,0 +1,45 @@ + + + \ No newline at end of file diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index 703497d3..9265011b 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -209,8 +209,8 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { } func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, - testName, method, path, expectedBodyContent string) { - + testName, method, path, expectedBodyContent string, +) { r, err := http.NewRequest(method, path, nil) if err != nil { t.Errorf("httpRecorderBodyTest NewRequest error: %v", err)