Merge branch 'develop' into develop

This commit is contained in:
jianzhiyao 2021-06-01 10:25:42 +08:00 committed by GitHub
commit 402bda1aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1200 additions and 51 deletions

17
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,17 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 10
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"

View File

@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v1
- uses: actions/stale@v3.0.19
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue is inactive for a long time.'

View File

@ -2,6 +2,8 @@ language: go
go:
- "1.14.x"
- "1.15.x"
- "1.16.x"
services:
- redis-server
- mysql
@ -12,7 +14,7 @@ env:
global:
- GO_REPO_FULLNAME="github.com/beego/beego/v2"
matrix:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
- ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8"
before_install:

View File

@ -1,7 +1,11 @@
# developing
- Add template functions eq,lt to support uint and int compare. [4607](https://github.com/beego/beego/pull/4607)
- Add http client and option func. [4455](https://github.com/beego/beego/issues/4455)
- Add: Convenient way to generate mock object [4620](https://github.com/beego/beego/issues/4620)
- Infra: use dependabot to update dependencies. [4623](https://github.com/beego/beego/pull/4623)
- Lint: use golangci-lint. [4619](https://github.com/beego/beego/pull/4619)
- Chore: format code. [4615](https://github.com/beego/beego/pull/4615)
- Test on Go v1.15.x & v1.16.x. [4614](https://github.com/beego/beego/pull/4614)
- Env: non-empty GOBIN & GOPATH. [4613](https://github.com/beego/beego/pull/4613)
- Chore: update dependencies. [4611](https://github.com/beego/beego/pull/4611)
- Update orm_test.go/TestInsertOrUpdate with table-driven. [4609](https://github.com/beego/beego/pull/4609)
@ -45,10 +49,13 @@
- Optimize AddAutoPrefix: only register one router in case-insensitive mode. [4582](https://github.com/beego/beego/pull/4582)
- Init exceptMethod by using reflection. [4583](https://github.com/beego/beego/pull/4583)
- Deprecated BeeMap and replace all usage with `sync.map` [4616](https://github.com/beego/beego/pull/4616)
- TaskManager support graceful shutdown [4635](https://github.com/beego/beego/pull/4635)
## Fix Sonar
- [4624](https://github.com/beego/beego/pull/4624)
- [4608](https://github.com/beego/beego/pull/4608)
- [4473](https://github.com/beego/beego/pull/4473)
- [4474](https://github.com/beego/beego/pull/4474)
- [4479](https://github.com/beego/beego/pull/4479)
- [4639](https://github.com/beego/beego/pull/4639)

View File

@ -12,6 +12,11 @@ var (
ErrNotIntegerType = berror.Error(NotIntegerType, "item val is not (u)int (u)int32 (u)int64")
)
const (
MinUint32 uint32 = 0
MinUint64 uint64 = 0
)
func incr(originVal interface{}) (interface{}, error) {
switch val := originVal.(type) {
case int:
@ -75,12 +80,12 @@ func decr(originVal interface{}) (interface{}, error) {
}
return val - 1, nil
case uint32:
if val == 0 {
if val == MinUint32 {
return nil, ErrDecrementOverflow
}
return val - 1, nil
case uint64:
if val == 0 {
if val == MinUint64 {
return nil, ErrDecrementOverflow
}
return val - 1, nil

View File

@ -9,7 +9,7 @@ httplib is an libs help you to curl remote url.
you can use Get to crawl data.
import "github.com/beego/beego/v2/client/httplib"
str, err := httplib.Get("http://beego.me/").String()
if err != nil {
// error
@ -39,7 +39,7 @@ Example:
// GET
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
// POST
httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)

View File

@ -0,0 +1,155 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"crypto/tls"
"net/http"
"net/url"
"time"
)
type (
ClientOption func(client *Client)
BeegoHTTPRequestOption func(request *BeegoHTTPRequest)
)
// WithEnableCookie will enable cookie in all subsequent request
func WithEnableCookie(enable bool) ClientOption {
return func(client *Client) {
client.Setting.EnableCookie = enable
}
}
// WithEnableCookie will adds UA in all subsequent request
func WithUserAgent(userAgent string) ClientOption {
return func(client *Client) {
client.Setting.UserAgent = userAgent
}
}
// WithTLSClientConfig will adds tls config in all subsequent request
func WithTLSClientConfig(config *tls.Config) ClientOption {
return func(client *Client) {
client.Setting.TLSClientConfig = config
}
}
// WithTransport will set transport field in all subsequent request
func WithTransport(transport http.RoundTripper) ClientOption {
return func(client *Client) {
client.Setting.Transport = transport
}
}
// WithProxy will set http proxy field in all subsequent request
func WithProxy(proxy func(*http.Request) (*url.URL, error)) ClientOption {
return func(client *Client) {
client.Setting.Proxy = proxy
}
}
// WithCheckRedirect will specifies the policy for handling redirects in all subsequent request
func WithCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) ClientOption {
return func(client *Client) {
client.Setting.CheckRedirect = redirect
}
}
// WithHTTPSetting can replace beegoHTTPSeting
func WithHTTPSetting(setting BeegoHTTPSettings) ClientOption {
return func(client *Client) {
client.Setting = setting
}
}
// WithEnableGzip will enable gzip in all subsequent request
func WithEnableGzip(enable bool) ClientOption {
return func(client *Client) {
client.Setting.Gzip = enable
}
}
// BeegoHttpRequestOption
// WithTimeout sets connect time out and read-write time out for BeegoRequest.
func WithTimeout(connectTimeout, readWriteTimeout time.Duration) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.SetTimeout(connectTimeout, readWriteTimeout)
}
}
// WithHeader adds header item string in request.
func WithHeader(key, value string) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.Header(key, value)
}
}
// WithCookie adds a cookie to the request.
func WithCookie(cookie *http.Cookie) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.Header("Cookie", cookie.String())
}
}
// Withtokenfactory adds a custom function to set Authorization
func WithTokenFactory(tokenFactory func() string) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
t := tokenFactory()
request.Header("Authorization", t)
}
}
// WithBasicAuth adds a custom function to set basic auth
func WithBasicAuth(basicAuth func() (string, string)) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
username, password := basicAuth()
request.SetBasicAuth(username, password)
}
}
// WithFilters will use the filter as the invocation filters
func WithFilters(fcs ...FilterChain) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.SetFilters(fcs...)
}
}
// WithContentType adds ContentType in header
func WithContentType(contentType string) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.Header(contentTypeKey, contentType)
}
}
// WithParam adds query param in to request.
func WithParam(key, value string) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.Param(key, value)
}
}
// WithRetry set retry times and delay for the request
// default is 0 (never retry)
// -1 retry indefinitely (forever)
// Other numbers specify the exact retry amount
func WithRetry(times int, delay time.Duration) BeegoHTTPRequestOption {
return func(request *BeegoHTTPRequest) {
request.Retries(times)
request.RetryDelay(delay)
}
}

View File

@ -0,0 +1,261 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"errors"
"net"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type respCarrier struct {
bytes []byte
}
func (r *respCarrier) SetBytes(bytes []byte) {
r.bytes = bytes
}
func (r *respCarrier) String() string {
return string(r.bytes)
}
func TestOption_WithEnableCookie(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/",
WithEnableCookie(true))
if err != nil {
t.Fatal(err)
}
v := "smallfish"
resp := &respCarrier{}
err = client.Get(resp, "/cookies/set?k1="+v)
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
err = client.Get(resp, "/cookies")
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), v)
if n == -1 {
t.Fatal(v + " not found in cookie")
}
}
func TestOption_WithUserAgent(t *testing.T) {
v := "beego"
client, err := NewClient("test", "http://httpbin.org/",
WithUserAgent(v))
if err != nil {
t.Fatal(err)
}
resp := &respCarrier{}
err = client.Get(resp, "/headers")
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), v)
if n == -1 {
t.Fatal(v + " not found in user-agent")
}
}
func TestOption_WithCheckRedirect(t *testing.T) {
client, err := NewClient("test", "https://goolnk.com/33BD2j",
WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error {
return errors.New("Redirect triggered")
}))
if err != nil {
t.Fatal(err)
}
err = client.Get(nil, "")
assert.NotNil(t, err)
}
func TestOption_WithHTTPSetting(t *testing.T) {
v := "beego"
var setting BeegoHTTPSettings
setting.EnableCookie = true
setting.UserAgent = v
setting.Transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 50,
IdleConnTimeout: 90 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
setting.ReadWriteTimeout = 5 * time.Second
client, err := NewClient("test", "http://httpbin.org/",
WithHTTPSetting(setting))
if err != nil {
t.Fatal(err)
}
resp := &respCarrier{}
err = client.Get(resp, "/get")
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), v)
if n == -1 {
t.Fatal(v + " not found in user-agent")
}
}
func TestOption_WithHeader(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
client.CommonOpts = append(client.CommonOpts, WithHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36"))
resp := &respCarrier{}
err = client.Get(resp, "/headers")
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), "Mozilla/5.0")
if n == -1 {
t.Fatal("Mozilla/5.0 not found in user-agent")
}
}
func TestOption_WithTokenFactory(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
client.CommonOpts = append(client.CommonOpts,
WithTokenFactory(func() string {
return "testauth"
}))
resp := &respCarrier{}
err = client.Get(resp, "/headers")
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), "testauth")
if n == -1 {
t.Fatal("Auth is not set in request")
}
}
func TestOption_WithBasicAuth(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
resp := &respCarrier{}
err = client.Get(resp, "/basic-auth/user/passwd",
WithBasicAuth(func() (string, string) {
return "user", "passwd"
}))
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), "authenticated")
if n == -1 {
t.Fatal("authenticated not found in response")
}
}
func TestOption_WithContentType(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
v := "application/json"
resp := &respCarrier{}
err = client.Get(resp, "/headers", WithContentType(v))
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), v)
if n == -1 {
t.Fatal(v + " not found in header")
}
}
func TestOption_WithParam(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
v := "smallfish"
resp := &respCarrier{}
err = client.Get(resp, "/get", WithParam("username", v))
if err != nil {
t.Fatal(err)
}
t.Log(resp.String())
n := strings.Index(resp.String(), v)
if n == -1 {
t.Fatal(v + " not found in header")
}
}
func TestOption_WithRetry(t *testing.T) {
client, err := NewClient("test", "https://goolnk.com/33BD2j",
WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error {
return errors.New("Redirect triggered")
}))
if err != nil {
t.Fatal(err)
}
retryAmount := 1
retryDelay := 800 * time.Millisecond
startTime := time.Now().UnixNano() / int64(time.Millisecond)
_ = client.Get(nil, "", WithRetry(retryAmount, retryDelay))
endTime := time.Now().UnixNano() / int64(time.Millisecond)
elapsedTime := endTime - startTime
delayedTime := int64(retryAmount) * retryDelay.Milliseconds()
if elapsedTime < delayedTime {
t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime)
}
}

