diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..d85c6373 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" + open-pull-requests-limit: 10 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 412274a3..8142982a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/stale@v1 + - uses: actions/stale@v3.0.19 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: 'This issue is inactive for a long time.' diff --git a/.travis.yml b/.travis.yml index 5ccd3645..9382e985 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,8 @@ language: go go: - "1.14.x" + - "1.15.x" + - "1.16.x" services: - redis-server - mysql @@ -12,7 +14,7 @@ env: global: - GO_REPO_FULLNAME="github.com/beego/beego/v2" matrix: - - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db + - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" before_install: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5749cf07..1c04dc50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,11 @@ # developing - Add template functions eq,lt to support uint and int compare. [4607](https://github.com/beego/beego/pull/4607) +- Add http client and option func. [4455](https://github.com/beego/beego/issues/4455) +- Add: Convenient way to generate mock object [4620](https://github.com/beego/beego/issues/4620) +- Infra: use dependabot to update dependencies. [4623](https://github.com/beego/beego/pull/4623) - Lint: use golangci-lint. [4619](https://github.com/beego/beego/pull/4619) - Chore: format code. [4615](https://github.com/beego/beego/pull/4615) +- Test on Go v1.15.x & v1.16.x. [4614](https://github.com/beego/beego/pull/4614) - Env: non-empty GOBIN & GOPATH. [4613](https://github.com/beego/beego/pull/4613) - Chore: update dependencies. [4611](https://github.com/beego/beego/pull/4611) - Update orm_test.go/TestInsertOrUpdate with table-driven. [4609](https://github.com/beego/beego/pull/4609) @@ -45,10 +49,13 @@ - Optimize AddAutoPrefix: only register one router in case-insensitive mode. [4582](https://github.com/beego/beego/pull/4582) - Init exceptMethod by using reflection. [4583](https://github.com/beego/beego/pull/4583) - Deprecated BeeMap and replace all usage with `sync.map` [4616](https://github.com/beego/beego/pull/4616) +- TaskManager support graceful shutdown [4635](https://github.com/beego/beego/pull/4635) ## Fix Sonar +- [4624](https://github.com/beego/beego/pull/4624) - [4608](https://github.com/beego/beego/pull/4608) - [4473](https://github.com/beego/beego/pull/4473) - [4474](https://github.com/beego/beego/pull/4474) - [4479](https://github.com/beego/beego/pull/4479) +- [4639](https://github.com/beego/beego/pull/4639) diff --git a/client/cache/calc_utils.go b/client/cache/calc_utils.go index 417f8337..f8b7f24a 100644 --- a/client/cache/calc_utils.go +++ b/client/cache/calc_utils.go @@ -12,6 +12,11 @@ var ( ErrNotIntegerType = berror.Error(NotIntegerType, "item val is not (u)int (u)int32 (u)int64") ) +const ( + MinUint32 uint32 = 0 + MinUint64 uint64 = 0 +) + func incr(originVal interface{}) (interface{}, error) { switch val := originVal.(type) { case int: @@ -75,12 +80,12 @@ func decr(originVal interface{}) (interface{}, error) { } return val - 1, nil case uint32: - if val == 0 { + if val == MinUint32 { return nil, ErrDecrementOverflow } return val - 1, nil case uint64: - if val == 0 { + if val == MinUint64 { return nil, ErrDecrementOverflow } return val - 1, nil diff --git a/client/httplib/README.md b/client/httplib/README.md index 1d22f341..a5723c6d 100644 --- a/client/httplib/README.md +++ b/client/httplib/README.md @@ -9,7 +9,7 @@ httplib is an libs help you to curl remote url. you can use Get to crawl data. import "github.com/beego/beego/v2/client/httplib" - + str, err := httplib.Get("http://beego.me/").String() if err != nil { // error @@ -39,7 +39,7 @@ Example: // GET httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - + // POST httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) diff --git a/client/httplib/client_option.go b/client/httplib/client_option.go new file mode 100644 index 00000000..f970e67d --- /dev/null +++ b/client/httplib/client_option.go @@ -0,0 +1,155 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/url" + "time" +) + +type ( + ClientOption func(client *Client) + BeegoHTTPRequestOption func(request *BeegoHTTPRequest) +) + +// WithEnableCookie will enable cookie in all subsequent request +func WithEnableCookie(enable bool) ClientOption { + return func(client *Client) { + client.Setting.EnableCookie = enable + } +} + +// WithEnableCookie will adds UA in all subsequent request +func WithUserAgent(userAgent string) ClientOption { + return func(client *Client) { + client.Setting.UserAgent = userAgent + } +} + +// WithTLSClientConfig will adds tls config in all subsequent request +func WithTLSClientConfig(config *tls.Config) ClientOption { + return func(client *Client) { + client.Setting.TLSClientConfig = config + } +} + +// WithTransport will set transport field in all subsequent request +func WithTransport(transport http.RoundTripper) ClientOption { + return func(client *Client) { + client.Setting.Transport = transport + } +} + +// WithProxy will set http proxy field in all subsequent request +func WithProxy(proxy func(*http.Request) (*url.URL, error)) ClientOption { + return func(client *Client) { + client.Setting.Proxy = proxy + } +} + +// WithCheckRedirect will specifies the policy for handling redirects in all subsequent request +func WithCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) ClientOption { + return func(client *Client) { + client.Setting.CheckRedirect = redirect + } +} + +// WithHTTPSetting can replace beegoHTTPSeting +func WithHTTPSetting(setting BeegoHTTPSettings) ClientOption { + return func(client *Client) { + client.Setting = setting + } +} + +// WithEnableGzip will enable gzip in all subsequent request +func WithEnableGzip(enable bool) ClientOption { + return func(client *Client) { + client.Setting.Gzip = enable + } +} + +// BeegoHttpRequestOption + +// WithTimeout sets connect time out and read-write time out for BeegoRequest. +func WithTimeout(connectTimeout, readWriteTimeout time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetTimeout(connectTimeout, readWriteTimeout) + } +} + +// WithHeader adds header item string in request. +func WithHeader(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(key, value) + } +} + +// WithCookie adds a cookie to the request. +func WithCookie(cookie *http.Cookie) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header("Cookie", cookie.String()) + } +} + +// Withtokenfactory adds a custom function to set Authorization +func WithTokenFactory(tokenFactory func() string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + t := tokenFactory() + + request.Header("Authorization", t) + } +} + +// WithBasicAuth adds a custom function to set basic auth +func WithBasicAuth(basicAuth func() (string, string)) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + username, password := basicAuth() + request.SetBasicAuth(username, password) + } +} + +// WithFilters will use the filter as the invocation filters +func WithFilters(fcs ...FilterChain) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetFilters(fcs...) + } +} + +// WithContentType adds ContentType in header +func WithContentType(contentType string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(contentTypeKey, contentType) + } +} + +// WithParam adds query param in to request. +func WithParam(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Param(key, value) + } +} + +// WithRetry set retry times and delay for the request +// default is 0 (never retry) +// -1 retry indefinitely (forever) +// Other numbers specify the exact retry amount +func WithRetry(times int, delay time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Retries(times) + request.RetryDelay(delay) + } +} diff --git a/client/httplib/client_option_test.go b/client/httplib/client_option_test.go new file mode 100644 index 00000000..79e5103f --- /dev/null +++ b/client/httplib/client_option_test.go @@ -0,0 +1,261 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type respCarrier struct { + bytes []byte +} + +func (r *respCarrier) SetBytes(bytes []byte) { + r.bytes = bytes +} + +func (r *respCarrier) String() string { + return string(r.bytes) +} + +func TestOption_WithEnableCookie(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/", + WithEnableCookie(true)) + if err != nil { + t.Fatal(err) + } + + v := "smallfish" + resp := &respCarrier{} + err = client.Get(resp, "/cookies/set?k1="+v) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + err = client.Get(resp, "/cookies") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestOption_WithUserAgent(t *testing.T) { + v := "beego" + client, err := NewClient("test", "http://httpbin.org/", + WithUserAgent(v)) + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestOption_WithCheckRedirect(t *testing.T) { + client, err := NewClient("test", "https://goolnk.com/33BD2j", + WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + })) + if err != nil { + t.Fatal(err) + } + err = client.Get(nil, "") + assert.NotNil(t, err) +} + +func TestOption_WithHTTPSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + + client, err := NewClient("test", "http://httpbin.org/", + WithHTTPSetting(setting)) + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/get") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestOption_WithHeader(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + client.CommonOpts = append(client.CommonOpts, WithHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36")) + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), "Mozilla/5.0") + if n == -1 { + t.Fatal("Mozilla/5.0 not found in user-agent") + } +} + +func TestOption_WithTokenFactory(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + client.CommonOpts = append(client.CommonOpts, + WithTokenFactory(func() string { + return "testauth" + })) + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), "testauth") + if n == -1 { + t.Fatal("Auth is not set in request") + } +} + +func TestOption_WithBasicAuth(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/basic-auth/user/passwd", + WithBasicAuth(func() (string, string) { + return "user", "passwd" + })) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + n := strings.Index(resp.String(), "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestOption_WithContentType(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + v := "application/json" + resp := &respCarrier{} + err = client.Get(resp, "/headers", WithContentType(v)) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in header") + } +} + +func TestOption_WithParam(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + v := "smallfish" + resp := &respCarrier{} + err = client.Get(resp, "/get", WithParam("username", v)) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in header") + } +} + +func TestOption_WithRetry(t *testing.T) { + client, err := NewClient("test", "https://goolnk.com/33BD2j", + WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + })) + if err != nil { + t.Fatal(err) + } + + retryAmount := 1 + retryDelay := 800 * time.Millisecond + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _ = client.Get(nil, "", WithRetry(retryAmount, retryDelay)) + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } +} diff --git a/client/httplib/error_code.go b/client/httplib/error_code.go index bd349a34..177419ad 100644 --- a/client/httplib/error_code.go +++ b/client/httplib/error_code.go @@ -124,3 +124,11 @@ Make sure that: 1. You pass valid structure pointer to the function; 2. The body is valid YAML document `) + +var UnmarshalResponseToObjectFailed = berror.DefineCode(5001011, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +There are several cases that cause this error: +1. You pass valid structure pointer to the function; +2. The body is valid json, Yaml or XML document +`) diff --git a/client/httplib/httpclient.go b/client/httplib/httpclient.go new file mode 100644 index 00000000..c2a61fcf --- /dev/null +++ b/client/httplib/httpclient.go @@ -0,0 +1,174 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" +) + +// Client provides an HTTP client supporting chain call +type Client struct { + Name string + Endpoint string + CommonOpts []BeegoHTTPRequestOption + + Setting BeegoHTTPSettings +} + +// HTTPResponseCarrier If value implement HTTPResponseCarrier. http.Response will pass to SetHTTPResponse +type HTTPResponseCarrier interface { + SetHTTPResponse(resp *http.Response) +} + +// HTTPBodyCarrier If value implement HTTPBodyCarrier. http.Response.Body will pass to SetReader +type HTTPBodyCarrier interface { + SetReader(r io.ReadCloser) +} + +// HTTPBytesCarrier If value implement HTTPBytesCarrier. +// All the byte in http.Response.Body will pass to SetBytes +type HTTPBytesCarrier interface { + SetBytes(bytes []byte) +} + +// HTTPStatusCarrier If value implement HTTPStatusCarrier. http.Response.StatusCode will pass to SetStatusCode +type HTTPStatusCarrier interface { + SetStatusCode(status int) +} + +// HttpHeaderCarrier If value implement HttpHeaderCarrier. http.Response.Header will pass to SetHeader +type HTTPHeadersCarrier interface { + SetHeader(header map[string][]string) +} + +// NewClient return a new http client +func NewClient(name string, endpoint string, opts ...ClientOption) (*Client, error) { + res := &Client{ + Name: name, + Endpoint: endpoint, + } + setting := GetDefaultSetting() + res.Setting = setting + for _, o := range opts { + o(res) + } + return res, nil +} + +func (c *Client) customReq(req *BeegoHTTPRequest, opts []BeegoHTTPRequestOption) { + req.Setting(c.Setting) + opts = append(c.CommonOpts, opts...) + for _, o := range opts { + o(req) + } +} + +// handleResponse try to parse body to meaningful value +func (c *Client) handleResponse(value interface{}, req *BeegoHTTPRequest) error { + // make sure req.resp is not nil + _, err := req.Bytes() + if err != nil { + return err + } + + err = c.handleCarrier(value, req) + if err != nil { + return err + } + + return req.ToValue(value) +} + +// handleCarrier set http data to value +func (c *Client) handleCarrier(value interface{}, req *BeegoHTTPRequest) error { + if value == nil { + return nil + } + + if carrier, ok := value.(HTTPResponseCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + req.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + carrier.SetHTTPResponse(req.resp) + } + if carrier, ok := value.(HTTPBodyCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + reader := ioutil.NopCloser(bytes.NewReader(b)) + carrier.SetReader(reader) + } + if carrier, ok := value.(HTTPBytesCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + carrier.SetBytes(b) + } + if carrier, ok := value.(HTTPStatusCarrier); ok { + carrier.SetStatusCode(req.resp.StatusCode) + } + if carrier, ok := value.(HTTPHeadersCarrier); ok { + carrier.SetHeader(req.resp.Header) + } + return nil +} + +// Get Send a GET request and try to give its result value +func (c *Client) Get(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Get(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Post Send a POST request and try to give its result value +func (c *Client) Post(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Post(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Put Send a Put request and try to give its result value +func (c *Client) Put(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Put(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Delete Send a Delete request and try to give its result value +func (c *Client) Delete(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Delete(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Head Send a Head request and try to give its result value +func (c *Client) Head(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Head(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} diff --git a/client/httplib/httpclient_test.go b/client/httplib/httpclient_test.go new file mode 100644 index 00000000..6bb00258 --- /dev/null +++ b/client/httplib/httpclient_test.go @@ -0,0 +1,220 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "encoding/xml" + "io" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + client, err := NewClient("test1", "http://beego.me", WithEnableCookie(true)) + assert.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, true, client.Setting.EnableCookie) +} + +type slideShowResponse struct { + Resp *http.Response + bytes []byte + StatusCode int + Body io.ReadCloser + Header map[string][]string + + Slideshow slideshow `json:"slideshow" yaml:"slideshow"` +} + +func (r *slideShowResponse) SetHTTPResponse(resp *http.Response) { + r.Resp = resp +} + +func (r *slideShowResponse) SetBytes(bytes []byte) { + r.bytes = bytes +} + +func (r *slideShowResponse) SetReader(reader io.ReadCloser) { + r.Body = reader +} + +func (r *slideShowResponse) SetStatusCode(status int) { + r.StatusCode = status +} + +func (r *slideShowResponse) SetHeader(header map[string][]string) { + r.Header = header +} + +func (r *slideShowResponse) String() string { + return string(r.bytes) +} + +type slideshow struct { + XMLName xml.Name `xml:"slideshow"` + + Title string `json:"title" yaml:"title" xml:"title,attr"` + Author string `json:"author" yaml:"author" xml:"author,attr"` + Date string `json:"date" yaml:"date" xml:"date,attr"` + Slides []slide `json:"slides" yaml:"slides" xml:"slide"` +} + +type slide struct { + XMLName xml.Name `xml:"slide"` + + Title string `json:"title" yaml:"title" xml:"title"` +} + +func TestClient_handleCarrier(t *testing.T) { + v := "beego" + client, err := NewClient("test", "http://httpbin.org/", + WithUserAgent(v)) + if err != nil { + t.Fatal(err) + } + + s := &slideShowResponse{} + err = client.Get(s, "/json") + if err != nil { + t.Fatal(err) + } + defer s.Body.Close() + + assert.NotNil(t, s.Resp) + assert.NotNil(t, s.Body) + assert.Equal(t, "429", s.Header["Content-Length"][0]) + assert.Equal(t, 200, s.StatusCode) + + b, err := ioutil.ReadAll(s.Body) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 429, len(b)) + assert.Equal(t, s.String(), string(b)) +} + +func TestClient_Get(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + // json + var s *slideShowResponse + err = client.Get(&s, "/json") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) + + // xml + var ssp *slideshow + err = client.Get(&ssp, "/base64/PD94bWwgPz48c2xpZGVzaG93CnRpdGxlPSJTYW1wbGUgU2xpZGUgU2hvdyIKZGF0ZT0iRGF0ZSBvZiBwdWJsaWNhdGlvbiIKYXV0aG9yPSJZb3VycyBUcnVseSI+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPldha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyE8L3RpdGxlPjwvc2xpZGU+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPk92ZXJ2aWV3PC90aXRsZT48aXRlbT5XaHkgPGVtPldvbmRlcldpZGdldHM8L2VtPiBhcmUgZ3JlYXQ8L2l0ZW0+PGl0ZW0vPjxpdGVtPldobyA8ZW0+YnV5czwvZW0+IFdvbmRlcldpZGdldHM8L2l0ZW0+PC9zbGlkZT48L3NsaWRlc2hvdz4=") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", ssp.Title) + assert.Equal(t, 2, len(ssp.Slides)) + assert.Equal(t, "Overview", ssp.Slides[1].Title) + + // yaml + s = nil + err = client.Get(&s, "/base64/c2xpZGVzaG93OgogIGF1dGhvcjogWW91cnMgVHJ1bHkKICBkYXRlOiBkYXRlIG9mIHB1YmxpY2F0aW9uCiAgc2xpZGVzOgogIC0gdGl0bGU6IFdha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyEKICAgIHR5cGU6IGFsbAogIC0gaXRlbXM6CiAgICAtIFdoeSA8ZW0+V29uZGVyV2lkZ2V0czwvZW0+IGFyZSBncmVhdAogICAgLSBXaG8gPGVtPmJ1eXM8L2VtPiBXb25kZXJXaWRnZXRzCiAgICB0aXRsZTogT3ZlcnZpZXcKICAgIHR5cGU6IGFsbAogIHRpdGxlOiBTYW1wbGUgU2xpZGUgU2hvdw==") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) +} + +func TestClient_Post(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Get(resp, "/json") + if err != nil { + t.Fatal(err) + } + + jsonStr := resp.String() + err = client.Post(resp, "/post", jsonStr) + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, resp) + assert.Equal(t, http.MethodPost, resp.Resp.Request.Method) +} + +func TestClient_Put(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Get(resp, "/json") + if err != nil { + t.Fatal(err) + } + + jsonStr := resp.String() + err = client.Put(resp, "/put", jsonStr) + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, resp) + assert.Equal(t, http.MethodPut, resp.Resp.Request.Method) +} + +func TestClient_Delete(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Delete(resp, "/delete") + if err != nil { + t.Fatal(err) + } + defer resp.Resp.Body.Close() + assert.NotNil(t, resp) + assert.Equal(t, http.MethodDelete, resp.Resp.Request.Method) +} + +func TestClient_Head(t *testing.T) { + client, err := NewClient("test", "http://beego.me") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Head(resp, "") + if err != nil { + t.Fatal(err) + } + defer resp.Resp.Body.Close() + assert.NotNil(t, resp) + assert.Equal(t, http.MethodHead, resp.Resp.Request.Method) +} diff --git a/client/httplib/httplib.go b/client/httplib/httplib.go index 434d74c1..b102f687 100644 --- a/client/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -124,7 +124,6 @@ type BeegoHTTPRequest struct { setting BeegoHTTPSettings resp *http.Response body []byte - dump []byte } // GetRequest returns the request object @@ -199,7 +198,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { // SetProtocolVersion sets the protocol version for incoming requests. // Client requests always use HTTP/1.1 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { - if len(vers) == 0 { + if vers == "" { vers = "HTTP/1.1" } @@ -511,18 +510,16 @@ func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper { DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), MaxIdleConnsPerHost: 100, } - } else { + } else if t, ok := trans.(*http.Transport); ok { // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.DialContext == nil { - t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.DialContext == nil { + t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) } } return trans @@ -656,6 +653,40 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.") } +// ToValue attempts to resolve the response body to value using an existing method. +// Calls Response inner. +// If response header contain Content-Type, func will call ToJSON\ToXML\ToYAML. +// Else it will try to parse body as json\yaml\xml, If all attempts fail, an error will be returned +func (b *BeegoHTTPRequest) ToValue(value interface{}) error { + if value == nil { + return nil + } + + contentType := strings.Split(b.resp.Header.Get(contentTypeKey), ";")[0] + // try to parse it as content type + switch contentType { + case "application/json": + return b.ToJSON(value) + case "text/xml", "application/xml": + return b.ToXML(value) + case "text/yaml", "application/x-yaml", "application/x+yaml": + return b.ToYAML(value) + } + + // try to parse it anyway + if err := b.ToJSON(value); err == nil { + return nil + } + if err := b.ToYAML(value); err == nil { + return nil + } + if err := b.ToXML(value); err == nil { + return nil + } + + return berror.Error(UnmarshalResponseToObjectFailed, "unmarshal body to object failed.") +} + // Response executes request client gets response manually. func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index be702fb6..491b1b9f 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -433,3 +433,7 @@ func TestBeegoHTTPRequestXMLBody(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, req.req.GetBody) } + +// TODO +func TestBeegoHTTPRequest_ResponseForValue(t *testing.T) { +} diff --git a/client/httplib/mock/mock_condition.go b/client/httplib/mock/mock_condition.go index 53d3d703..912699ff 100644 --- a/client/httplib/mock/mock_condition.go +++ b/client/httplib/mock/mock_condition.go @@ -57,7 +57,7 @@ func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondi } func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { - res := true + var res bool if len(sc.path) > 0 { res = sc.matchPath(ctx, req) } else if len(sc.pathReg) > 0 { diff --git a/client/httplib/setting.go b/client/httplib/setting.go index df8eff4b..2d7a0eed 100644 --- a/client/httplib/setting.go +++ b/client/httplib/setting.go @@ -55,6 +55,11 @@ func SetDefaultSetting(setting BeegoHTTPSettings) { defaultSetting = setting } +// GetDefaultSetting return current default setting +func GetDefaultSetting() BeegoHTTPSettings { + return defaultSetting +} + var defaultSetting = BeegoHTTPSettings{ UserAgent: "beegoServer", ConnectTimeout: 60 * time.Second, diff --git a/client/orm/clauses/order_clause/order_test.go b/client/orm/clauses/order_clause/order_test.go index 172e7492..7854948e 100644 --- a/client/orm/clauses/order_clause/order_test.go +++ b/client/orm/clauses/order_clause/order_test.go @@ -120,25 +120,21 @@ func TestOrder_GetColumn(t *testing.T) { } } -func TestOrder_GetSort(t *testing.T) { - o := Clause( - SortDescending(), - ) - if o.GetSort() != Descending { - t.Error() - } -} +func TestSortString(t *testing.T) { + template := "got: %s, want: %s" -func TestOrder_IsRaw(t *testing.T) { - o1 := Clause() - if o1.IsRaw() { - t.Error() + o1 := Clause(sort(Sort(1))) + if o1.SortString() != "ASC" { + t.Errorf(template, o1.SortString(), "ASC") } - o2 := Clause( - Raw(), - ) - if !o2.IsRaw() { - t.Error() + o2 := Clause(sort(Sort(2))) + if o2.SortString() != "DESC" { + t.Errorf(template, o2.SortString(), "DESC") + } + + o3 := Clause(sort(Sort(3))) + if o3.SortString() != `` { + t.Errorf(template, o3.SortString(), ``) } } diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 58f2a597..5ded367a 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -1845,17 +1845,12 @@ func TestRawQueryRow(t *testing.T) { case "id": throwFail(t, AssertIs(id, 1)) break - case "time": + case "time", "datetime": v = v.(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc) assert.True(t, v.(time.Time).Sub(value) <= time.Second) break case "date": - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - assert.True(t, v.(time.Time).Sub(value) <= time.Second) - break default: throwFail(t, AssertIs(v, dataValues[col])) } @@ -2769,6 +2764,7 @@ func TestStrPkInsert(t *testing.T) { fmt.Println(err) if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { } else if err == ErrLastInsertIdUnavailable { + return } else { throwFailNow(t, err) } diff --git a/core/bean/mock.go b/core/bean/mock.go new file mode 100644 index 00000000..0a8d3e29 --- /dev/null +++ b/core/bean/mock.go @@ -0,0 +1,142 @@ +package bean + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// the mock object must be pointer of struct +// the element in mock object can be slices, structures, basic data types, pointers and interface +func Mock(v interface{}) (err error) { + pv := reflect.ValueOf(v) + // the input must be pointer of struct + if pv.Kind() != reflect.Ptr || pv.IsNil() { + err = fmt.Errorf("not a pointer of struct") + return + } + err = mock(pv) + return +} + +func mock(pv reflect.Value) (err error) { + pt := pv.Type() + for i := 0; i < pt.Elem().NumField(); i++ { + ptt := pt.Elem().Field(i) + pvv := pv.Elem().FieldByName(ptt.Name) + if !pvv.CanSet() { + continue + } + kt := ptt.Type.Kind() + tagValue := ptt.Tag.Get("mock") + switch kt { + case reflect.Map: + continue + case reflect.Interface: + if pvv.IsNil() { // when interface is nil,can not sure the type + continue + } + pvv.Set(reflect.New(pvv.Elem().Type().Elem())) + err = mock(pvv.Elem()) + case reflect.Ptr: + err = mockPtr(pvv, ptt.Type.Elem()) + case reflect.Struct: + err = mock(pvv.Addr()) + case reflect.Array, reflect.Slice: + err = mockSlice(tagValue, pvv) + case reflect.String: + pvv.SetString(tagValue) + case reflect.Bool: + err = mockBool(tagValue, pvv) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, e := strconv.ParseInt(tagValue, 10, 64) + if e != nil || pvv.OverflowInt(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetInt(value) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + value, e := strconv.ParseUint(tagValue, 10, 64) + if e != nil || pvv.OverflowUint(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetUint(value) + case reflect.Float32, reflect.Float64: + value, e := strconv.ParseFloat(tagValue, pvv.Type().Bits()) + if e != nil || pvv.OverflowFloat(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetFloat(value) + default: + } + if err != nil { + return + } + } + return +} + +// mock slice value +func mockSlice(tagValue string, pvv reflect.Value) (err error) { + if tagValue == "" { + return + } + sliceMetas := strings.Split(tagValue, ":") + if len(sliceMetas) != 2 || sliceMetas[0] != "length" { + err = fmt.Errorf("the value:%s is invalid", tagValue) + return + } + length, e := strconv.Atoi(sliceMetas[1]) + if e != nil { + return e + } + + sliceType := reflect.SliceOf(pvv.Type().Elem()) // get slice type + itemType := sliceType.Elem() // get the type of item in slice + value := reflect.MakeSlice(sliceType, 0, length) + newSliceValue := make([]reflect.Value, 0, length) + for k := 0; k < length; k++ { + itemValue := reflect.New(itemType).Elem() + // if item in slice is struct or pointer,must set zero value + switch itemType.Kind() { + case reflect.Struct: + err = mock(itemValue.Addr()) + case reflect.Ptr: + if itemValue.IsNil() { + itemValue.Set(reflect.New(itemType.Elem())) + if e := mock(itemValue); e != nil { + return e + } + } + } + newSliceValue = append(newSliceValue, itemValue) + if err != nil { + return + } + } + value = reflect.Append(value, newSliceValue...) + pvv.Set(value) + return +} + +// mock bool value +func mockBool(tagValue string, pvv reflect.Value) (err error) { + switch tagValue { + case "true": + pvv.SetBool(true) + case "false": + pvv.SetBool(false) + default: + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + return +} + +// mock pointer +func mockPtr(pvv reflect.Value, ptt reflect.Type) (err error) { + if pvv.IsNil() { + pvv.Set(reflect.New(ptt)) // must set nil value to zero value + } + err = mock(pvv) + return +} diff --git a/core/bean/mock_test.go b/core/bean/mock_test.go new file mode 100644 index 00000000..0fe81f43 --- /dev/null +++ b/core/bean/mock_test.go @@ -0,0 +1,74 @@ +package bean + +import ( + "fmt" + "testing" +) + +func TestMock(t *testing.T) { + type MockSubSubObject struct { + A int `mock:"20"` + } + type MockSubObjectAnoy struct { + Anoy int `mock:"20"` + } + type MockSubObject struct { + A bool `mock:"true"` + B MockSubSubObject + } + type MockObject struct { + A string `mock:"aaaaa"` + B int8 `mock:"10"` + C []*MockSubObject `mock:"length:2"` + D bool `mock:"true"` + E *MockSubObject + F []int `mock:"length:3"` + G InterfaceA + H InterfaceA + MockSubObjectAnoy + } + m := &MockObject{G: &ImplA{}} + err := Mock(m) + if err != nil { + t.Fatalf("mock failed: %v", err) + } + if m.A != "aaaaa" || m.B != 10 || m.C[1].B.A != 20 || + !m.E.A || m.E.B.A != 20 || !m.D || len(m.F) != 3 { + t.Fail() + } + _, ok := m.G.(*ImplA) + if !ok { + t.Fail() + } + _, ok = m.G.(*ImplB) + if ok { + t.Fail() + } + _, ok = m.H.(*ImplA) + if ok { + t.Fail() + } + if m.Anoy != 20 { + t.Fail() + } +} + +type InterfaceA interface { + Item() +} + +type ImplA struct { + A string `mock:"aaa"` +} + +func (i *ImplA) Item() { + fmt.Println("implA") +} + +type ImplB struct { + B string `mock:"bbb"` +} + +func (i *ImplB) Item() { + fmt.Println("implB") +} diff --git a/go.mod b/go.mod index 0305be1e..99b24947 100644 --- a/go.mod +++ b/go.mod @@ -37,4 +37,5 @@ require ( golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a google.golang.org/grpc v1.37.1 gopkg.in/yaml.v2 v2.4.0 + mvdan.cc/gofumpt v0.1.1 // indirect ) diff --git a/go.sum b/go.sum index c144e80d..91616ec7 100644 --- a/go.sum +++ b/go.sum @@ -333,6 +333,7 @@ github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqn github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= @@ -366,6 +367,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112 h1:NBrpnvz0pDPf3+HXZ1C9GcJd1DTpWDLcLWZhNq6uP7o= @@ -427,6 +429,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.0 h1:8pl+sMODzuvGJkmj2W4kZihvVb5mKm8pB/X44PIQHv8= +golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -512,6 +516,7 @@ golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210101214203-2dba1e4ea05c/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a h1:CB3a9Nez8M13wwlr/E2YtwoU+qYHKfC+JrDa45RXXoQ= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -589,6 +594,8 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +mvdan.cc/gofumpt v0.1.1 h1:bi/1aS/5W00E2ny5q65w9SnKpWEF/UIOqDYBILpo9rA= +mvdan.cc/gofumpt v0.1.1/go.mod h1:yXG1r1WqZVKWbVRtBWKWX9+CxGYfA51nSomhM0woR48= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= diff --git a/server/web/controller.go b/server/web/controller.go index a767fac4..c3a79ac8 100644 --- a/server/web/controller.go +++ b/server/web/controller.go @@ -250,7 +250,8 @@ func (c *Controller) Bind(obj interface{}) error { return c.BindJson(obj) } i, l := 0, len(ct[0]) - for ; i < l && ct[0][i] != ';'; i++ { + for i < l && ct[0][i] != ';' { + i++ } switch ct[0][0:i] { case "application/json": diff --git a/server/web/tree.go b/server/web/tree.go index 3716d4f4..4088e83e 100644 --- a/server/web/tree.go +++ b/server/web/tree.go @@ -284,7 +284,7 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, // Match router to runObject & params func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { - if len(pattern) == 0 || pattern[0] != '/' { + if pattern == "" || pattern[0] != '/' { return nil } w := make([]string, 0, 20) @@ -294,12 +294,13 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { if len(pattern) > 0 { i, l := 0, len(pattern) - for ; i < l && pattern[i] == '/'; i++ { + for i < l && pattern[i] == '/' { + i++ } pattern = pattern[i:] } // Handle leaf nodes: - if len(pattern) == 0 { + if pattern == "" { for _, l := range t.leaves { if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject @@ -316,7 +317,8 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } var seg string i, l := 0, len(pattern) - for ; i < l && pattern[i] != '/'; i++ { + for i < l && pattern[i] != '/' { + i++ } if i == 0 { seg = pattern @@ -327,7 +329,7 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } for _, subTree := range t.fixrouters { if subTree.prefix == seg { - if len(pattern) != 0 && pattern[0] == '/' { + if pattern != "" && pattern[0] == '/' { treePattern = pattern[1:] } else { treePattern = pattern diff --git a/task/task.go b/task/task.go index 00e67c4b..2bec0cc7 100644 --- a/task/task.go +++ b/task/task.go @@ -37,6 +37,7 @@ type taskManager struct { stop chan bool changed chan bool started bool + wait sync.WaitGroup } func newTaskManager() *taskManager { @@ -471,6 +472,11 @@ func ClearTask() { globalTaskManager.ClearTask() } +// GracefulShutdown wait all task done +func GracefulShutdown() <-chan struct{} { + return globalTaskManager.GracefulShutdown() +} + // StartTask start all tasks func (m *taskManager) StartTask() { m.taskLock.Lock() @@ -508,7 +514,7 @@ func (m *taskManager) run() { select { case now = <-time.After(effective.Sub(now)): // wait for effective time - runNextTasks(sortList, effective) + m.runNextTasks(sortList, effective) continue case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() @@ -540,7 +546,7 @@ func (m *taskManager) markManagerStop() { } // runNextTasks it runs next task which next run time is equal to effective -func runNextTasks(sortList *MapSorter, effective time.Time) { +func (m *taskManager) runNextTasks(sortList *MapSorter, effective time.Time) { // Run every entry whose next time was this effective time. var i = 0 for _, e := range sortList.Vals { @@ -551,8 +557,10 @@ func runNextTasks(sortList *MapSorter, effective time.Time) { // check if timeout is on, if yes passing the timeout context ctx := context.Background() + m.wait.Add(1) if duration := e.GetTimeout(ctx); duration != 0 { go func(e Tasker) { + defer m.wait.Done() ctx, cancelFunc := context.WithTimeout(ctx, duration) defer cancelFunc() err := e.Run(ctx) @@ -562,6 +570,7 @@ func runNextTasks(sortList *MapSorter, effective time.Time) { }(e) } else { go func(e Tasker) { + defer m.wait.Done() err := e.Run(ctx) if err != nil { log.Printf("tasker.run err: %s\n", err.Error()) @@ -581,6 +590,17 @@ func (m *taskManager) StopTask() { }() } +// GracefulShutdown wait all task done +func (m *taskManager) GracefulShutdown() <-chan struct{} { + done := make(chan struct{}) + go func() { + m.stop <- true + m.wait.Wait() + close(done) + }() + return done +} + // AddTask add task with name func (m *taskManager) AddTask(taskname string, t Tasker) { isChanged := false diff --git a/task/task_test.go b/task/task_test.go index 1078aa01..8d274e8f 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "testing" "time" @@ -177,6 +178,26 @@ func TestCrudTask(t *testing.T) { assert.Equal(t, 0, len(m.adminTaskList)) } +func TestGracefulShutdown(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + waitDone := atomic.Value{} + waitDone.Store(false) + tk := NewTask("everySecond", "* * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + time.Sleep(2 * time.Second) + waitDone.Store(true) + return nil + }) + m.AddTask("taska", tk) + m.StartTask() + time.Sleep(1 * time.Second) + shutdown := m.GracefulShutdown() + assert.False(t, waitDone.Load().(bool)) + <-shutdown + assert.True(t, waitDone.Load().(bool)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() {