From 979c0760240777680c6d65fe29e79919380a4c39 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 10 Dec 2023 23:22:51 +0800 Subject: [PATCH] fix 4936: always transport request body (#5546) --- client/httplib/httplib.go | 51 ++++++++++++++++++---------------- client/httplib/httplib_test.go | 43 ++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/client/httplib/httplib.go b/client/httplib/httplib.go index 30a814c0..936e3fd8 100644 --- a/client/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -73,7 +73,6 @@ func NewBeegoRequestWithCtx(ctx context.Context, rawurl, method string) *BeegoHT if err != nil { logs.Error("%+v", berror.Wrapf(err, InvalidURLOrMethod, "invalid raw url or method: %s %s", rawurl, method)) } - return &BeegoHTTPRequest{ url: rawurl, req: req, @@ -81,6 +80,9 @@ func NewBeegoRequestWithCtx(ctx context.Context, rawurl, method string) *BeegoHT files: map[string]string{}, setting: defaultSetting, resp: &http.Response{}, + copyBody: func() io.ReadCloser { + return nil + }, } } @@ -117,7 +119,10 @@ type BeegoHTTPRequest struct { files map[string]string setting BeegoHTTPSettings resp *http.Response - body []byte + // body the response body, not the request body + body []byte + // copyBody support retry strategy to avoid copy request body + copyBody func() io.ReadCloser } // GetRequest returns the request object @@ -281,25 +286,28 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: - bf := bytes.NewBufferString(t) - b.req.Body = io.NopCloser(bf) - b.req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bf), nil - } - b.req.ContentLength = int64(len(t)) + b.reqBody([]byte(t)) case []byte: - bf := bytes.NewBuffer(t) - b.req.Body = io.NopCloser(bf) - b.req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bf), nil - } - b.req.ContentLength = int64(len(t)) + b.reqBody(t) default: logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t)) } return b } +func (b *BeegoHTTPRequest) reqBody(data []byte) *BeegoHTTPRequest { + body := io.NopCloser(bytes.NewReader(data)) + b.req.Body = body + b.req.GetBody = func() (io.ReadCloser, error) { + return body, nil + } + b.req.ContentLength = int64(len(data)) + b.copyBody = func() io.ReadCloser { + return io.NopCloser(bytes.NewReader(data)) + } + return b +} + // XMLBody adds the request raw body encoded in XML. func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { @@ -307,11 +315,7 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if err != nil { return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data") } - b.req.Body = io.NopCloser(bytes.NewReader(byts)) - b.req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader(byts)), nil - } - b.req.ContentLength = int64(len(byts)) + b.reqBody(byts) b.req.Header.Set(contentTypeKey, "application/xml") } return b, nil @@ -324,8 +328,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) if err != nil { return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data") } - b.req.Body = io.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) + b.reqBody(byts) b.req.Header.Set(contentTypeKey, "application/x+yaml") } return b, nil @@ -338,8 +341,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) if err != nil { return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body") } - b.req.Body = io.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) + b.reqBody(byts) b.req.Header.Set(contentTypeKey, "application/json") } return b, nil @@ -493,7 +495,7 @@ func (b *BeegoHTTPRequest) doRequest(_ context.Context) (*http.Response, error) func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success - // retries is setted, it will retries fixed times. + // retries is set, it will retry fixed times. // Sleeps for a 400ms between calls to reduce spam for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) @@ -501,6 +503,7 @@ func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response return } time.Sleep(b.setting.RetryDelay) + b.req.Body = b.copyBody() } return nil, berror.Wrap(err, SendRequestFailed, "sending request fail") } diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index 6c823223..dd571058 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -101,9 +101,17 @@ func (h *HttplibTestSuite) SetupSuite() { handler.HandleFunc("/redirect", func(writer http.ResponseWriter, request *http.Request) { http.Redirect(writer, request, "redirect_dst", http.StatusTemporaryRedirect) }) - handler.HandleFunc("redirect_dst", func(writer http.ResponseWriter, request *http.Request) { + handler.HandleFunc("/redirect_dst", func(writer http.ResponseWriter, request *http.Request) { _, _ = writer.Write([]byte("hello")) }) + + handler.HandleFunc("/retry", func(writer http.ResponseWriter, request *http.Request) { + body, err := io.ReadAll(request.Body) + require.NoError(h.T(), err) + assert.Equal(h.T(), []byte("retry body"), body) + panic("mock error") + }) + go func() { _ = http.Serve(listener, handler) }() @@ -362,6 +370,34 @@ func (h *HttplibTestSuite) TestPut() { assert.Equal(t, "PUT", req.req.Method) } +func (h *HttplibTestSuite) TestRetry() { + defaultSetting.Retries = 2 + testCases := []struct { + name string + req func(t *testing.T) *BeegoHTTPRequest + wantErr error + }{ + { + name: "retry_failed", + req: func(t *testing.T) *BeegoHTTPRequest { + req := NewBeegoRequest("http://localhost:8080/retry", http.MethodPost) + req.Body("retry body") + return req + }, + wantErr: io.EOF, + }, + } + + for _, tc := range testCases { + h.T().Run(tc.name, func(t *testing.T) { + req := tc.req(t) + resp, err := req.DoRequest() + assert.ErrorIs(t, err, tc.wantErr) + assert.Nil(t, resp) + }) + } +} + func TestNewBeegoRequest(t *testing.T) { req := NewBeegoRequest("http://beego.vip", "GET") assert.NotNil(t, req) @@ -384,6 +420,7 @@ func TestNewBeegoRequestWithCtx(t *testing.T) { // bad method but still get request req = NewBeegoRequestWithCtx(context.Background(), "http://beego.vip", "G\tET") assert.NotNil(t, req) + assert.NotNil(t, req.copyBody) } func TestBeegoHTTPRequestSetProtocolVersion(t *testing.T) { @@ -461,10 +498,6 @@ func TestBeegoHTTPRequestXMLBody(t *testing.T) { assert.NotNil(t, req.req.GetBody) } -// TODO -func TestBeegoHTTPRequestResponseForValue(t *testing.T) { -} - func TestBeegoHTTPRequestJSONMarshal(t *testing.T) { req := Post("http://beego.vip") req.SetEscapeHTML(false)