View File

@ -124,3 +124,11 @@ Make sure that:
1. You pass valid structure pointer to the function;
2. The body is valid YAML document
`)
var UnmarshalResponseToObjectFailed = berror.DefineCode(5001011, moduleName,
"UnmarshalResponseToObjectFailed", `
Beego trying to unmarshal response's body to structure but failed.
There are several cases that cause this error:
1. You pass valid structure pointer to the function;
2. The body is valid json, Yaml or XML document
`)

View File

@ -0,0 +1,174 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"bytes"
"io"
"io/ioutil"
"net/http"
)
// Client provides an HTTP client supporting chain call
type Client struct {
Name string
Endpoint string
CommonOpts []BeegoHTTPRequestOption
Setting BeegoHTTPSettings
}
// HTTPResponseCarrier If value implement HTTPResponseCarrier. http.Response will pass to SetHTTPResponse
type HTTPResponseCarrier interface {
SetHTTPResponse(resp *http.Response)
}
// HTTPBodyCarrier If value implement HTTPBodyCarrier. http.Response.Body will pass to SetReader
type HTTPBodyCarrier interface {
SetReader(r io.ReadCloser)
}
// HTTPBytesCarrier If value implement HTTPBytesCarrier.
// All the byte in http.Response.Body will pass to SetBytes
type HTTPBytesCarrier interface {
SetBytes(bytes []byte)
}
// HTTPStatusCarrier If value implement HTTPStatusCarrier. http.Response.StatusCode will pass to SetStatusCode
type HTTPStatusCarrier interface {
SetStatusCode(status int)
}
// HttpHeaderCarrier If value implement HttpHeaderCarrier. http.Response.Header will pass to SetHeader
type HTTPHeadersCarrier interface {
SetHeader(header map[string][]string)
}
// NewClient return a new http client
func NewClient(name string, endpoint string, opts ...ClientOption) (*Client, error) {
res := &Client{
Name: name,
Endpoint: endpoint,
}
setting := GetDefaultSetting()
res.Setting = setting
for _, o := range opts {
o(res)
}
return res, nil
}
func (c *Client) customReq(req *BeegoHTTPRequest, opts []BeegoHTTPRequestOption) {
req.Setting(c.Setting)
opts = append(c.CommonOpts, opts...)
for _, o := range opts {
o(req)
}
}
// handleResponse try to parse body to meaningful value
func (c *Client) handleResponse(value interface{}, req *BeegoHTTPRequest) error {
// make sure req.resp is not nil
_, err := req.Bytes()
if err != nil {
return err
}
err = c.handleCarrier(value, req)
if err != nil {
return err
}
return req.ToValue(value)
}
// handleCarrier set http data to value
func (c *Client) handleCarrier(value interface{}, req *BeegoHTTPRequest) error {
if value == nil {
return nil
}
if carrier, ok := value.(HTTPResponseCarrier); ok {
b, err := req.Bytes()
if err != nil {
return err
}
req.resp.Body = ioutil.NopCloser(bytes.NewReader(b))
carrier.SetHTTPResponse(req.resp)
}
if carrier, ok := value.(HTTPBodyCarrier); ok {
b, err := req.Bytes()
if err != nil {
return err
}
reader := ioutil.NopCloser(bytes.NewReader(b))
carrier.SetReader(reader)
}
if carrier, ok := value.(HTTPBytesCarrier); ok {
b, err := req.Bytes()
if err != nil {
return err
}
carrier.SetBytes(b)
}
if carrier, ok := value.(HTTPStatusCarrier); ok {
carrier.SetStatusCode(req.resp.StatusCode)
}
if carrier, ok := value.(HTTPHeadersCarrier); ok {
carrier.SetHeader(req.resp.Header)
}
return nil
}
// Get Send a GET request and try to give its result value
func (c *Client) Get(value interface{}, path string, opts ...BeegoHTTPRequestOption) error {
req := Get(c.Endpoint + path)
c.customReq(req, opts)
return c.handleResponse(value, req)
}
// Post Send a POST request and try to give its result value
func (c *Client) Post(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error {
req := Post(c.Endpoint + path)
c.customReq(req, opts)
if body != nil {
req = req.Body(body)
}
return c.handleResponse(value, req)
}
// Put Send a Put request and try to give its result value
func (c *Client) Put(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error {
req := Put(c.Endpoint + path)
c.customReq(req, opts)
if body != nil {
req = req.Body(body)
}
return c.handleResponse(value, req)
}
// Delete Send a Delete request and try to give its result value
func (c *Client) Delete(value interface{}, path string, opts ...BeegoHTTPRequestOption) error {
req := Delete(c.Endpoint + path)
c.customReq(req, opts)
return c.handleResponse(value, req)
}
// Head Send a Head request and try to give its result value
func (c *Client) Head(value interface{}, path string, opts ...BeegoHTTPRequestOption) error {
req := Head(c.Endpoint + path)
c.customReq(req, opts)
return c.handleResponse(value, req)
}

View File

@ -0,0 +1,220 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"encoding/xml"
"io"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewClient(t *testing.T) {
client, err := NewClient("test1", "http://beego.me", WithEnableCookie(true))
assert.NoError(t, err)
assert.NotNil(t, client)
assert.Equal(t, true, client.Setting.EnableCookie)
}
type slideShowResponse struct {
Resp *http.Response
bytes []byte
StatusCode int
Body io.ReadCloser
Header map[string][]string
Slideshow slideshow `json:"slideshow" yaml:"slideshow"`
}
func (r *slideShowResponse) SetHTTPResponse(resp *http.Response) {
r.Resp = resp
}
func (r *slideShowResponse) SetBytes(bytes []byte) {
r.bytes = bytes
}
func (r *slideShowResponse) SetReader(reader io.ReadCloser) {
r.Body = reader
}
func (r *slideShowResponse) SetStatusCode(status int) {
r.StatusCode = status
}
func (r *slideShowResponse) SetHeader(header map[string][]string) {
r.Header = header
}
func (r *slideShowResponse) String() string {
return string(r.bytes)
}
type slideshow struct {
XMLName xml.Name `xml:"slideshow"`
Title string `json:"title" yaml:"title" xml:"title,attr"`
Author string `json:"author" yaml:"author" xml:"author,attr"`
Date string `json:"date" yaml:"date" xml:"date,attr"`
Slides []slide `json:"slides" yaml:"slides" xml:"slide"`
}
type slide struct {
XMLName xml.Name `xml:"slide"`
Title string `json:"title" yaml:"title" xml:"title"`
}
func TestClient_handleCarrier(t *testing.T) {
v := "beego"
client, err := NewClient("test", "http://httpbin.org/",
WithUserAgent(v))
if err != nil {
t.Fatal(err)
}
s := &slideShowResponse{}
err = client.Get(s, "/json")
if err != nil {
t.Fatal(err)
}
defer s.Body.Close()
assert.NotNil(t, s.Resp)
assert.NotNil(t, s.Body)
assert.Equal(t, "429", s.Header["Content-Length"][0])
assert.Equal(t, 200, s.StatusCode)
b, err := ioutil.ReadAll(s.Body)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 429, len(b))
assert.Equal(t, s.String(), string(b))
}
func TestClient_Get(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org/")
if err != nil {
t.Fatal(err)
}
// json
var s *slideShowResponse
err = client.Get(&s, "/json")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "Sample Slide Show", s.Slideshow.Title)
assert.Equal(t, 2, len(s.Slideshow.Slides))
assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title)
// xml
var ssp *slideshow
err = client.Get(&ssp, "/base64/PD94bWwgPz48c2xpZGVzaG93CnRpdGxlPSJTYW1wbGUgU2xpZGUgU2hvdyIKZGF0ZT0iRGF0ZSBvZiBwdWJsaWNhdGlvbiIKYXV0aG9yPSJZb3VycyBUcnVseSI+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPldha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyE8L3RpdGxlPjwvc2xpZGU+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPk92ZXJ2aWV3PC90aXRsZT48aXRlbT5XaHkgPGVtPldvbmRlcldpZGdldHM8L2VtPiBhcmUgZ3JlYXQ8L2l0ZW0+PGl0ZW0vPjxpdGVtPldobyA8ZW0+YnV5czwvZW0+IFdvbmRlcldpZGdldHM8L2l0ZW0+PC9zbGlkZT48L3NsaWRlc2hvdz4=")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "Sample Slide Show", ssp.Title)
assert.Equal(t, 2, len(ssp.Slides))
assert.Equal(t, "Overview", ssp.Slides[1].Title)
// yaml
s = nil
err = client.Get(&s, "/base64/c2xpZGVzaG93OgogIGF1dGhvcjogWW91cnMgVHJ1bHkKICBkYXRlOiBkYXRlIG9mIHB1YmxpY2F0aW9uCiAgc2xpZGVzOgogIC0gdGl0bGU6IFdha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyEKICAgIHR5cGU6IGFsbAogIC0gaXRlbXM6CiAgICAtIFdoeSA8ZW0+V29uZGVyV2lkZ2V0czwvZW0+IGFyZSBncmVhdAogICAgLSBXaG8gPGVtPmJ1eXM8L2VtPiBXb25kZXJXaWRnZXRzCiAgICB0aXRsZTogT3ZlcnZpZXcKICAgIHR5cGU6IGFsbAogIHRpdGxlOiBTYW1wbGUgU2xpZGUgU2hvdw==")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "Sample Slide Show", s.Slideshow.Title)
assert.Equal(t, 2, len(s.Slideshow.Slides))
assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title)
}
func TestClient_Post(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org")
if err != nil {
t.Fatal(err)
}
resp := &slideShowResponse{}
err = client.Get(resp, "/json")
if err != nil {
t.Fatal(err)
}
jsonStr := resp.String()
err = client.Post(resp, "/post", jsonStr)
if err != nil {
t.Fatal(err)
}
assert.NotNil(t, resp)
assert.Equal(t, http.MethodPost, resp.Resp.Request.Method)
}
func TestClient_Put(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org")
if err != nil {
t.Fatal(err)
}
resp := &slideShowResponse{}
err = client.Get(resp, "/json")
if err != nil {
t.Fatal(err)
}
jsonStr := resp.String()
err = client.Put(resp, "/put", jsonStr)
if err != nil {
t.Fatal(err)
}
assert.NotNil(t, resp)
assert.Equal(t, http.MethodPut, resp.Resp.Request.Method)
}
func TestClient_Delete(t *testing.T) {
client, err := NewClient("test", "http://httpbin.org")
if err != nil {
t.Fatal(err)
}
resp := &slideShowResponse{}
err = client.Delete(resp, "/delete")
if err != nil {
t.Fatal(err)
}
defer resp.Resp.Body.Close()
assert.NotNil(t, resp)
assert.Equal(t, http.MethodDelete, resp.Resp.Request.Method)
}
func TestClient_Head(t *testing.T) {
client, err := NewClient("test", "http://beego.me")
if err != nil {
t.Fatal(err)
}
resp := &slideShowResponse{}
err = client.Head(resp, "")
if err != nil {
t.Fatal(err)
}
defer resp.Resp.Body.Close()
assert.NotNil(t, resp)
assert.Equal(t, http.MethodHead, resp.Resp.Request.Method)
}

View File

@ -124,7 +124,6 @@ type BeegoHTTPRequest struct {
setting BeegoHTTPSettings
resp *http.Response
body []byte
dump []byte
}
// GetRequest returns the request object
@ -199,7 +198,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
// SetProtocolVersion sets the protocol version for incoming requests.
// Client requests always use HTTP/1.1
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
if vers == "" {
vers = "HTTP/1.1"
}
@ -511,18 +510,16 @@ func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper {
DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
MaxIdleConnsPerHost: 100,
}
} else {
} else if t, ok := trans.(*http.Transport); ok {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.DialContext == nil {
t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.DialContext == nil {
t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
}
return trans
@ -656,6 +653,40 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.")
}
// ToValue attempts to resolve the response body to value using an existing method.
// Calls Response inner.
// If response header contain Content-Type, func will call ToJSON\ToXML\ToYAML.
// Else it will try to parse body as json\yaml\xml, If all attempts fail, an error will be returned
func (b *BeegoHTTPRequest) ToValue(value interface{}) error {
if value == nil {
return nil
}
contentType := strings.Split(b.resp.Header.Get(contentTypeKey), ";")[0]
// try to parse it as content type
switch contentType {
case "application/json":
return b.ToJSON(value)
case "text/xml", "application/xml":
return b.ToXML(value)
case "text/yaml", "application/x-yaml", "application/x+yaml":
return b.ToYAML(value)
}
// try to parse it anyway
if err := b.ToJSON(value); err == nil {
return nil
}
if err := b.ToYAML(value); err == nil {
return nil
}
if err := b.ToXML(value); err == nil {
return nil
}
return berror.Error(UnmarshalResponseToObjectFailed, "unmarshal body to object failed.")
}
// Response executes request client gets response manually.
func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse()

View File

@ -433,3 +433,7 @@ func TestBeegoHTTPRequestXMLBody(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, req.req.GetBody)
}
// TODO
func TestBeegoHTTPRequest_ResponseForValue(t *testing.T) {
}

View File

@ -57,7 +57,7 @@ func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondi
}
func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
res := true
var res bool
if len(sc.path) > 0 {
res = sc.matchPath(ctx, req)
} else if len(sc.pathReg) > 0 {

View File

@ -55,6 +55,11 @@ func SetDefaultSetting(setting BeegoHTTPSettings) {
defaultSetting = setting
}
// GetDefaultSetting return current default setting
func GetDefaultSetting() BeegoHTTPSettings {
return defaultSetting
}
var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,

View File

@ -120,25 +120,21 @@ func TestOrder_GetColumn(t *testing.T) {
}
}
func TestOrder_GetSort(t *testing.T) {
o := Clause(
SortDescending(),
)
if o.GetSort() != Descending {
t.Error()
}
}
func TestSortString(t *testing.T) {
template := "got: %s, want: %s"
func TestOrder_IsRaw(t *testing.T) {
o1 := Clause()
if o1.IsRaw() {
t.Error()
o1 := Clause(sort(Sort(1)))
if o1.SortString() != "ASC" {
t.Errorf(template, o1.SortString(), "ASC")
}
o2 := Clause(
Raw(),
)
if !o2.IsRaw() {
t.Error()
o2 := Clause(sort(Sort(2)))
if o2.SortString() != "DESC" {
t.Errorf(template, o2.SortString(), "DESC")
}
o3 := Clause(sort(Sort(3)))
if o3.SortString() != `` {
t.Errorf(template, o3.SortString(), ``)
}
}

View File

@ -1845,17 +1845,12 @@ func TestRawQueryRow(t *testing.T) {
case "id":
throwFail(t, AssertIs(id, 1))
break
case "time":
case "time", "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
assert.True(t, v.(time.Time).Sub(value) <= time.Second)
break
case "date":
case "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
assert.True(t, v.(time.Time).Sub(value) <= time.Second)
break
default:
throwFail(t, AssertIs(v, dataValues[col]))
}
@ -2769,6 +2764,7 @@ func TestStrPkInsert(t *testing.T) {
fmt.Println(err)
if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
} else if err == ErrLastInsertIdUnavailable {
return
} else {
throwFailNow(t, err)
}

142
core/bean/mock.go Normal file
View File

@ -0,0 +1,142 @@
package bean
import (
"fmt"
"reflect"
"strconv"
"strings"
)
// the mock object must be pointer of struct
// the element in mock object can be slices, structures, basic data types, pointers and interface
func Mock(v interface{}) (err error) {
pv := reflect.ValueOf(v)
// the input must be pointer of struct
if pv.Kind() != reflect.Ptr || pv.IsNil() {
err = fmt.Errorf("not a pointer of struct")
return
}
err = mock(pv)
return
}
func mock(pv reflect.Value) (err error) {
pt := pv.Type()
for i := 0; i < pt.Elem().NumField(); i++ {
ptt := pt.Elem().Field(i)
pvv := pv.Elem().FieldByName(ptt.Name)
if !pvv.CanSet() {
continue
}
kt := ptt.Type.Kind()
tagValue := ptt.Tag.Get("mock")
switch kt {
case reflect.Map:
continue
case reflect.Interface:
if pvv.IsNil() { // when interface is nil,can not sure the type
continue
}
pvv.Set(reflect.New(pvv.Elem().Type().Elem()))
err = mock(pvv.Elem())
case reflect.Ptr:
err = mockPtr(pvv, ptt.Type.Elem())
case reflect.Struct:
err = mock(pvv.Addr())
case reflect.Array, reflect.Slice:
err = mockSlice(tagValue, pvv)
case reflect.String:
pvv.SetString(tagValue)
case reflect.Bool:
err = mockBool(tagValue, pvv)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
value, e := strconv.ParseInt(tagValue, 10, 64)
if e != nil || pvv.OverflowInt(value) {
err = fmt.Errorf("the value:%s is invalid", tagValue)
}
pvv.SetInt(value)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
value, e := strconv.ParseUint(tagValue, 10, 64)
if e != nil || pvv.OverflowUint(value) {
err = fmt.Errorf("the value:%s is invalid", tagValue)
}
pvv.SetUint(value)
case reflect.Float32, reflect.Float64:
value, e := strconv.ParseFloat(tagValue, pvv.Type().Bits())
if e != nil || pvv.OverflowFloat(value) {
err = fmt.Errorf("the value:%s is invalid", tagValue)
}
pvv.SetFloat(value)
default:
}
if err != nil {
return
}
}
return
}
// mock slice value
func mockSlice(tagValue string, pvv reflect.Value) (err error) {
if tagValue == "" {
return
}
sliceMetas := strings.Split(tagValue, ":")
if len(sliceMetas) != 2 || sliceMetas[0] != "length" {
err = fmt.Errorf("the value:%s is invalid", tagValue)
return
}
length, e := strconv.Atoi(sliceMetas[1])
if e != nil {
return e
}
sliceType := reflect.SliceOf(pvv.Type().Elem()) // get slice type
itemType := sliceType.Elem() // get the type of item in slice
value := reflect.MakeSlice(sliceType, 0, length)
newSliceValue := make([]reflect.Value, 0, length)
for k := 0; k < length; k++ {
itemValue := reflect.New(itemType).Elem()
// if item in slice is struct or pointer,must set zero value
switch itemType.Kind() {
case reflect.Struct:
err = mock(itemValue.Addr())
case reflect.Ptr:
if itemValue.IsNil() {
itemValue.Set(reflect.New(itemType.Elem()))
if e := mock(itemValue); e != nil {
return e
}
}
}
newSliceValue = append(newSliceValue, itemValue)
if err != nil {
return
}
}
value = reflect.Append(value, newSliceValue...)
pvv.Set(value)
return
}
// mock bool value
func mockBool(tagValue string, pvv reflect.Value) (err error) {
switch tagValue {
case "true":
pvv.SetBool(true)
case "false":
pvv.SetBool(false)
default:
err = fmt.Errorf("the value:%s is invalid", tagValue)
}
return
}
// mock pointer
func mockPtr(pvv reflect.Value, ptt reflect.Type) (err error) {
if pvv.IsNil() {
pvv.Set(reflect.New(ptt)) // must set nil value to zero value
}
err = mock(pvv)
return
}

74
core/bean/mock_test.go Normal file
View File

@ -0,0 +1,74 @@
package bean
import (
"fmt"
"testing"
)
func TestMock(t *testing.T) {
type MockSubSubObject struct {
A int `mock:"20"`
}
type MockSubObjectAnoy struct {
Anoy int `mock:"20"`
}
type MockSubObject struct {
A bool `mock:"true"`
B MockSubSubObject
}
type MockObject struct {
A string `mock:"aaaaa"`
B int8 `mock:"10"`
C []*MockSubObject `mock:"length:2"`
D bool `mock:"true"`
E *MockSubObject
F []int `mock:"length:3"`
G InterfaceA
H InterfaceA
MockSubObjectAnoy
}
m := &MockObject{G: &ImplA{}}
err := Mock(m)
if err != nil {
t.Fatalf("mock failed: %v", err)
}
if m.A != "aaaaa" || m.B != 10 || m.C[1].B.A != 20 ||
!m.E.A || m.E.B.A != 20 || !m.D || len(m.F) != 3 {
t.Fail()
}
_, ok := m.G.(*ImplA)
if !ok {
t.Fail()
}
_, ok = m.G.(*ImplB)
if ok {
t.Fail()
}
_, ok = m.H.(*ImplA)
if ok {
t.Fail()
}
if m.Anoy != 20 {
t.Fail()
}
}
type InterfaceA interface {
Item()
}
type ImplA struct {
A string `mock:"aaa"`
}
func (i *ImplA) Item() {
fmt.Println("implA")
}
type ImplB struct {
B string `mock:"bbb"`
}
func (i *ImplB) Item() {
fmt.Println("implB")
}

1
go.mod
View File

@ -37,4 +37,5 @@ require (
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
google.golang.org/grpc v1.37.1
gopkg.in/yaml.v2 v2.4.0
mvdan.cc/gofumpt v0.1.1 // indirect
)

7
go.sum
View File

@ -333,6 +333,7 @@ github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqn
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
@ -366,6 +367,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112 h1:NBrpnvz0pDPf3+HXZ1C9GcJd1DTpWDLcLWZhNq6uP7o=
@ -427,6 +429,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.0 h1:8pl+sMODzuvGJkmj2W4kZihvVb5mKm8pB/X44PIQHv8=
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -512,6 +516,7 @@ golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210101214203-2dba1e4ea05c/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a h1:CB3a9Nez8M13wwlr/E2YtwoU+qYHKfC+JrDa45RXXoQ=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -589,6 +594,8 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
mvdan.cc/gofumpt v0.1.1 h1:bi/1aS/5W00E2ny5q65w9SnKpWEF/UIOqDYBILpo9rA=
mvdan.cc/gofumpt v0.1.1/go.mod h1:yXG1r1WqZVKWbVRtBWKWX9+CxGYfA51nSomhM0woR48=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU=

View File

@ -250,7 +250,8 @@ func (c *Controller) Bind(obj interface{}) error {
return c.BindJson(obj)
}
i, l := 0, len(ct[0])
for ; i < l && ct[0][i] != ';'; i++ {
for i < l && ct[0][i] != ';' {
i++
}
switch ct[0][0:i] {
case "application/json":

View File

@ -284,7 +284,7 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string,
// Match router to runObject & params
func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) {
if len(pattern) == 0 || pattern[0] != '/' {
if pattern == "" || pattern[0] != '/' {
return nil
}
w := make([]string, 0, 20)
@ -294,12 +294,13 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{
func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) {
if len(pattern) > 0 {
i, l := 0, len(pattern)
for ; i < l && pattern[i] == '/'; i++ {
for i < l && pattern[i] == '/' {
i++
}
pattern = pattern[i:]
}
// Handle leaf nodes:
if len(pattern) == 0 {
if pattern == "" {
for _, l := range t.leaves {
if ok := l.match(treePattern, wildcardValues, ctx); ok {
return l.runObject
@ -316,7 +317,8 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string
}
var seg string
i, l := 0, len(pattern)
for ; i < l && pattern[i] != '/'; i++ {
for i < l && pattern[i] != '/' {
i++
}
if i == 0 {
seg = pattern
@ -327,7 +329,7 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string
}
for _, subTree := range t.fixrouters {
if subTree.prefix == seg {
if len(pattern) != 0 && pattern[0] == '/' {
if pattern != "" && pattern[0] == '/' {
treePattern = pattern[1:]
} else {
treePattern = pattern

View File

@ -37,6 +37,7 @@ type taskManager struct {
stop chan bool
changed chan bool
started bool
wait sync.WaitGroup
}
func newTaskManager() *taskManager {
@ -471,6 +472,11 @@ func ClearTask() {
globalTaskManager.ClearTask()
}
// GracefulShutdown wait all task done
func GracefulShutdown() <-chan struct{} {
return globalTaskManager.GracefulShutdown()
}
// StartTask start all tasks
func (m *taskManager) StartTask() {
m.taskLock.Lock()
@ -508,7 +514,7 @@ func (m *taskManager) run() {
select {
case now = <-time.After(effective.Sub(now)): // wait for effective time
runNextTasks(sortList, effective)
m.runNextTasks(sortList, effective)
continue
case <-m.changed: // tasks have been changed, set all tasks run again now
now = time.Now().Local()
@ -540,7 +546,7 @@ func (m *taskManager) markManagerStop() {
}
// runNextTasks it runs next task which next run time is equal to effective
func runNextTasks(sortList *MapSorter, effective time.Time) {
func (m *taskManager) runNextTasks(sortList *MapSorter, effective time.Time) {
// Run every entry whose next time was this effective time.
var i = 0
for _, e := range sortList.Vals {
@ -551,8 +557,10 @@ func runNextTasks(sortList *MapSorter, effective time.Time) {
// check if timeout is on, if yes passing the timeout context
ctx := context.Background()
m.wait.Add(1)
if duration := e.GetTimeout(ctx); duration != 0 {
go func(e Tasker) {
defer m.wait.Done()
ctx, cancelFunc := context.WithTimeout(ctx, duration)
defer cancelFunc()
err := e.Run(ctx)
@ -562,6 +570,7 @@ func runNextTasks(sortList *MapSorter, effective time.Time) {
}(e)
} else {
go func(e Tasker) {
defer m.wait.Done()
err := e.Run(ctx)
if err != nil {
log.Printf("tasker.run err: %s\n", err.Error())
@ -581,6 +590,17 @@ func (m *taskManager) StopTask() {
}()
}
// GracefulShutdown wait all task done
func (m *taskManager) GracefulShutdown() <-chan struct{} {
done := make(chan struct{})
go func() {
m.stop <- true
m.wait.Wait()
close(done)
}()
return done
}
// AddTask add task with name
func (m *taskManager) AddTask(taskname string, t Tasker) {
isChanged := false

View File

@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@ -177,6 +178,26 @@ func TestCrudTask(t *testing.T) {
assert.Equal(t, 0, len(m.adminTaskList))
}
func TestGracefulShutdown(t *testing.T) {
m := newTaskManager()
defer m.ClearTask()
waitDone := atomic.Value{}
waitDone.Store(false)
tk := NewTask("everySecond", "* * * * * *", func(ctx context.Context) error {
fmt.Println("hello world")
time.Sleep(2 * time.Second)
waitDone.Store(true)
return nil
})
m.AddTask("taska", tk)
m.StartTask()
time.Sleep(1 * time.Second)
shutdown := m.GracefulShutdown()
assert.False(t, waitDone.Load().(bool))
<-shutdown
assert.True(t, waitDone.Load().(bool))
}
func wait(wg *sync.WaitGroup) chan bool {
ch := make(chan bool)
go func() {