diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 7e4b1032..50e91510 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -8,7 +8,7 @@ on: pull_request: types: [opened, synchronize, reopened, labeled, unlabeled] branches: - - master + - develop jobs: changelog: diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 00000000..85b159db --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,32 @@ +name: golangci-lint +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. + version: v1.29 + + # Optional: working directory, useful for monorepos +# working-directory: ./ + + # Optional: golangci-lint command line arguments. + args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true + + # Optional: show only new issues if it's a pull request. The default value is `false`. + only-new-issues: true + + # Optional: if set to true then the action will use pre-installed Go + # skip-go-installation: true \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index deb5c76c..e4a5971b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,19 @@ # developing +- ORM mock. [4407](https://github.com/beego/beego/pull/4407) +- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) +- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) +- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) +- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453) +- Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446) +- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) +- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456) - Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461) diff --git a/ERROR_SPECIFICATION.md b/ERROR_SPECIFICATION.md new file mode 100644 index 00000000..cf2f94b2 --- /dev/null +++ b/ERROR_SPECIFICATION.md @@ -0,0 +1 @@ +# Error Module diff --git a/adapter/router.go b/adapter/router.go index 900e3eb7..9a615efe 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) + (*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/adapter/toolbox/task.go b/adapter/toolbox/task.go index bdd6679f..7b7cd68a 100644 --- a/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { return o.delegate.GetPrev() } + +func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration { + return 0 +} diff --git a/client/httplib/error_code.go b/client/httplib/error_code.go new file mode 100644 index 00000000..961e7e34 --- /dev/null +++ b/client/httplib/error_code.go @@ -0,0 +1,131 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", ` +You pass a invalid url to httplib module. Please check your url, be careful about special character. +`) + +var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", ` +You pass a invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2 +But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2. +but you must confirm that server support those abnormal protocol version. +`) + +var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", ` +You use a invalid data as request body. +For now, we only support type string and byte[]. +`) + +var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", ` +You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well. +Sometimes you got XML document and you want to make it as request body. So you call XMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", ` +You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well. +Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", ` +You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well. +Sometimes you got JSON document and you want to make it as request body. So you call JSONBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + + +// start with 5 -------------------------------------------------------------------------- + +var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", ` +In normal case than handling files with BeegoRequest, you should not see this error code. +Unexpected EOF, invalid characters, bad file descriptor may cause this error. +`) + +var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", ` +There are several cases that cause this error: +1. file not found. Please check the file name; +2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code. +make sure that this file is in correct directory which Beego looks for; +3. Beego don't have the privilege to read this file, please change file mode; +`) + +var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", ` +When we try to read file content and then copy it to another writer, and failed. +1. Unexpected EOF; +2. Bad file descriptor; +3. Write conflict; + +Please check your file content, and confirm that file is not processed by other process (or by user manually). +`) + +var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", ` +After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor. +`) + +var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", ` +Beego send HTTP request, but it failed. +If you config retry times, it means that Beego had retried and failed. +When you got this error, there are vary kind of reason: +1. Network unstable and timeout. In this case, sometimes server has received the request. +2. Server error. Make sure that server works well. +3. The request is invalid, which means that you pass some invalid parameter. +`) + +var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", ` +Beego parse gzip-encode body failed. Usually Beego got invalid response. +Please confirm that server returns gzip data. +`) + +var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", ` +Beego want to create file if not exist and failed. +In most cases, it means that Beego doesn't have the privilege to create this file. +Please change file mode to ensure that Beego is able to create files on specific directory. +Or you can run Beego with higher authority. +In some cases, you pass invalid filename. Make sure that the file name is valid on your system. +`) + +var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid json document +`) + +var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid XML document +`) + +var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid YAML document +`) + + + + diff --git a/client/httplib/filter/prometheus/filter.go b/client/httplib/filter/prometheus/filter.go index 3d5acf12..5761eb7e 100644 --- a/client/httplib/filter/prometheus/filter.go +++ b/client/httplib/filter/prometheus/filter.go @@ -18,6 +18,7 @@ import ( "context" "net/http" "strconv" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -26,24 +27,30 @@ import ( ) type FilterChainBuilder struct { - summaryVec prometheus.ObserverVec AppName string ServerName string RunMode string } +var summaryVec prometheus.ObserverVec +var initSummaryVec sync.Once + func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { - builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "beego", - Subsystem: "remote_http_request", - ConstLabels: map[string]string{ - "server": builder.ServerName, - "env": builder.RunMode, - "appname": builder.AppName, - }, - Help: "The statics info for remote http requests", - }, []string{"proto", "scheme", "method", "host", "path", "status", "isError"}) + initSummaryVec.Do(func() { + summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "remote_http_request", + ConstLabels: map[string]string{ + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, + }, + Help: "The statics info for remote http requests", + }, []string{"proto", "scheme", "method", "host", "path", "status", "isError"}) + + prometheus.MustRegister(summaryVec) + }) return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { startTime := time.Now() @@ -72,6 +79,6 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time dur := int(endTime.Sub(startTime) / time.Millisecond) - builder.summaryVec.WithLabelValues(proto, scheme, method, host, path, + summaryVec.WithLabelValues(proto, scheme, method, host, path, strconv.Itoa(status), strconv.FormatBool(err != nil)).Observe(float64(dur)) } diff --git a/client/httplib/http_response_test.go b/client/httplib/http_response_test.go new file mode 100644 index 00000000..90db3fca --- /dev/null +++ b/client/httplib/http_response_test.go @@ -0,0 +1,36 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + + +func TestNewHttpResponseWithJsonBody(t *testing.T) { + // string + resp := NewHttpResponseWithJsonBody("{}") + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody([]byte("{}")) + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody(&user{ + Name: "Tom", + }) + assert.True(t, resp.ContentLength > 0) +} diff --git a/client/httplib/httplib.go b/client/httplib/httplib.go index 9402eca6..346de7af 100644 --- a/client/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -40,59 +40,36 @@ import ( "encoding/xml" "io" "io/ioutil" - "log" "mime/multipart" "net" "net/http" - "net/http/cookiejar" "net/http/httputil" "net/url" "os" "path" "strings" - "sync" "time" "gopkg.in/yaml.v2" + + "github.com/beego/beego/v2/core/berror" + "github.com/beego/beego/v2/core/logs" ) -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, - FilterChains: []FilterChain{mockFilter.FilterChain}, -} - -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex - // it will be the last filter and execute request.Do var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { return req.doRequest(ctx) } -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} - -// SetDefaultSetting overwrites default settings -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting -} - // NewBeegoRequest returns *BeegoHttpRequest with specific method +// TODO add error as return value +// I think if we don't return error +// users are hard to check whether we create Beego request successfully func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) if err != nil { - log.Println("Httplib:", err) + logs.Error("%+v", berror.Wrapf(err, InvalidUrl, "invalid raw url: %s", rawurl)) } req := http.Request{ URL: u, @@ -137,24 +114,6 @@ func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } -// BeegoHTTPSettings is the http.Client setting -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever - RetryDelay time.Duration - FilterChains []FilterChain -} - // BeegoHTTPRequest provides more useful methods than http.Request for requesting a url. type BeegoHTTPRequest struct { url string @@ -254,7 +213,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { } // SetProtocolVersion sets the protocol version for incoming requests. -// Client requests always use HTTP/1.1. +// Client requests always use HTTP/1.1 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { if len(vers) == 0 { vers = "HTTP/1.1" @@ -265,8 +224,9 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { b.req.Proto = vers b.req.ProtoMajor = major b.req.ProtoMinor = minor + return b } - + logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers)) return b } @@ -334,6 +294,7 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest // Body adds request raw body. // Supports string and []byte. +// TODO return error if data is invalid func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: @@ -350,6 +311,8 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { return ioutil.NopCloser(bf), nil } b.req.ContentLength = int64(len(t)) + default: + logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t)) } return b } @@ -359,9 +322,12 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(byts)), nil + } b.req.ContentLength = int64(len(byts)) b.req.Header.Set("Content-Type", "application/xml") } @@ -373,7 +339,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) @@ -387,7 +353,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) if b.req.Body == nil && obj != nil { byts, err := json.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) @@ -415,28 +381,15 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { bodyWriter := multipart.NewWriter(pw) go func() { for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - // iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } + b.handleFileToBody(bodyWriter, formname, filename) } for k, v := range b.params { for _, vv := range v { - bodyWriter.WriteField(k, vv) + _ = bodyWriter.WriteField(k, vv) } } - bodyWriter.Close() - pw.Close() + _ = bodyWriter.Close() + _ = pw.Close() }() b.Header("Content-Type", bodyWriter.FormDataContentType()) b.req.Body = ioutil.NopCloser(pr) @@ -452,6 +405,29 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { } } +func (b *BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + const errFmt = "Httplib: %+v" + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed, + "could not create form file, formname: %s, filename: %s", formname, filename)) + } + fh, err := os.Open(filename) + + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename)) + } + // iocopy + _, err = io.Copy(fileWriter, fh) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename)) + } + err = fh.Close() + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename)) + } +} + func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { if b.resp.StatusCode != 0 { return b.resp, nil @@ -480,62 +456,20 @@ func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Res return root(ctx, b) } -func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } - } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] - } +func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (*http.Response, error) { + paramBody := b.buildParamBody() b.buildURL(paramBody) urlParsed, err := url.Parse(b.url) if err != nil { - return nil, err + return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url) } b.req.URL = urlParsed - trans := b.setting.Transport + trans := b.buildTrans() - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } + jar := b.buildCookieJar() client := &http.Client{ Transport: trans, @@ -551,12 +485,16 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, } if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) + dump, e := httputil.DumpRequest(b.req, b.setting.DumpBody) + if e != nil { + logs.Error("%+v", e) } b.dump = dump } + return b.sendRequest(client) +} + +func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success // retries is setted, it will retries fixed times. @@ -564,11 +502,68 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { - break + return } time.Sleep(b.setting.RetryDelay) } - return resp, err + return nil, berror.Wrap(err, SendRequestFailed, "sending request fail") +} + +func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar { + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + return jar +} + +func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper { + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.DialContext == nil { + t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + return trans +} + +func (b *BeegoHTTPRequest) buildParamBody() string { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + return paramBody } // String returns the body string in response. @@ -599,10 +594,10 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { reader, err := gzip.NewReader(resp.Body) if err != nil { - return nil, err + return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed") } b.body, err = ioutil.ReadAll(reader) - return b.body, err + return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed") } b.body, err = ioutil.ReadAll(resp.Body) return b.body, err @@ -645,7 +640,7 @@ func pathExistAndMkdir(filename string) (err error) { return nil } } - return err + return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename) } // ToJSON returns the map that marshals from the body bytes as json in response. @@ -655,7 +650,8 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { if err != nil { return err } - return json.Unmarshal(data, v) + return berror.Wrap(json.Unmarshal(data, v), + UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.") } // ToXML returns the map that marshals from the body bytes as xml in response . @@ -665,7 +661,8 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { if err != nil { return err } - return xml.Unmarshal(data, v) + return berror.Wrap(xml.Unmarshal(data, v), + UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.") } // ToYAML returns the map that marshals from the body bytes as yaml in response . @@ -675,7 +672,8 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { if err != nil { return err } - return yaml.Unmarshal(data, v) + return berror.Wrap(yaml.Unmarshal(data, v), + UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.") } // Response executes request client gets response manually. @@ -684,8 +682,18 @@ func (b *BeegoHTTPRequest) Response() (*http.Response, error) { } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +// Deprecated +// we will move this at the end of 2021 +// please use TimeoutDialerCtx func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { return func(netw, addr string) (net.Conn, error) { + return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr) + } +} + +func TimeoutDialerCtx(cTimeout time.Duration, + rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { conn, err := net.DialTimeout(netw, addr, cTimeout) if err != nil { return nil, err diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index d0f826cb..36948de7 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -300,3 +300,136 @@ func TestAddFilter(t *testing.T) { r := Get("http://beego.me") assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) } + +func TestFilterChainOrder(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + +func TestHead(t *testing.T) { + req := Head("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "HEAD", req.req.Method) +} + +func TestDelete(t *testing.T) { + req := Delete("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "DELETE", req.req.Method) +} + +func TestPost(t *testing.T) { + req := Post("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "POST", req.req.Method) +} + +func TestNewBeegoRequest(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + assert.NotNil(t, req) + assert.Equal(t, "GET", req.req.Method) + + // invalid case but still go request + req = NewBeegoRequest("httpa\ta://beego.me", "GET") + assert.NotNil(t, req) +} + +func TestBeegoHTTPRequest_SetProtocolVersion(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + req.SetProtocolVersion("HTTP/3.10") + assert.Equal(t, "HTTP/3.10", req.req.Proto) + assert.Equal(t, 3, req.req.ProtoMajor) + assert.Equal(t, 10, req.req.ProtoMinor) + + req.SetProtocolVersion("") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) + + // invalid case + req.SetProtocolVersion("HTTP/aaa1.1") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) +} + +func TestPut(t *testing.T) { + req := Put("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "PUT", req.req.Method) +} + +func TestBeegoHTTPRequest_Header(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-header", "test-header-value" + req.Header(key, value) + assert.Equal(t, value, req.req.Header.Get(key)) +} + +func TestBeegoHTTPRequest_SetHost(t *testing.T) { + req := Post("http://beego.me") + host := "test-hose" + req.SetHost(host) + assert.Equal(t, host, req.req.Host) +} + +func TestBeegoHTTPRequest_Param(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-param", "test-param-value" + req.Param(key, value) + assert.Equal(t, value, req.params[key][0]) + + value1 := "test-param-value-1" + req.Param(key, value1) + assert.Equal(t, value1, req.params[key][1]) +} + +func TestBeegoHTTPRequest_Body(t *testing.T) { + req := Post("http://beego.me") + body := `hello, world` + req.Body([]byte(body)) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + body = "hhhh, i am test" + req.Body(body) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + // invalid case + req.Body(13) +} + + +type user struct { + Name string `xml:"name"` +} +func TestBeegoHTTPRequest_XMLBody(t *testing.T) { + req := Post("http://beego.me") + body := &user{ + Name: "Tom", + } + _, err := req.XMLBody(body) + assert.True(t, req.req.ContentLength > 0) + assert.Nil(t, err) + assert.NotNil(t, req.req.GetBody) +} \ No newline at end of file diff --git a/client/httplib/mock.go b/client/httplib/mock/mock.go similarity index 90% rename from client/httplib/mock.go rename to client/httplib/mock/mock.go index 691f03d2..7640e454 100644 --- a/client/httplib/mock.go +++ b/client/httplib/mock/mock.go @@ -12,17 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "context" "net/http" + "github.com/beego/beego/v2/client/httplib" "github.com/beego/beego/v2/core/logs" ) const mockCtxKey = "beego-httplib-mock" +func init() { + InitMockSetting() +} + type Stub interface { Mock(cond RequestCondition, resp *http.Response, err error) Clear() @@ -31,6 +36,10 @@ type Stub interface { var mockFilter = &MockResponseFilter{} +func InitMockSetting() { + httplib.AddDefaultFilter(mockFilter.FilterChain) +} + func StartMock() Stub { return mockFilter } diff --git a/client/httplib/mock_condition.go b/client/httplib/mock/mock_condition.go similarity index 88% rename from client/httplib/mock_condition.go rename to client/httplib/mock/mock_condition.go index 5e6ff455..639b45a3 100644 --- a/client/httplib/mock_condition.go +++ b/client/httplib/mock/mock_condition.go @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "context" "encoding/json" "net/textproto" "regexp" + + "github.com/beego/beego/v2/client/httplib" ) type RequestCondition interface { - Match(ctx context.Context, req *BeegoHTTPRequest) bool + Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool } // reqCondition create condition @@ -54,7 +56,7 @@ func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondi return sc } -func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { res := true if len(sc.path) > 0 { res = sc.matchPath(ctx, req) @@ -70,12 +72,12 @@ func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) boo sc.matchBodyFields(ctx, req) } -func (sc *SimpleCondition) matchPath(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { path := req.GetRequest().URL.Path return path == sc.path } -func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { path := req.GetRequest().URL.Path if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil { return b @@ -83,7 +85,7 @@ func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPReque return false } -func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { qs := req.GetRequest().URL.Query() for k, v := range sc.query { if uv, ok := qs[k]; !ok || uv[0] != v { @@ -93,7 +95,7 @@ func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest return true } -func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { headers := req.GetRequest().Header for k, v := range sc.header { if uv, ok := headers[k]; !ok || uv[0] != v { @@ -103,7 +105,7 @@ func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPReques return true } -func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { if len(sc.body) == 0 { return true } @@ -135,7 +137,7 @@ func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRe return true } -func (sc *SimpleCondition) matchMethod(ctx context.Context, req *BeegoHTTPRequest) bool { +func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { if len(sc.method) > 0 { return sc.method == req.GetRequest().Method } diff --git a/client/httplib/mock_condition_test.go b/client/httplib/mock/mock_condition_test.go similarity index 74% rename from client/httplib/mock_condition_test.go rename to client/httplib/mock/mock_condition_test.go index 643dc353..4fc6d377 100644 --- a/client/httplib/mock_condition_test.go +++ b/client/httplib/mock/mock_condition_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "context" @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/client/httplib" ) func init() { @@ -28,37 +29,37 @@ func init() { func TestSimpleCondition_MatchPath(t *testing.T) { sc := NewSimpleCondition("/abc/s") - res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s")) + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s")) assert.True(t, res) } func TestSimpleCondition_MatchQuery(t *testing.T) { k, v := "my-key", "my-value" sc := NewSimpleCondition("/abc/s") - res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value")) + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) assert.True(t, res) sc = NewSimpleCondition("/abc/s", WithQuery(k, v)) - res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value")) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) assert.True(t, res) - res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-valuesss")) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss")) assert.False(t, res) - res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key-a=my-value")) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value")) assert.False(t, res) - res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello")) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello")) assert.True(t, res) } func TestSimpleCondition_MatchHeader(t *testing.T) { k, v := "my-header", "my-header-value" sc := NewSimpleCondition("/abc/s") - req := Get("http://localhost:8080/abc/s") + req := httplib.Get("http://localhost:8080/abc/s") assert.True(t, sc.Match(context.Background(), req)) - req = Get("http://localhost:8080/abc/s") + req = httplib.Get("http://localhost:8080/abc/s") req.Header(k, v) assert.True(t, sc.Match(context.Background(), req)) @@ -73,7 +74,7 @@ func TestSimpleCondition_MatchHeader(t *testing.T) { func TestSimpleCondition_MatchBodyField(t *testing.T) { sc := NewSimpleCondition("/abc/s") - req := Post("http://localhost:8080/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") assert.True(t, sc.Match(context.Background(), req)) @@ -102,7 +103,7 @@ func TestSimpleCondition_MatchBodyField(t *testing.T) { func TestSimpleCondition_Match(t *testing.T) { sc := NewSimpleCondition("/abc/s") - req := Post("http://localhost:8080/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") assert.True(t, sc.Match(context.Background(), req)) @@ -115,9 +116,9 @@ func TestSimpleCondition_Match(t *testing.T) { func TestSimpleCondition_MatchPathReg(t *testing.T) { sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`)) - req := Post("http://localhost:8080/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") assert.True(t, sc.Match(context.Background(), req)) - req = Post("http://localhost:8080/abcd/s") + req = httplib.Post("http://localhost:8080/abcd/s") assert.False(t, sc.Match(context.Background(), req)) } diff --git a/client/httplib/mock_filter.go b/client/httplib/mock/mock_filter.go similarity index 87% rename from client/httplib/mock_filter.go rename to client/httplib/mock/mock_filter.go index 83a7b71b..225d65f3 100644 --- a/client/httplib/mock_filter.go +++ b/client/httplib/mock/mock_filter.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "context" - "fmt" "net/http" + + "github.com/beego/beego/v2/client/httplib" ) // MockResponse will return mock response if find any suitable mock data @@ -32,13 +33,10 @@ func NewMockResponseFilter() *MockResponseFilter { } } -func (m *MockResponseFilter) FilterChain(next Filter) Filter { - return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { - +func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { ms := mockFromCtx(ctx) ms = append(ms, m.ms...) - - fmt.Printf("url: %s, mock: %d \n", req.url, len(ms)) for _, mock := range ms { if mock.cond.Match(ctx, req) { return mock.resp, mock.err diff --git a/client/httplib/mock_filter_test.go b/client/httplib/mock/mock_filter_test.go similarity index 79% rename from client/httplib/mock_filter_test.go rename to client/httplib/mock/mock_filter_test.go index 40a2185e..b27e772e 100644 --- a/client/httplib/mock_filter_test.go +++ b/client/httplib/mock/mock_filter_test.go @@ -12,20 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "errors" "testing" "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" ) func TestMockResponseFilter_FilterChain(t *testing.T) { - req := Get("http://localhost:8080/abc/s") + req := httplib.Get("http://localhost:8080/abc/s") ft := NewMockResponseFilter() - expectedResp := NewHttpResponseWithJsonBody(`{}`) + expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`) expectedErr := errors.New("expected error") ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr) @@ -35,16 +37,16 @@ func TestMockResponseFilter_FilterChain(t *testing.T) { assert.Equal(t, expectedErr, err) assert.Equal(t, expectedResp, resp) - req = Get("http://localhost:8080/abcd/s") + req = httplib.Get("http://localhost:8080/abcd/s") req.AddFilters(ft.FilterChain) resp, err = req.DoRequest() assert.NotEqual(t, expectedErr, err) assert.NotEqual(t, expectedResp, resp) - req = Get("http://localhost:8080/abc/s") + req = httplib.Get("http://localhost:8080/abc/s") req.AddFilters(ft.FilterChain) - expectedResp1 := NewHttpResponseWithJsonBody(map[string]string{}) + expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{}) expectedErr1 := errors.New("expected error") ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) @@ -52,7 +54,7 @@ func TestMockResponseFilter_FilterChain(t *testing.T) { assert.Equal(t, expectedErr, err) assert.Equal(t, expectedResp, resp) - req = Get("http://localhost:8080/abc/abs/bbc") + req = httplib.Get("http://localhost:8080/abc/abs/bbc") req.AddFilters(ft.FilterChain) ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) resp, err = req.DoRequest() diff --git a/client/httplib/mock_test.go b/client/httplib/mock/mock_test.go similarity index 75% rename from client/httplib/mock_test.go rename to client/httplib/mock/mock_test.go index 1d913b29..e73e8a6a 100644 --- a/client/httplib/mock_test.go +++ b/client/httplib/mock/mock_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package httplib +package mock import ( "context" @@ -21,16 +21,18 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" ) func TestStartMock(t *testing.T) { - defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain} + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} stub := StartMock() // defer stub.Clear() - expectedResp := NewHttpResponseWithJsonBody([]byte(`{}`)) + expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) expectedErr := errors.New("expected err") stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr) @@ -45,14 +47,14 @@ func TestStartMock(t *testing.T) { // TestStartMock_Isolation Test StartMock that // mock only work for this request func TestStartMock_Isolation(t *testing.T) { - defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain} + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} // setup global stub stub := StartMock() - globalMockResp := NewHttpResponseWithJsonBody([]byte(`{}`)) + globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) globalMockErr := errors.New("expected err") stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr) - expectedResp := NewHttpResponseWithJsonBody(struct { + expectedResp := httplib.NewHttpResponseWithJsonBody(struct { A string `json:"a"` }{ A: "aaa", @@ -67,9 +69,9 @@ func TestStartMock_Isolation(t *testing.T) { } func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) { - return Get("http://localhost:7777/abc").DoRequestWithCtx(ctx) + return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx) } func OriginalCodeUsingHttplib() (*http.Response, error){ - return Get("http://localhost:7777/abc").DoRequest() + return httplib.Get("http://localhost:7777/abc").DoRequest() } diff --git a/client/httplib/module.go b/client/httplib/module.go new file mode 100644 index 00000000..8503133c --- /dev/null +++ b/client/httplib/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +const moduleName = "httplib" diff --git a/client/httplib/setting.go b/client/httplib/setting.go new file mode 100644 index 00000000..c8d049e0 --- /dev/null +++ b/client/httplib/setting.go @@ -0,0 +1,81 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + "net/url" + "sync" + "time" +) + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + DumpBody bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration + FilterChains []FilterChain +} + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting overwrites default settings +// Keep in mind that when you invoke the SetDefaultSetting +// some methods invoked before SetDefaultSetting +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + DumpBody: true, + FilterChains: make([]FilterChain, 0, 4), +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// AddDefaultFilter add a new filter into defaultSetting +// Be careful about using this method if you invoke SetDefaultSetting somewhere +func AddDefaultFilter(fc FilterChain) { + settingMutex.Lock() + defer settingMutex.Unlock() + if defaultSetting.FilterChains == nil { + defaultSetting.FilterChains = make([]FilterChain, 0, 4) + } + defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc) +} \ No newline at end of file diff --git a/client/orm/clauses/const.go b/client/orm/clauses/const.go new file mode 100644 index 00000000..747d3fd7 --- /dev/null +++ b/client/orm/clauses/const.go @@ -0,0 +1,6 @@ +package clauses + +const ( + ExprSep = "__" + ExprDot = "." +) diff --git a/client/orm/clauses/order_clause/order.go b/client/orm/clauses/order_clause/order.go new file mode 100644 index 00000000..e45c2f85 --- /dev/null +++ b/client/orm/clauses/order_clause/order.go @@ -0,0 +1,103 @@ +package order_clause + +import ( + "github.com/beego/beego/v2/client/orm/clauses" + "strings" +) + +type Sort int8 + +const ( + None Sort = 0 + Ascending Sort = 1 + Descending Sort = 2 +) + +type Option func(order *Order) + +type Order struct { + column string + sort Sort + isRaw bool +} + +func Clause(options ...Option) *Order { + o := &Order{} + for _, option := range options { + option(o) + } + + return o +} + +func (o *Order) GetColumn() string { + return o.column +} + +func (o *Order) GetSort() Sort { + return o.sort +} + +func (o *Order) SortString() string { + switch o.GetSort() { + case Ascending: + return "ASC" + case Descending: + return "DESC" + } + + return `` +} + +func (o *Order) IsRaw() bool { + return o.isRaw +} + +func ParseOrder(expressions ...string) []*Order { + var orders []*Order + for _, expression := range expressions { + sort := Ascending + column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) + if column[0] == '-' { + sort = Descending + column = column[1:] + } + + orders = append(orders, &Order{ + column: column, + sort: sort, + }) + } + + return orders +} + +func Column(column string) Option { + return func(order *Order) { + order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) + } +} + +func sort(sort Sort) Option { + return func(order *Order) { + order.sort = sort + } +} + +func SortAscending() Option { + return sort(Ascending) +} + +func SortDescending() Option { + return sort(Descending) +} + +func SortNone() Option { + return sort(None) +} + +func Raw() Option { + return func(order *Order) { + order.isRaw = true + } +} diff --git a/client/orm/clauses/order_clause/order_test.go b/client/orm/clauses/order_clause/order_test.go new file mode 100644 index 00000000..172e7492 --- /dev/null +++ b/client/orm/clauses/order_clause/order_test.go @@ -0,0 +1,144 @@ +package order_clause + +import ( + "testing" +) + +func TestClause(t *testing.T) { + var ( + column = `a` + ) + + o := Clause( + Column(column), + ) + + if o.GetColumn() != column { + t.Error() + } +} + +func TestSortAscending(t *testing.T) { + o := Clause( + SortAscending(), + ) + + if o.GetSort() != Ascending { + t.Error() + } +} + +func TestSortDescending(t *testing.T) { + o := Clause( + SortDescending(), + ) + + if o.GetSort() != Descending { + t.Error() + } +} + +func TestSortNone(t *testing.T) { + o1 := Clause( + SortNone(), + ) + + if o1.GetSort() != None { + t.Error() + } + + o2 := Clause() + + if o2.GetSort() != None { + t.Error() + } +} + +func TestRaw(t *testing.T) { + o1 := Clause() + + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestColumn(t *testing.T) { + o1 := Clause( + Column(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + +func TestParseOrder(t *testing.T) { + orders := ParseOrder( + `-user__status`, + `status`, + `user__status`, + ) + + t.Log(orders) + + if orders[0].GetSort() != Descending { + t.Error() + } + + if orders[0].GetColumn() != `user.status` { + t.Error() + } + + if orders[1].GetColumn() != `status` { + t.Error() + } + + if orders[1].GetSort() != Ascending { + t.Error() + } + + if orders[2].GetColumn() != `user.status` { + t.Error() + } + +} + +func TestOrder_GetColumn(t *testing.T) { + o := Clause( + Column(`user__id`), + ) + if o.GetColumn() != `user.id` { + t.Error() + } +} + +func TestOrder_GetSort(t *testing.T) { + o := Clause( + SortDescending(), + ) + if o.GetSort() != Descending { + t.Error() + } +} + +func TestOrder_IsRaw(t *testing.T) { + o1 := Clause() + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + if !o2.IsRaw() { + t.Error() + } +} diff --git a/client/orm/cmd.go b/client/orm/cmd.go index b0661971..b377a5f2 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,6 +15,7 @@ package orm import ( + "context" "flag" "fmt" "os" @@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } + ctx := context.Background() for i, mi := range modelCache.allOrdered() { if !isApplicableTableForDB(mi.addrField, d.al.Name) { @@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/client/orm/db.go b/client/orm/db.go index 4080f292..a49d6df7 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() dbcols := make([]string, 0, len(mi.fields.dbcols)) @@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } // insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err @@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } // execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } // multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err @@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err @@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) if err != nil && err.Error() == `pq: syntax error at or near "ON"` { @@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } // execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) + res, err := q.ExecContext(ctx, query, setValues...) if err == nil { return res.RowsAffected() } @@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { @@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } + res, err := q.ExecContext(ctx, query, values...) if err == nil { return res.RowsAffected() } @@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * } // delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } // read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) - errTyp := true + unregister := true one := true isPtr := true + name := "" if val.Kind() == reflect.Ptr { fn := "" @@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi case reflect.Struct: isPtr = false fn = getFullName(typ) + name = getTableName(reflect.New(typ)) } } else { fn = getFullName(ind.Type()) + name = getTableName(ind) } - errTyp = fn != mi.fullName + unregister = fn != mi.fullName } - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } + if unregister { + RegisterModel(container) } rlimit := qs.limit @@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } + if qs.aggregate != "" { + sels = qs.aggregate + } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) @@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + slice := ind + if unregister { + mi, _ = modelCache.get(name) + tCols = mi.fields.dbcols + colsNum = len(tCols) } refs := make([]interface{}, colsNum) @@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var ref interface{} refs[i] = &ref } - - defer rs.Close() - - slice := ind - var cnt int64 for rs.Next() { if one && cnt == 0 || !one { @@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } // excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } @@ -1649,7 +1631,7 @@ setValue: } // query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params @@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond d.ins.ReplaceMarks(&query) - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { return nil } @@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } // get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index ee68baf7..c89b1e52 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" "strings" @@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index 1de440b6..a3b93ff3 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strings" @@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index 12431d6e..b2f321db 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strconv" ) @@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string mi.table, name, Q, name, Q, Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index aff713a5..6a4b3131 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -73,11 +74,11 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } - return d.dbBase.Read(q, mi, ind, tz, cols, false) + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } // get sqlite operator. @@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { } // get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 5fd472d1..d62d8106 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,6 +16,8 @@ package orm import ( "fmt" + "github.com/beego/beego/v2/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "strings" "time" ) @@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) + column := order.GetColumn() + clause := strings.Split(column, clauses.ExprDot) - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } + if order.IsRaw() { + if len(clause) == 2 { + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) + } else if len(clause) == 1 { + orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) + } else { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } + } else { + index, _, fi, suc := t.parseExprs(t.mi, clause) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) + } } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 6020a488..48c5b4e7 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go index c6da420d..59ffe877 100644 --- a/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { return nil } @@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { return nil } diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go index 4d477353..e10f70af 100644 --- a/client/orm/do_nothing_orm_test.go +++ b/client/orm/do_nothing_orm_test.go @@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, o.Driver()) - assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.Read(nil)) @@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(0), i) - assert.Nil(t, o.QueryTableWithCtx(nil, nil)) assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.Read(nil)) diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go index 75852c63..7afa07f1 100644 --- a/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -27,7 +27,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to trace QuerySetter -// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// actually we trace invoking "QueryTable" // the method Begin*, Commit and Rollback are ignored. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go index db60876e..270deaf6 100644 --- a/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -18,6 +18,7 @@ import ( "context" "strconv" "strings" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -31,26 +32,31 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to records the metrics of QuerySetter -// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" +// actually we only records metrics of invoking "QueryTable" type FilterChainBuilder struct { - summaryVec prometheus.ObserverVec AppName string ServerName string RunMode string } +var summaryVec prometheus.ObserverVec +var initSummaryVec sync.Once + func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { - builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "beego", - Subsystem: "orm_operation", - ConstLabels: map[string]string{ - "server": builder.ServerName, - "env": builder.RunMode, - "appname": builder.AppName, - }, - Help: "The statics info for orm operation", - }, []string{"method", "name", "insideTx", "txName"}) + initSummaryVec.Do(func() { + summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "orm_operation", + ConstLabels: map[string]string{ + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, + }, + Help: "The statics info for orm operation", + }, []string{"method", "name", "insideTx", "txName"}) + prometheus.MustRegister(summaryVec) + }) return func(ctx context.Context, inv *orm.Invocation) []interface{} { startTime := time.Now() @@ -74,12 +80,12 @@ func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocati builder.reportTxn(ctx, inv) return } - builder.summaryVec.WithLabelValues(inv.Method, inv.GetTableName(), + summaryVec.WithLabelValues(inv.Method, inv.GetTableName(), strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) } func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) { dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond - builder.summaryVec.WithLabelValues(inv.Method, inv.TxName, + summaryVec.WithLabelValues(inv.Method, inv.TxName, strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) } diff --git a/client/orm/filter/prometheus/filter_test.go b/client/orm/filter/prometheus/filter_test.go index 92316557..a25515a7 100644 --- a/client/orm/filter/prometheus/filter_test.go +++ b/client/orm/filter/prometheus/filter_test.go @@ -32,7 +32,7 @@ func TestFilterChainBuilder_FilterChain1(t *testing.T) { builder := &FilterChainBuilder{} filter := builder.FilterChain(next) - assert.NotNil(t, builder.summaryVec) + assert.NotNil(t, summaryVec) assert.NotNil(t, filter) inv := &orm.Invocation{} diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a60390a1..caf2b3f9 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { - return f.QueryM2MWithCtx(context.Background(), md, name) -} - -func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, _ := modelCache.getByMd(md) inv := &Invocation{ - Method: "QueryM2MWithCtx", + Method: "QueryM2M", Args: []interface{}{md, name}, Md: md, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, f: func(c context.Context) []interface{} { - res := f.ormer.QueryM2MWithCtx(c, md, name) + res := f.ormer.QueryM2M(md, name) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil } return res[0].(QueryM2Mer) } -func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { - return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) } -func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { var ( name string md interface{} @@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT } inv := &Invocation{ - Method: "QueryTableWithCtx", + Method: "QueryTable", Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, Md: md, mi: mi, f: func(c context.Context) []interface{} { - res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + res := f.ormer.QueryTable(ptrStructOrTableName) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil @@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT return res[0].(QuerySeter) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + func (f *filterOrmDecorator) DBStats() *sql.DBStats { inv := &Invocation{ Method: "DBStats", diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 9e223358..566499dd 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, "QueryM2M", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, "QueryTable", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) diff --git a/client/orm/mock/condition.go b/client/orm/mock/condition.go new file mode 100644 index 00000000..486849d4 --- /dev/null +++ b/client/orm/mock/condition.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +type Mock struct { + cond Condition + resp []interface{} + cb func(inv *orm.Invocation) +} + +func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock { + return &Mock{ + cond: cond, + resp: resp, + cb: cb, + } +} + +type Condition interface { + Match(ctx context.Context, inv *orm.Invocation) bool +} + +type SimpleCondition struct { + tableName string + method string +} + +func NewSimpleCondition(tableName string, methodName string) Condition { + return &SimpleCondition{ + tableName: tableName, + method: methodName, + } +} + +func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(s.tableName) != 0 { + res = res && (s.tableName == inv.GetTableName()) + } + + if len(s.method) != 0 { + res = res && (s.method == inv.Method) + } + return res +} diff --git a/client/orm/mock/condition_test.go b/client/orm/mock/condition_test.go new file mode 100644 index 00000000..7f646e70 --- /dev/null +++ b/client/orm/mock/condition_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestSimpleCondition_Match(t *testing.T) { + cond := NewSimpleCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewSimpleCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewSimpleCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "B", + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "A", + })) +} diff --git a/client/orm/mock/context.go b/client/orm/mock/context.go new file mode 100644 index 00000000..ca251c5d --- /dev/null +++ b/client/orm/mock/context.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" +) + +type mockCtxKeyType string + +const mockCtxKey = mockCtxKeyType("beego-orm-mock") + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} diff --git a/client/orm/mock/context_test.go b/client/orm/mock/context_test.go new file mode 100644 index 00000000..a3ed1e90 --- /dev/null +++ b/client/orm/mock/context_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCtx(t *testing.T) { + ms := make([]*Mock, 0, 4) + ctx := CtxWithMock(context.Background(), ms...) + res := mockFromCtx(ctx) + assert.Equal(t, ms, res) +} diff --git a/client/orm/mock/mock.go b/client/orm/mock/mock.go new file mode 100644 index 00000000..072488b2 --- /dev/null +++ b/client/orm/mock/mock.go @@ -0,0 +1,72 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +var stub = newOrmStub() + +func init() { + orm.AddGlobalFilterChain(stub.FilterChain) +} + +type Stub interface { + Mock(m *Mock) + Clear() +} + +type OrmStub struct { + ms []*Mock +} + +func StartMock() Stub { + return stub +} + +func newOrmStub() *OrmStub { + return &OrmStub{ + ms: make([]*Mock, 0, 4), + } +} + +func (o *OrmStub) Mock(m *Mock) { + o.ms = append(o.ms, m) +} + +func (o *OrmStub) Clear() { + o.ms = make([]*Mock, 0, 4) +} + +func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + + ms := mockFromCtx(ctx) + ms = append(ms, o.ms...) + + for _, mock := range ms { + if mock.cond.Match(ctx, inv) { + if mock.cb != nil { + mock.cb(inv) + } + return mock.resp + } + } + return next(ctx, inv) + } +} diff --git a/client/orm/mock/mock_orm.go b/client/orm/mock/mock_orm.go new file mode 100644 index 00000000..16ae8612 --- /dev/null +++ b/client/orm/mock/mock_orm.go @@ -0,0 +1,162 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" + + "github.com/beego/beego/v2/client/orm" +) + +func init() { + RegisterMockDB("default") +} + +// RegisterMockDB create an "virtual DB" by using sqllite +// you should not +func RegisterMockDB(name string) { + source := filepath.Join(os.TempDir(), name+".db") + _ = orm.RegisterDataBase(name, "sqlite3", source) +} + +// MockTable only check table name +func MockTable(tableName string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition(tableName, ""), resp, nil) +} + +// MockMethod only check method name +func MockMethod(method string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition("", method), resp, nil) +} + +// MockOrmRead support orm.Read and orm.ReadWithCtx +// cb is used to mock read data from DB +func MockRead(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) { + if cb != nil { + cb(inv.Args[0]) + } + }) +} + +// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx +// cb is used to mock read data from DB +func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"), + []interface{}{err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx +// cb is used to mock read data from DB +func MockReadOrCreateWithCtx(tableName string, + cb func(data interface{}), + insert bool, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"), + []interface{}{insert, id, err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockInsertWithCtx support Insert and InsertWithCtx +func MockInsertWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil) +} + +// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx +func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil) +} + +// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx +func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil) +} + +// MockUpdateWithCtx support UpdateWithCtx and Update +func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockDeleteWithCtx support Delete and DeleteWithCtx +func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M +// Now you may be need to use golang/mock to generate QueryM2M mock instance +// Or use DoNothingQueryM2Mer +// for example: +// post := Post{Id: 4} +// m2m := Ormer.QueryM2M(&post, "Tags") +// when you write test code: +// MockQueryM2MWithCtx("post", "Tags", mockM2Mer) +// "post" is the table name of model Post structure +// TODO provide orm.QueryM2Mer +func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil) +} + +// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated +func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil) +} + +// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable +func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock { + return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil) +} + +// MockRawWithCtx support RawWithCtx and Raw +func MockRawWithCtx(rs orm.RawSeter) *Mock { + return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil) +} + +// MockDriver support Driver +// func MockDriver(driver orm.Driver) *Mock { +// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver}) +// } + +// MockDBStats support DBStats +func MockDBStats(stats *sql.DBStats) *Mock { + return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil) +} + +// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts +// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err}) +// } + +// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts +// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return MockBeginWithCtxAndOpts(txOrm, err) +// } + +// MockCommit support Commit +func MockCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil) +} + +// MockRollback support Rollback +func MockRollback(err error) *Mock { + return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil) +} diff --git a/client/orm/mock/mock_orm_test.go b/client/orm/mock/mock_orm_test.go new file mode 100644 index 00000000..d65855cb --- /dev/null +++ b/client/orm/mock/mock_orm_test.go @@ -0,0 +1,297 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) +const mockErrorMsg = "mock error" +func init() { + orm.RegisterModel(&User{}) +} + +func TestMockDBStats(t *testing.T) { + s := StartMock() + defer s.Clear() + stats := &sql.DBStats{} + s.Mock(MockDBStats(stats)) + + o := orm.NewOrm() + + res := o.DBStats() + + assert.Equal(t, stats, res) +} + +func TestMockDeleteWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + rows, err := o.Delete(&User{}) + assert.Equal(t, int64(12), rows) + assert.Nil(t, err) +} + +func TestMockInsertOrUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + id, err := o.InsertOrUpdate(&User{}) + assert.Equal(t, int64(12), id) + assert.Nil(t, err) +} + +func TestMockRead(t *testing.T) { + s := StartMock() + defer s.Clear() + err := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, err)) + o := orm.NewOrm() + u := &User{} + e := o.Read(u) + assert.Equal(t, err, e) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockQueryM2MWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQueryM2Mer{} + s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock)) + o := orm.NewOrm() + res := o.QueryM2M(&User{}, "Tags") + assert.Equal(t, mock, res) +} + +func TestMockQueryTableWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQuerySetter{} + s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.QueryTable(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockTable(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockTable((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.Read(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockInsertMultiWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.InsertMulti(11, []interface{}{&User{}}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockInsertWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock)) + o := orm.NewOrm() + res, err := o.Insert(&User{}) + assert.Equal(t, int64(13), res) + assert.Equal(t, mock, err) +} + +func TestMockUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.Update(&User{}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockLoadRelatedWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock)) + o := orm.NewOrm() + res, err := o.LoadRelated(&User{}, "T") + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockMethod(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockMethod("ReadWithCtx", mock)) + o := orm.NewOrm() + err := o.Read(&User{}) + assert.Equal(t, mock, err) +} + +func TestMockReadForUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + o := orm.NewOrm() + u := &User{} + err := o.ReadForUpdate(u) + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockRawWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingRawSetter{} + s.Mock(MockRawWithCtx(mock)) + o := orm.NewOrm() + res := o.Raw("") + assert.Equal(t, mock, res) +} + +func TestMockReadOrCreateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, false, 12, mock)) + o := orm.NewOrm() + u := &User{} + inserted, id, err := o.ReadOrCreate(u, "") + assert.Equal(t, mock, err) + assert.Equal(t, int64(12), id) + assert.False(t, inserted) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionClosure(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxUsingClosure() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionManually(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxManually() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionRollback(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error"))) + s.Mock(MockRollback(mock)) + _, err := originalTx() + assert.Equal(t, mock, err) +} + +func TestTransactionCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, nil)) + s.Mock(MockCommit(mock)) + u, err := originalTx() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func originalTx() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + if err == nil { + err = txOrm.Commit() + return u, err + } else { + err = txOrm.Rollback() + return nil, err + } +} + +func originalTxManually() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + _ = txOrm.Commit() + return u, err +} + +func originalTxUsingClosure() (*User, error) { + u := &User{} + var err error + o := orm.NewOrm() + _ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error { + err = txOrm.Read(u) + return nil + }) + return u, err +} + +type User struct { + Id int + Name string +} + +func (u *User) TableName() string { + return "user" +} diff --git a/client/orm/mock/mock_queryM2Mer.go b/client/orm/mock/mock_queryM2Mer.go new file mode 100644 index 00000000..16648fee --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer.go @@ -0,0 +1,92 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +// DoNothingQueryM2Mer do nothing +// use it to build mock orm.QueryM2Mer +type DoNothingQueryM2Mer struct { + +} + +func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) Clear() (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Count() (int64, error) { + return 0, nil +} + +type QueryM2MerCondition struct { + tableName string + name string +} + +func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition { + return &QueryM2MerCondition{ + tableName: tableName, + name: name, + } +} + +func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(q.tableName) > 0 { + res = res && (q.tableName == inv.GetTableName()) + } + if len(q.name) > 0 { + res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string)) + } + return res +} + + diff --git a/client/orm/mock/mock_queryM2Mer_test.go b/client/orm/mock/mock_queryM2Mer_test.go new file mode 100644 index 00000000..82776e76 --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDoNothingQueryM2Mer(t *testing.T) { + m2m := &DoNothingQueryM2Mer{} + + i, err := m2m.Clear() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Add() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Remove() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + assert.True(t, m2m.Exist(nil)) +} + +func TestNewQueryM2MerCondition(t *testing.T) { + cond := NewQueryM2MerCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewQueryM2MerCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewQueryM2MerCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "B"}, + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "A"}, + })) +} \ No newline at end of file diff --git a/client/orm/mock/mock_querySetter.go b/client/orm/mock/mock_querySetter.go new file mode 100644 index 00000000..074b6211 --- /dev/null +++ b/client/orm/mock/mock_querySetter.go @@ -0,0 +1,183 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" +) + +// DoNothingQuerySetter do nothing +// usually you use this to build your mock QuerySetter +type DoNothingQuerySetter struct { +} + +func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool { + return true +} + +func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GetCond() *orm.Condition { + return orm.NewCondition() +} + +func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Count() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Exist() bool { + return true +} + +func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Delete() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} diff --git a/client/orm/mock/mock_querySetter_test.go b/client/orm/mock/mock_querySetter_test.go new file mode 100644 index 00000000..09e5ad8c --- /dev/null +++ b/client/orm/mock/mock_querySetter_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingQuerySetter(t *testing.T) { + setter := &DoNothingQuerySetter{} + setter.GroupBy().Filter("").Limit(10). + Distinct().Exclude("a").FilterRaw("", ""). + ForceIndex().ForUpdate().IgnoreIndex(). + Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex() + + assert.True(t, setter.Exist()) + err := setter.One(nil) + assert.Nil(t, err) + i, err := setter.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Delete() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.All(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Update(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesFlat(nil, "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + ins, err := setter.PrepareInsert() + assert.Nil(t, err) + assert.Nil(t, ins) + + assert.NotNil(t, setter.GetCond()) +} diff --git a/client/orm/mock/mock_rawSetter.go b/client/orm/mock/mock_rawSetter.go new file mode 100644 index 00000000..016fde47 --- /dev/null +++ b/client/orm/mock/mock_rawSetter.go @@ -0,0 +1,64 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + + "github.com/beego/beego/v2/client/orm" +) + +type DoNothingRawSetter struct { +} + +func (d *DoNothingRawSetter) Exec() (sql.Result, error) { + return nil, nil +} + +func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error { + return nil +} + +func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter { + return d +} + +func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) { + return nil, nil +} \ No newline at end of file diff --git a/client/orm/mock/mock_rawSetter_test.go b/client/orm/mock/mock_rawSetter_test.go new file mode 100644 index 00000000..dd98edbd --- /dev/null +++ b/client/orm/mock/mock_rawSetter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingRawSetter(t *testing.T) { + rs := &DoNothingRawSetter{} + i, err := rs.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.ValuesFlat(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.QueryRows() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + err = rs.QueryRow() + // assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + s, err := rs.Exec() + assert.Nil(t, err) + assert.Nil(t, s) + + p, err := rs.Prepare() + assert.Nil(t, err) + assert.Nil(t, p) + + rrs := rs.SetArgs() + assert.Equal(t, rrs, rs) +} diff --git a/client/orm/mock/mock_test.go b/client/orm/mock/mock_test.go new file mode 100644 index 00000000..73bce4e5 --- /dev/null +++ b/client/orm/mock/mock_test.go @@ -0,0 +1,58 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestOrmStub_FilterChain(t *testing.T) { + os := newOrmStub() + inv := &orm.Invocation{ + Args: []interface{}{10}, + } + i := 1 + os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { + i++ + return nil + })(context.Background(), inv) + + assert.Equal(t, 2, i) + + m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 1 + }) + os.Mock(m) + + os.FilterChain(nil)(context.Background(), inv) + assert.Equal(t, 11, inv.Args[0]) + + inv.Args[0] = 10 + ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 3 + }) + + os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv) + assert.Equal(t, 13, inv.Args[0]) +} diff --git a/client/orm/models.go b/client/orm/models.go index 64dfab09..0f07e24d 100644 --- a/client/orm/models.go +++ b/client/orm/models.go @@ -332,10 +332,6 @@ end: // register register models to model cache func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { - if mc.done { - err = fmt.Errorf("register must be run before BootStrap") - return - } for _, model := range models { val := reflect.ValueOf(model) @@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) return } - + if val.Elem().Kind() == reflect.Slice { + val = reflect.New(val.Elem().Type().Elem()) + } table := getTableName(val) if prefixOrSuffixStr != "" { @@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } if _, ok := mc.get(table); ok { - err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) - return + return nil } mi := newModelInfo(val) @@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } } } - - if mi.fields.pk == nil { - err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - return - } - } mi.table = table diff --git a/client/orm/models_test.go b/client/orm/models_test.go index e3f74c0b..3fd35765 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -255,6 +255,22 @@ func NewTM() *TM { return obj } +type DeptInfo struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + +type UnregisterModel struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + type User struct { ID int `orm:"column(id)"` UserName string `orm:"size(30);unique"` @@ -476,45 +492,45 @@ var ( helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - + go get -u github.com/beego/beego/v2/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq go get -u github.com/pingcap/tidb - + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" go test -v github.com/beego/beego/v2/client/orm - - + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/beego/beego/v2/client/orm - - + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/beego/beego/v2/client/orm - + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' go test -v github.com/beego/beego/v2/pgk/orm - + ` ) diff --git a/client/orm/models_utils.go b/client/orm/models_utils.go index 950ca243..b2e5760e 100644 --- a/client/orm/models_utils.go +++ b/client/orm/models_utils.go @@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string { // get whether the table needs to be created for the database alias func isApplicableTableForDB(val reflect.Value, db string) bool { + if !val.IsValid() { + return true + } fun := val.MethodByName("IsApplicableTableForDB") if fun.IsValid() { vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) diff --git a/client/orm/orm.go b/client/orm/orm.go index 1adf84e2..660f2939 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,6 +58,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "os" "reflect" "time" @@ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form @@ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.InsertWithCtx(ctx, md) @@ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } @@ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } @@ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err } @@ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) + return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } // delete model in database @@ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } @@ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - return o.QueryM2MWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri return newQueryM2M(md, o, mi, fi, ind) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) +} + // load related models to md model. // args are limit, offset int and order string. // @@ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = []string{order} + qs.orders = order_clause.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) @@ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) -} -func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return + return qs +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) } // return a raw query seter for raw sql string. @@ -595,9 +602,8 @@ func NewOrm() Ormer { func NewOrmUsingDB(aliasName string) Ormer { if al, ok := dataBaseCache.get(aliasName); ok { return newDBWithAlias(al) - } else { - panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } // NewOrmWithDB create a new ormer object with specify *sql.DB for query diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index 5409406e..0080d53c 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -16,12 +16,13 @@ package orm import ( "fmt" + "github.com/beego/beego/v2/client/orm/clauses" "strings" ) // ExprSep define the expression separation const ( - ExprSep = "__" + ExprSep = clauses.ExprSep ) type condValue struct { diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 61addeb5..8ac373b5 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } - func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 6f9798d3..50c1ca41 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" ) @@ -31,6 +32,10 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } @@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -70,11 +75,11 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 17e1b5d1..9da49bba 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,7 +14,10 @@ package orm -import "reflect" +import ( + "context" + "reflect" +) // model to model struct type queryM2M struct { @@ -33,6 +36,10 @@ type queryM2M struct { // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo mfi := fi.reverseFieldInfo @@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) } // clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 177cfc3a..9f7b8441 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,6 +18,7 @@ import ( "context" "fmt" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" ) @@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forUpdate bool - useIndex int - indexes []string - orm *ormBase - ctx context.Context - forContext bool + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + aggregate string } var _ QuerySeter = new(querySet) @@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter { // add ORDER expression. // "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs +func (o querySet) OrderBy(expressions ...string) QuerySeter { + if len(expressions) <= 0 { + return &o + } + o.orders = order_clause.ParseOrder(expressions...) + return &o +} + +// add ORDER expression. +func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { + if len(orders) <= 0 { + return &o + } + o.orders = orders return &o } @@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition { // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.CountWithCtx(context.Background()) +} + +func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.ExistWithCtx(context.Background()) +} + +func (o *querySet) ExistWithCtx(ctx context.Context) bool { + cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) + return o.UpdateWithCtx(context.Background(), values) +} + +func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.DeleteWithCtx(context.Background()) +} + +func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // return a insert queryer. @@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { + return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + return o.AllWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { + return o.OneWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } @@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } // query all rows into map[string]interface with specify key and value column name. @@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - // create new QuerySeter. func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) @@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o.orm = orm return o } + +// aggregate func +func (o querySet) Aggregate(s string) QuerySeter { + o.aggregate = s + return &o +} diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index f1074973..e2e25ac4 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "io/ioutil" "math" "os" @@ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) BootStrap() @@ -333,6 +336,73 @@ func TestTM(t *testing.T) { throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) } +func TestUnregisterModel(t *testing.T) { + data := []*DeptInfo{ + { + DeptName: "A", + EmployeeName: "A1", + Salary: 1000, + }, + { + DeptName: "A", + EmployeeName: "A2", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B1", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B2", + Salary: 4000, + }, + { + DeptName: "B", + EmployeeName: "B3", + Salary: 3000, + }, + } + qs := dORM.QueryTable("dept_info") + i, _ := qs.PrepareInsert() + for _, d := range data { + _, err := i.Insert(d) + if err != nil { + throwFail(t, err) + } + } + + f := func() { + var res []UnregisterModel + n, err := dORM.QueryTable("dept_info").All(&res) + throwFail(t, err) + throwFail(t, AssertIs(n, 5)) + throwFail(t, AssertIs(res[0].EmployeeName, "A1")) + + type Sum struct { + DeptName string + Total int + } + var sun []Sum + qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) + throwFail(t, AssertIs(sun[0].DeptName, "A")) + throwFail(t, AssertIs(sun[0].Total, 3000)) + + type Max struct { + DeptName string + Max float64 + } + var max []Max + qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) + throwFail(t, AssertIs(max[1].DeptName, "B")) + throwFail(t, AssertIs(max[1].Max, 4000)) + } + for i := 0; i < 5; i++ { + f() + } +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} @@ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) { num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } } func TestAll(t *testing.T) { @@ -1163,6 +1253,19 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column("Id"), + order_clause.SortAscending(), + ), + ).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) @@ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) { throwFail(t, AssertIs(!cycleFlag, true)) return } + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) + throwFail(t, AssertIs(err, context.Canceled)) +} diff --git a/client/orm/types.go b/client/orm/types.go index be402971..ab3ddac4 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) @@ -196,12 +197,16 @@ type DQL interface { // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter DBStats() *sql.DBStats @@ -236,6 +241,7 @@ type TxOrmer interface { // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) Close() error } @@ -295,6 +301,28 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add ORDER expression by order clauses + // for example: + // OrderClauses( + // order_clause.Clause( + // order.Column("Id"), + // order.SortAscending(), + // ), + // order_clause.Clause( + // order.Column("status"), + // order.SortDescending(), + // ), + // ) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`user__status`), + // order_clause.SortDescending(),//default None + // )) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`random()`), + // order_clause.SortNone(),//default None + // order_clause.Raw(),//default false.if true, do not check field is valid or not + // )) + OrderClauses(orders ...*order_clause.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) @@ -333,9 +361,11 @@ type QuerySeter interface { // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + CountWithCtx(context.Context) (int64, error) // check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool + ExistWithCtx(context.Context) bool // execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ @@ -345,11 +375,13 @@ type QuerySeter interface { // "user_name": "slene2" // }) // user slene's name will change to slene2 Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) // delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) // return a insert queryer. // it can be used in times. // example: @@ -358,18 +390,21 @@ type QuerySeter interface { // num, err = i.Insert(&user2) // user table will add one record user2 at once // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) // query all data and map to containers. // cols means the columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) // query one row data and map to containers. // cols means the columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error // query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. @@ -377,18 +412,21 @@ type QuerySeter interface { // var maps []Params // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) // query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) // query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) // query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data @@ -411,6 +449,15 @@ type QuerySeter interface { // Found int // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + // aggregate func. + // for example: + // type result struct { + // DeptName string + // Total int + // } + // var res []result + // o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) + Aggregate(s string) QuerySeter } // QueryM2Mer model to model query struct @@ -428,18 +475,23 @@ type QueryM2Mer interface { // insert one or more rows to m2m table // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) // remove models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool // clean all models in related of origin model Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) // count all related models of origin model Count() (int64, error) + CountWithCtx(context.Context) (int64, error) } // RawPreparer raw query statement @@ -513,11 +565,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - // QueryContext(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier @@ -554,28 +606,28 @@ type txEnder interface { // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -584,12 +636,12 @@ type dbBaser interface { TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) ShowTablesQuery() string ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool + IndexExists(context.Context, dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error + setval(context.Context, dbQuerier, *modelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/core/berror/codes.go b/core/berror/codes.go new file mode 100644 index 00000000..e3a616e8 --- /dev/null +++ b/core/berror/codes.go @@ -0,0 +1,91 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "sync" +) + + + +// A Code is an unsigned 32-bit error code as defined in the beego spec. +type Code interface { + Code() uint32 + Module() string + Desc() string + Name() string +} + +var defaultCodeRegistry = &codeRegistry{ + codes: make(map[uint32]*codeDefinition, 127), +} + +// DefineCode defining a new Code +// Before defining a new code, please read Beego specification. +// desc could be markdown doc +func DefineCode(code uint32, module string, name string, desc string) Code { + res := &codeDefinition{ + code: code, + module: module, + desc: desc, + } + defaultCodeRegistry.lock.Lock() + defer defaultCodeRegistry.lock.Unlock() + + if _, ok := defaultCodeRegistry.codes[code]; ok { + panic(fmt.Sprintf("duplicate code, code %d has been registered", code)) + } + defaultCodeRegistry.codes[code] = res + return res +} + +type codeRegistry struct { + lock sync.RWMutex + codes map[uint32]*codeDefinition +} + + +func (cr *codeRegistry) Get(code uint32) (Code, bool) { + cr.lock.RLock() + defer cr.lock.RUnlock() + c, ok := cr.codes[code] + return c, ok +} + +type codeDefinition struct { + code uint32 + module string + desc string + name string +} + + +func (c *codeDefinition) Name() string { + return c.name +} + +func (c *codeDefinition) Code() uint32 { + return c.code +} + +func (c *codeDefinition) Module() string { + return c.module +} + +func (c *codeDefinition) Desc() string { + return c.desc +} + diff --git a/core/berror/error.go b/core/berror/error.go new file mode 100644 index 00000000..ca09798a --- /dev/null +++ b/core/berror/error.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// code, msg +const errFmt = "ERROR-%d, %s" + +// Err returns an error representing c and msg. If c is OK, returns nil. +func Error(c Code, msg string) error { + return fmt.Errorf(errFmt, c.Code(), msg) +} + +// Errorf returns error +func Errorf(c Code, format string, a ...interface{}) error { + return Error(c, fmt.Sprintf(format, a...)) +} + +func Wrap(err error, c Code, msg string) error { + if err == nil { + return nil + } + return errors.Wrap(err, fmt.Sprintf(errFmt, c.Code(), msg)) +} + +func Wrapf(err error, c Code, format string, a ...interface{}) error { + return Wrap(err, c, fmt.Sprintf(format, a...)) +} + +// FromError is very simple. It just parse error msg and check whether code has been register +// if code not being register, return unknown +// if err.Error() is not valid beego error code, return unknown +func FromError(err error) (Code, bool) { + msg := err.Error() + codeSeg := strings.SplitN(msg, ",", 2) + if strings.HasPrefix(codeSeg[0], "ERROR-") { + codeStr := strings.SplitN(codeSeg[0], "-", 2) + if len(codeStr) < 2 { + return Unknown, false + } + codeInt, e := strconv.ParseUint(codeStr[1], 10, 32) + if e != nil { + return Unknown, false + } + if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok { + return code, true + } + } + return Unknown, false +} diff --git a/core/berror/error_test.go b/core/berror/error_test.go new file mode 100644 index 00000000..7a14d933 --- /dev/null +++ b/core/berror/error_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1") + +var testErr = errors.New("hello, this is error") + +func TestErrorf(t *testing.T) { + msg := Errorf(testCode1, "errorf %s", "aaaa") + assert.NotNil(t, msg) + assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error()) +} + +func TestWrapf(t *testing.T) { + err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa") + assert.NotNil(t, err) + assert.True(t, errors.Is(err, testErr)) +} + +func TestFromError(t *testing.T) { + err := errors.New("ERROR-1, errorf aaaa") + code, ok := FromError(err) + assert.True(t, ok) + assert.Equal(t, testCode1, code) + assert.Equal(t, "unit_test", code.Module()) + assert.Equal(t, "Hello, test code1", code.Desc()) + + err = errors.New("not beego error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2, not register") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-aaa, invalid code") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("aaaaaaaaaaaaaa") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2-3, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) +} diff --git a/core/berror/pre_define_code.go b/core/berror/pre_define_code.go new file mode 100644 index 00000000..01992957 --- /dev/null +++ b/core/berror/pre_define_code.go @@ -0,0 +1,52 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" +) + +// pre define code + +// Unknown indicates got some error which is not defined +var Unknown = DefineCode(5000001, "error", "Unknown",fmt.Sprintf(` +Unknown error code. Usually you will see this code in three cases: +1. You forget to define Code or function DefineCode not being executed; +2. This is not Beego's error but you call FromError(); +3. Beego got unexpected error and don't know how to handle it, and then return Unknown error + +A common practice to DefineCode looks like: +%s + +In this way, you may forget to import this package, and got Unknown error. + +Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err) + +`, goCodeBlock(` +import your_package + +func init() { + DefineCode(5100100, "your_module", "detail") + // ... +} +`))) + +func goCodeBlock(code string) string { + return codeBlock("go", code) +} +func codeBlock(lan string, code string) string { + return fmt.Sprintf("```%s\n%s\n```", lan, code) +} + diff --git a/core/logs/formatter.go b/core/logs/formatter.go index 67500b2b..80b30fa0 100644 --- a/core/logs/formatter.go +++ b/core/logs/formatter.go @@ -69,8 +69,8 @@ func (p *PatternLogFormatter) ToString(lm *LogMsg) string { 'm': lm.Msg, 'n': strconv.Itoa(lm.LineNumber), 'l': strconv.Itoa(lm.Level), - 't': levelPrefix[lm.Level-1], - 'T': levelNames[lm.Level-1], + 't': levelPrefix[lm.Level], + 'T': levelNames[lm.Level], 'F': lm.FilePath, } _, m['f'] = path.Split(lm.FilePath) diff --git a/core/logs/formatter_test.go b/core/logs/formatter_test.go index a97765ac..a1853d72 100644 --- a/core/logs/formatter_test.go +++ b/core/logs/formatter_test.go @@ -88,7 +88,7 @@ func TestPatternLogFormatter(t *testing.T) { } got := tes.ToString(lm) want := lm.FilePath + ":" + strconv.Itoa(lm.LineNumber) + "|" + - when.Format(tes.WhenFormat) + levelPrefix[lm.Level-1] + ">> " + lm.Msg + when.Format(tes.WhenFormat) + levelPrefix[lm.Level] + ">> " + lm.Msg if got != want { t.Errorf("want %s, got %s", want, got) } diff --git a/go.mod b/go.mod index 89baa406..ce67b7d2 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-cmp v0.5.0 // indirect - github.com/google/uuid v1.1.1 // indirect + github.com/google/uuid v1.1.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 diff --git a/server/web/admin.go b/server/web/admin.go index 89c9ddb9..3f665b09 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -108,8 +108,11 @@ func registerAdmin() error { c := &adminController{ servers: make([]*HttpServer, 0, 2), } + + // copy config to avoid conflict + adminCfg := *BConfig beeAdminApp = &adminApp{ - HttpServer: NewHttpServerWithCfg(BConfig), + HttpServer: NewHttpServerWithCfg(&adminCfg), } // keep in mind that all data should be html escaped to avoid XSS attack beeAdminApp.Router("/", c, "get:AdminIndex") diff --git a/server/web/context/context.go b/server/web/context/context.go index 6070c996..099729d0 100644 --- a/server/web/context/context.go +++ b/server/web/context/context.go @@ -29,6 +29,7 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/beego/beego/v2/server/web/session" "net" "net/http" "strconv" @@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } } +// Session return session store of this context of request +func (ctx *Context) Session() (store session.Store, err error) { + if ctx.Input != nil { + if ctx.Input.CruSession != nil { + store = ctx.Input.CruSession + return + } else { + err = errors.New(`no valid session store(please initialize session)`) + return + } + } else { + err = errors.New(`no valid input`) + return + } +} + // Response is a wrapper for the http.ResponseWriter // Started: if true, response was already written to so the other handler will not be executed type Response struct { diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 7c0535e0..977c3cbf 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -15,6 +15,7 @@ package context import ( + "github.com/beego/beego/v2/server/web/session" "net/http" "net/http/httptest" "testing" @@ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) { t.FailNow() } } + +func TestContext_Session(t *testing.T) { + c := NewContext() + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session1(t *testing.T) { + c := Context{} + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session2(t *testing.T) { + c := NewContext() + c.Input.CruSession = &session.MemSessionStore{} + + if store, err := c.Session(); store == nil || err != nil { + t.FailNow() + } +} \ No newline at end of file diff --git a/server/web/filter/prometheus/filter.go b/server/web/filter/prometheus/filter.go index 59a673ac..520170e0 100644 --- a/server/web/filter/prometheus/filter.go +++ b/server/web/filter/prometheus/filter.go @@ -17,23 +17,49 @@ package prometheus import ( "strconv" "strings" + "sync" "time" "github.com/prometheus/client_golang/prometheus" "github.com/beego/beego/v2" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web" "github.com/beego/beego/v2/server/web/context" ) +const unknownRouterPattern = "UnknownRouterPattern" + // FilterChainBuilder is an extension point, // when we want to support some configuration, // please use this structure type FilterChainBuilder struct { } +var summaryVec prometheus.ObserverVec +var initSummaryVec sync.Once + // FilterChain returns a FilterFunc. The filter will records some metrics func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { + + initSummaryVec.Do(func() { + summaryVec = builder.buildVec() + err := prometheus.Register(summaryVec) + if _, ok := err.(*prometheus.AlreadyRegisteredError); err != nil && !ok { + logs.Error("web module register prometheus vector failed, %+v", err) + } + registerBuildInfo() + }) + + return func(ctx *context.Context) { + startTime := time.Now() + next(ctx) + endTime := time.Now() + go report(endTime.Sub(startTime), ctx, summaryVec) + } +} + +func (builder *FilterChainBuilder) buildVec() *prometheus.SummaryVec { summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", Subsystem: "http_request", @@ -44,17 +70,7 @@ func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFu }, Help: "The statics info for http request", }, []string{"pattern", "method", "status"}) - - prometheus.MustRegister(summaryVec) - - registerBuildInfo() - - return func(ctx *context.Context) { - startTime := time.Now() - next(ctx) - endTime := time.Now() - go report(endTime.Sub(startTime), ctx, summaryVec) - } + return summaryVec } func registerBuildInfo() { @@ -75,13 +91,17 @@ func registerBuildInfo() { }, }, []string{}) - prometheus.MustRegister(buildInfo) + _ = prometheus.Register(buildInfo) buildInfo.WithLabelValues().Set(1) } -func report(dur time.Duration, ctx *context.Context, vec *prometheus.SummaryVec) { +func report(dur time.Duration, ctx *context.Context, vec prometheus.ObserverVec) { status := ctx.Output.Status - ptn := ctx.Input.GetData("RouterPattern").(string) + ptnItf := ctx.Input.GetData("RouterPattern") + ptn := unknownRouterPattern + if ptnItf != nil { + ptn = ptnItf.(string) + } ms := dur / time.Millisecond vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status)).Observe(float64(ms)) } diff --git a/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go index f00f20e7..7d88d843 100644 --- a/server/web/filter/prometheus/filter_test.go +++ b/server/web/filter/prometheus/filter_test.go @@ -18,6 +18,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" @@ -37,4 +38,19 @@ func TestFilterChain(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") filter(ctx) assert.True(t, ctx.Input.GetData("invocation").(bool)) + time.Sleep(1 * time.Second) } + +func TestFilterChainBuilder_report(t *testing.T) { + + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + fb := &FilterChainBuilder{} + // without router info + report(time.Second, ctx, fb.buildVec()) + + ctx.Input.SetData("RouterPattern", "my-route") + report(time.Second, ctx, fb.buildVec()) +} \ No newline at end of file diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 00000000..bcf9edf4 --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,35 @@ +package session + +import ( + "context" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +//Session maintain session for web service +//Session new a session storage and store it into webContext.Context +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if ctx.Input.CruSession != nil { + return + } + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Error(`init session error:%s`, err.Error()) + } else { + //release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.CruSession = sess + } + + next(ctx) + } + } +} \ No newline at end of file diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 00000000..687789a5 --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,86 @@ +package session + +import ( + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "github.com/google/uuid" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + storeKey := uuid.New().String() + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(storeKey); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestSession1(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store, err := ctx.Session(); store == nil || err != nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go index 2a428b78..5dd38fc5 100644 --- a/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -15,9 +15,12 @@ package web import ( + "fmt" "net/http" "net/http/httptest" + "strconv" "testing" + "time" "github.com/stretchr/testify/assert" @@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ns := NewNamespace("/chain") ns.Get("/*", func(ctx *context.Context) { - ctx.Output.Body([]byte("hello")) + _ = ctx.Output.Body([]byte("hello")) }) r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() + BeeApp.Handlers.Init() BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(t, "filter-chain", w.Header().Get("filter")) } + +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 2deef54e..3e20c8fb 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/hooks.go b/server/web/hooks.go index ae32b9f8..f91bec1e 100644 --- a/server/web/hooks.go +++ b/server/web/hooks.go @@ -6,6 +6,8 @@ import ( "net/http" "path/filepath" + "github.com/coreos/etcd/pkg/fileutil" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" @@ -99,7 +101,12 @@ func registerGzip() error { func registerCommentRouter() error { if BConfig.RunMode == DEV { - if err := parserPkg(filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)); err != nil { + ctrlDir := filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath) + if !fileutil.Exist(ctrlDir) { + logs.Warn("controller package not found, won't generate router: ", ctrlDir) + return nil + } + if err := parserPkg(ctrlDir); err != nil { return err } } diff --git a/server/web/namespace.go b/server/web/namespace.go index 4e0c3b85..96037b4d 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) return n } @@ -187,6 +187,54 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { return n } +// RouterGet same as beego.RouterGet +func (n *Namespace) RouterGet(rootpath string, f interface{}) *Namespace { + n.handlers.RouterGet(rootpath, f) + return n +} + +// RouterPost same as beego.RouterPost +func (n *Namespace) RouterPost(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPost(rootpath, f) + return n +} + +// RouterDelete same as beego.RouterDelete +func (n *Namespace) RouterDelete(rootpath string, f interface{}) *Namespace { + n.handlers.RouterDelete(rootpath, f) + return n +} + +// RouterPut same as beego.RouterPut +func (n *Namespace) RouterPut(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPut(rootpath, f) + return n +} + +// RouterHead same as beego.RouterHead +func (n *Namespace) RouterHead(rootpath string, f interface{}) *Namespace { + n.handlers.RouterHead(rootpath, f) + return n +} + +// RouterOptions same as beego.RouterOptions +func (n *Namespace) RouterOptions(rootpath string, f interface{}) *Namespace { + n.handlers.RouterOptions(rootpath, f) + return n +} + +// RouterPatch same as beego.RouterPatch +func (n *Namespace) RouterPatch(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPatch(rootpath, f) + return n +} + +// Any same as beego.RouterAny +func (n *Namespace) RouterAny(rootpath string, f interface{}) *Namespace { + n.handlers.RouterAny(rootpath, f) + return n +} + // Namespace add nest Namespace // usage: // ns := beego.NewNamespace(“/v1”). @@ -366,6 +414,62 @@ func NSPatch(rootpath string, f FilterFunc) LinkNamespace { } } +// NSRouterGet call Namespace RouterGet +func NSRouterGet(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterGet(rootpath, f) + } +} + +// NSRouterPost call Namespace RouterPost +func NSRouterPost(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPost(rootpath, f) + } +} + +// NSRouterHead call Namespace RouterHead +func NSRouterHead(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterHead(rootpath, f) + } +} + +// NSRouterPut call Namespace RouterPut +func NSRouterPut(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPut(rootpath, f) + } +} + +// NSRouterDelete call Namespace RouterDelete +func NSRouterDelete(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterDelete(rootpath, f) + } +} + +// NSRouterAny call Namespace RouterAny +func NSRouterAny(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterAny(rootpath, f) + } +} + +// NSRouterOptions call Namespace RouterOptions +func NSRouterOptions(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterOptions(rootpath, f) + } +} + +// NSRouterPatch call Namespace RouterPatch +func NSRouterPatch(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPatch(rootpath, f) + } +} + // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { diff --git a/server/web/namespace_test.go b/server/web/namespace_test.go index 05042c96..30d17cb2 100644 --- a/server/web/namespace_test.go +++ b/server/web/namespace_test.go @@ -15,6 +15,7 @@ package web import ( + "fmt" "net/http" "net/http/httptest" "strconv" @@ -23,6 +24,40 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +const ( + exampleBody = "hello world" + examplePointerBody = "hello world pointer" + + nsNamespace = "/router" + nsPath = "/user" + nsNamespacePath = "/router/user" +) + +type ExampleController struct { + Controller +} + +func (m ExampleController) Ping() { + err := m.Ctx.Output.Body([]byte(exampleBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m *ExampleController) PingPointer() { + err := m.Ctx.Output.Body([]byte(examplePointerBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m ExampleController) ping() { + err := m.Ctx.Output.Body([]byte("ping method")) + if err != nil { + fmt.Println(err) + } +} + func TestNamespaceGet(t *testing.T) { r, _ := http.NewRequest("GET", "/v1/user", nil) w := httptest.NewRecorder() @@ -166,3 +201,215 @@ func TestNamespaceInside(t *testing.T) { t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) } } + +func TestNamespaceRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterGet(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPost(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterDelete(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPut(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterHead(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterOptions(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPatch(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + ns.RouterAny(nsPath, ExampleController.Ping) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestNamespaceNSRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterGet(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/router") + NSRouterPost(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterDelete(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterPut(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterHead(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterOptions(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterPatch("/user", ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + NSRouterAny(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterAny can't run, get the response is " + w.Body.String()) + } + } +} diff --git a/server/web/router.go b/server/web/router.go index fd94a769..e80694c0 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -20,6 +20,7 @@ import ( "net/http" "path" "reflect" + "runtime" "strconv" "strings" "sync" @@ -118,12 +119,33 @@ type ControllerInfo struct { routerType int initialize func() ControllerInterface methodParams []*param.MethodParam + sessionOn bool } +type ControllerOption func(*ControllerInfo) + func (c *ControllerInfo) GetPattern() string { return c.pattern } +func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + +func WithRouterSessionOn(sessionOn bool) ControllerOption { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -136,6 +158,9 @@ type ControllerRegister struct { // the filter created by FilterChain chainRoot *FilterRouter + // keep registered chain and build it when serve http + filterChains []filterChainConfig + cfg *Config } @@ -155,12 +180,24 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { return beecontext.NewContext() }, }, - cfg: cfg, + cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0; i-- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(root.filterFunc) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + // Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method @@ -171,41 +208,64 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { + p.addWithMethodParams(pattern, c, nil, opts...) } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) } } - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + + route := p.createBeegoRouter(t, pattern) route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) execController, ok := vc.Interface().(ControllerInterface) @@ -229,23 +289,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt return execController } - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + for i := range opts { + opts[i](route) } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + + p.addRouterForMethod(route) } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { @@ -273,7 +328,8 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { for _, f := range a.Filters { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + + p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } @@ -294,6 +350,261 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { p.pool.Put(ctx) } +// RouterGet add get method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterGet("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterGet(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodGet, pattern, f) +} + +// RouterPost add post method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPost("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPost(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPost, pattern, f) +} + +// RouterHead add head method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterHead("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterHead(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodHead, pattern, f) +} + +// RouterPut add put method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPut("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPut(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPut, pattern, f) +} + +// RouterPatch add patch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPatch("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPatch(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPatch, pattern, f) +} + +// RouterDelete add delete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterDelete("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterDelete(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodDelete, pattern, f) +} + +// RouterOptions add options method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterOptions("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterOptions(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodOptions, pattern, f) +} + +// RouterAny add all method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterAny("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterAny(pattern string, f interface{}) { + p.AddRouterMethod("*", pattern, f) +} + +// AddRouterMethod add http method router +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// AddRouterMethod("get","/api/:id", MyController.Ping) +func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { + httpMethod = p.getUpperMethodString(httpMethod) + ct, methodName := getReflectTypeAndMethod(f) + + p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) +} + +// addBeegoTypeRouter add beego type router +func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { + route := p.createBeegoRouter(ct, pattern) + methods := p.getHttpMethodMapMethod(httpMethod, ctMethod) + route.methods = methods + + p.addRouterForMethod(route) +} + +// createBeegoRouter create beego router base on reflect type and pattern +func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.controllerType = ct + return route +} + +// createRestfulRouter create restful router with filter function and pattern +func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.runFunction = f + return route +} + +// createHandlerRouter create handler router with handler and pattern +func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.handler = h + return route +} + +// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will +// use http method as the controller method +func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string { + methods := make(map[string]string) + // not match-all sign, only add for the http method + if httpMethod != "*" { + + if ctMethod == "" { + ctMethod = httpMethod + } + methods[httpMethod] = ctMethod + return methods + } + + // add all http method + for val := range HTTPMETHOD { + if ctMethod == "" { + methods[val] = val + } else { + methods[val] = ctMethod + } + } + return methods +} + +// getUpperMethodString get upper string of method, and panic if the method +// is not valid +func (p *ControllerRegister) getUpperMethodString(method string) string { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + return method +} + +// get reflect controller type and method by controller method expression +func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) { + // check f is a function + funcType := reflect.TypeOf(f) + if funcType.Kind() != reflect.Func { + panic("not a method") + } + + // get function name + funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer()) + if funcObj == nil { + panic("cannot find the method") + } + funcNameSli := strings.Split(funcObj.Name(), ".") + lFuncSli := len(funcNameSli) + if lFuncSli == 0 { + panic("invalid method full name: " + funcObj.Name()) + } + + method = funcNameSli[lFuncSli-1] + if len(method) == 0 { + panic("method name is empty") + } else if method[0] > 96 || method[0] < 65 { + panic(fmt.Sprintf("%s is not a public method", method)) + } + + // check only one param which is the method receiver + if numIn := funcType.NumIn(); numIn != 1 { + panic("invalid number of param in") + } + + controllerType = funcType.In(0) + + // check controller has the method + _, exists := controllerType.MethodByName(method) + if !exists { + panic(controllerType.String() + " has no method " + method) + } + + // check the receiver implement ControllerInterface + if controllerType.Kind() == reflect.Ptr { + controllerType = controllerType.Elem() + } + controller := reflect.New(controllerType) + _, ok := controller.Interface().(ControllerInterface) + if !ok { + panic(controllerType.String() + " is not implemented ControllerInterface") + } + + return +} + // Get add get method // usage: // Get("/", func(ctx *context.Context){ @@ -372,34 +683,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) { // ctx.Output.Body("hello world") // }) func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } + method = p.getUpperMethodString(method) + + route := p.createRestfulRouter(f, pattern) + methods := p.getHttpMethodMapMethod(method, "") route.methods = methods - for k := range methods { - p.addToRouter(k, pattern, route) - } + + p.addRouterForMethod(route) } // Handler add user defined Handler func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h + route := p.createHandlerRouter(h, pattern) if len(options) > 0 { if _, ok := options[0].(bool); ok { pattern = path.Join(pattern, "?:all(.*)") @@ -431,15 +726,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) controllerName := strings.TrimSuffix(ct.Name(), "Controller") for i := 0; i < rt.NumMethod(); i++ { if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern + + route := p.createBeegoRouter(ct, pattern) + route.methods = map[string]string{"*": rt.Method(i).Name} for m := range HTTPMETHOD { p.addToRouter(m, pattern, route) p.addToRouter(m, patternInit, route) @@ -472,12 +765,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // } // } func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { - root := p.chainRoot - filterFunc := chain(root.filterFunc) - opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) - p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) - p.chainRoot.next = root + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, + }) } // add Filter into @@ -542,7 +836,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str for _, l := range t.leaves { if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { find := false if HTTPMETHOD[strings.ToUpper(methodName)] { if len(c.methods) == 0 { @@ -664,12 +958,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool ) if p.cfg.RecoverFunc != nil { @@ -735,7 +1032,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - if p.cfg.WebConfig.Session.SessionOn { + currentSessionOn = p.cfg.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) diff --git a/server/web/router_test.go b/server/web/router_test.go index 87997322..3633aee7 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -16,6 +16,7 @@ package web import ( "bytes" + "fmt" "net/http" "net/http/httptest" "strings" @@ -26,6 +27,25 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +type PrefixTestController struct { + Controller +} + +func (ptc *PrefixTestController) PrefixList() { + ptc.Ctx.Output.Body([]byte("i am list in prefix test")) +} + +type TestControllerWithInterface struct { +} + +func (m TestControllerWithInterface) Ping() { + fmt.Println("pong") +} + +func (m *TestControllerWithInterface) PingPointer() { + fmt.Println("pong pointer") +} + type TestController struct { Controller } @@ -87,10 +107,24 @@ func (jc *JSONController) Get() { jc.Ctx.Output.Body([]byte("ok")) } +func TestPrefixUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) + + if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { + logs.Info(a) + t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") + } + if a := handler.URLFor(`TestController.PrefixList`); a != `` { + logs.Info(a) + t.Errorf("TestController.PrefixList must equal to empty string") + } +} + func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -113,9 +147,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +179,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -235,7 +269,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -249,7 +283,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -266,7 +300,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -750,3 +784,321 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } + +} + +func TestRouterRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterGet("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterGet can't run") + } +} + +func TestRouterRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPost("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPost can't run") + } +} + +func TestRouterRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterHead("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterHead can't run") + } +} + +func TestRouterRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPut("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPut can't run") + } +} + +func TestRouterRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPatch("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPatch can't run") + } +} + +func TestRouterRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterDelete("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterDelete can't run") + } +} + +func TestRouterRouterAny(t *testing.T) { + handler := NewControllerRegister() + handler.RouterAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterRouterGetPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterGet("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterGetPointerMethod can't run") + } +} + +func TestRouterRouterPostPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPost("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPostPointerMethod can't run") + } +} + +func TestRouterRouterHeadPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterHead("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterHeadPointerMethod can't run") + } +} + +func TestRouterRouterPutPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPut("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPutPointerMethod can't run") + } +} + +func TestRouterRouterPatchPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPatch("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPatchPointerMethod can't run") + } +} + +func TestRouterRouterDeletePointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterDelete("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterDeletePointerMethod can't run") + } +} + +func TestRouterRouterAnyPointerMethod(t *testing.T) { + handler := NewControllerRegister() + handler.RouterAny("/user", (*ExampleController).PingPointer) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterAnyPointerMethod can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) { + method := "some random method" + message := "not support http method: " + strings.ToUpper(method) + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.Ping) +} + +func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) { + method := http.MethodGet + message := "not a method" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController{}) +} + +func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) { + method := http.MethodGet + message := "ping is not a public method" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.ping) +} + +func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping) +} + +func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) +} diff --git a/server/web/server.go b/server/web/server.go index f0a4f4ea..e1ef1a03 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { initBeforeHTTPRun() + // init... app.initAddr(addr) + app.Handlers.Init() addr = app.Cfg.Listen.HTTPAddr @@ -267,7 +269,11 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { // Router see HttpServer.Router func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - return BeeApp.Router(rootpath, c, mappingMethods...) + return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) +} + +func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + return BeeApp.RouterWithOpts(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -287,7 +293,11 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - app.Handlers.Add(rootPath, c, mappingMethods...) + return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) +} + +func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) return app } @@ -453,6 +463,166 @@ func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpSer return app } +// RouterGet see HttpServer.RouterGet +func RouterGet(rootpath string, f interface{}) { + BeeApp.RouterGet(rootpath, f) +} + +// RouterGet used to register router for RouterGet method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterGet("/api/:id", MyController.Ping) +func (app *HttpServer) RouterGet(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterGet(rootpath, f) + return app +} + +// RouterPost see HttpServer.RouterGet +func RouterPost(rootpath string, f interface{}) { + BeeApp.RouterPost(rootpath, f) +} + +// RouterPost used to register router for RouterPost method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPost("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPost(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPost(rootpath, f) + return app +} + +// RouterHead see HttpServer.RouterHead +func RouterHead(rootpath string, f interface{}) { + BeeApp.RouterHead(rootpath, f) +} + +// RouterHead used to register router for RouterHead method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterHead("/api/:id", MyController.Ping) +func (app *HttpServer) RouterHead(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterHead(rootpath, f) + return app +} + +// RouterPut see HttpServer.RouterPut +func RouterPut(rootpath string, f interface{}) { + BeeApp.RouterPut(rootpath, f) +} + +// RouterPut used to register router for RouterPut method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPut("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPut(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPut(rootpath, f) + return app +} + +// RouterPatch see HttpServer.RouterPatch +func RouterPatch(rootpath string, f interface{}) { + BeeApp.RouterPatch(rootpath, f) +} + +// RouterPatch used to register router for RouterPatch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPatch("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPatch(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPatch(rootpath, f) + return app +} + +// RouterDelete see HttpServer.RouterDelete +func RouterDelete(rootpath string, f interface{}) { + BeeApp.RouterDelete(rootpath, f) +} + +// RouterDelete used to register router for RouterDelete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterDelete("/api/:id", MyController.Ping) +func (app *HttpServer) RouterDelete(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterDelete(rootpath, f) + return app +} + +// RouterOptions see HttpServer.RouterOptions +func RouterOptions(rootpath string, f interface{}) { + BeeApp.RouterOptions(rootpath, f) +} + +// RouterOptions used to register router for RouterOptions method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterOptions("/api/:id", MyController.Ping) +func (app *HttpServer) RouterOptions(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterOptions(rootpath, f) + return app +} + +// RouterAny see HttpServer.RouterAny +func RouterAny(rootpath string, f interface{}) { + BeeApp.RouterAny(rootpath, f) +} + +// RouterAny used to register router for RouterAny method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterAny("/api/:id", MyController.Ping) +func (app *HttpServer) RouterAny(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterAny(rootpath, f) + return app +} + // Get see HttpServer.Get func Get(rootpath string, f FilterFunc) *HttpServer { return BeeApp.Get(rootpath, f) diff --git a/server/web/server_test.go b/server/web/server_test.go index 0b0c601c..0734be77 100644 --- a/server/web/server_test.go +++ b/server/web/server_test.go @@ -15,6 +15,8 @@ package web import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -28,3 +30,82 @@ func TestNewHttpServerWithCfg(t *testing.T) { assert.Equal(t, "hello", BConfig.AppName) } + +func TestServerRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + RouterGet("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterGet can't run") + } +} + +func TestServerRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + RouterPost("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPost can't run") + } +} + +func TestServerRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + RouterHead("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterHead can't run") + } +} + +func TestServerRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + RouterPut("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPut can't run") + } +} + +func TestServerRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + RouterPatch("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPatch can't run") + } +} + +func TestServerRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + RouterDelete("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterDelete can't run") + } +} + +func TestServerRouterAny(t *testing.T) { + RouterAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + r, _ := http.NewRequest(method, "/user", nil) + w := httptest.NewRecorder() + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterAny can't run") + } + } +} diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index fe5c363b..2b15eef1 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -15,21 +15,22 @@ import ( ) func TestRedis(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - } - redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { redisAddr = "127.0.0.1:6379" } + redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) + + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig(redisConfig), + ) - sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) globalSession, err := session.NewManager("redis", sessionConfig) if err != nil { t.Fatal("could not create manager:", err) diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 0a8030ce..489e8998 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -13,15 +13,15 @@ import ( ) func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), + ) globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) if e != nil { t.Log(e) diff --git a/server/web/session/session.go b/server/web/session/session.go index 6b53ec29..ca0407e8 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) { return provider, nil } -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` - CookieSameSite http.SameSite `json:"cookieSameSite"` -} - // Manager contains Provider and its configuration. type Manager struct { provider Provider diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go new file mode 100644 index 00000000..e42247db --- /dev/null +++ b/server/web/session/session_config.go @@ -0,0 +1,143 @@ +package session + +import "net/http" + +// ManagerConfig define the session config +type ManagerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` + CookieSameSite http.SameSite `json:"cookieSameSite"` +} + +func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { + for _, opt := range opts { + opt(c) + } +} + +type ManagerConfigOpt func(config *ManagerConfig) + +func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { + config := &ManagerConfig{} + for _, opt := range opts { + opt(config) + } + return config +} + +// CfgCookieName set key of session id +func CfgCookieName(cookieName string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieName = cookieName + } +} + +// CfgCookieName set len of session id +func CfgSessionIdLength(len int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDLength = len + } +} + +// CfgSessionIdPrefix set prefix of session id +func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDPrefix = prefix + } +} + +//CfgSetCookie whether set `Set-Cookie` header in HTTP response +func CfgSetCookie(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSetCookie = enable + } +} + +//CfgGcLifeTime set session gc lift time +func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Gclifetime = lifeTime + } +} + +//CfgMaxLifeTime set session lift time +func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Maxlifetime = lifeTime + } +} + +//CfgGcLifeTime set session lift time +func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieLifeTime = lifeTime + } +} + +//CfgProviderConfig configure session provider +func CfgProviderConfig(providerConfig string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.ProviderConfig = providerConfig + } +} + +//CfgDomain set cookie domain +func CfgDomain(domain string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Domain = domain + } +} + +//CfgSessionIdInHTTPHeader enable session id in http header +func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInHTTPHeader = enable + } +} + +//CfgSetSessionNameInHTTPHeader set key of session id in http header +func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionNameInHTTPHeader = name + } +} + +//EnableSidInURLQuery enable session id in query string +func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInURLQuery = enable + } +} + +//DisableHTTPOnly set HTTPOnly for http.Cookie +func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.DisableHTTPOnly = !HTTPOnly + } +} + +//CfgSecure set Secure for http.Cookie +func CfgSecure(Enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Secure = Enable + } +} + +//CfgSameSite set http.SameSite +func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieSameSite = sameSite + } +} diff --git a/server/web/session/session_config_test.go b/server/web/session/session_config_test.go new file mode 100644 index 00000000..a596c5c6 --- /dev/null +++ b/server/web/session/session_config_test.go @@ -0,0 +1,222 @@ +package session + +import ( + "net/http" + "testing" +) + +func TestCfgCookieLifeTime(t *testing.T) { + value := 8754 + c := NewManagerConfig( + CfgCookieLifeTime(value), + ) + + if c.CookieLifeTime != value { + t.Error() + } +} + +func TestCfgDomain(t *testing.T) { + value := `http://domain.com` + c := NewManagerConfig( + CfgDomain(value), + ) + + if c.Domain != value { + t.Error() + } +} + +func TestCfgSameSite(t *testing.T) { + value := http.SameSiteLaxMode + c := NewManagerConfig( + CfgSameSite(value), + ) + + if c.CookieSameSite != value { + t.Error() + } +} + +func TestCfgSecure(t *testing.T) { + c := NewManagerConfig( + CfgSecure(true), + ) + + if c.Secure != true { + t.Error() + } +} + +func TestCfgSecure1(t *testing.T) { + c := NewManagerConfig( + CfgSecure(false), + ) + + if c.Secure != false { + t.Error() + } +} + +func TestCfgSessionIdPrefix(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSessionIdPrefix(value), + ) + + if c.SessionIDPrefix != value { + t.Error() + } +} + +func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSetSessionNameInHTTPHeader(value), + ) + + if c.SessionNameInHTTPHeader != value { + t.Error() + } +} + +func TestCfgCookieName(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgCookieName(value), + ) + + if c.CookieName != value { + t.Error() + } +} + +func TestCfgEnableSidInURLQuery(t *testing.T) { + c := NewManagerConfig( + CfgEnableSidInURLQuery(true), + ) + + if c.EnableSidInURLQuery != true { + t.Error() + } +} + +func TestCfgGcLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgGcLifeTime(value), + ) + + if c.Gclifetime != value { + t.Error() + } +} + +func TestCfgHTTPOnly(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(true), + ) + + if c.DisableHTTPOnly != false { + t.Error() + } +} + +func TestCfgHTTPOnly2(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(false), + ) + + if c.DisableHTTPOnly != true { + t.Error() + } +} + +func TestCfgMaxLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgMaxLifeTime(value), + ) + + if c.Maxlifetime != value { + t.Error() + } +} + +func TestCfgProviderConfig(t *testing.T) { + value := `asodiuasldkj12i39809as` + c := NewManagerConfig( + CfgProviderConfig(value), + ) + + if c.ProviderConfig != value { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(true), + ) + + if c.EnableSidInHTTPHeader != true { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader1(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(false), + ) + + if c.EnableSidInHTTPHeader != false { + t.Error() + } +} + +func TestCfgSessionIdLength(t *testing.T) { + value := int64(100) + c := NewManagerConfig( + CfgSessionIdLength(value), + ) + + if c.SessionIDLength != value { + t.Error() + } +} + +func TestCfgSetCookie(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(true), + ) + + if c.EnableSetCookie != true { + t.Error() + } +} + +func TestCfgSetCookie1(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(false), + ) + + if c.EnableSetCookie != false { + t.Error() + } +} + +func TestNewManagerConfig(t *testing.T) { + c := NewManagerConfig() + if c == nil { + t.Error() + } +} + +func TestManagerConfig_Opts(t *testing.T) { + c := NewManagerConfig() + c.Opts(CfgSetCookie(true)) + + if c.EnableSetCookie != true { + t.Error() + } +} \ No newline at end of file diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go new file mode 100644 index 00000000..78dc116d --- /dev/null +++ b/server/web/session/session_provider_type.go @@ -0,0 +1,18 @@ +package session + +type ProviderType string + +const ( + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` +) diff --git a/server/web/tree.go b/server/web/tree.go index dc459c49..5a765fd3 100644 --- a/server/web/tree.go +++ b/server/web/tree.go @@ -210,9 +210,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) { func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { if len(segments) == 0 { if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...) } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...) } } else { seg := segments[0] diff --git a/server/web/tree_test.go b/server/web/tree_test.go index 3cb39c60..43511ad8 100644 --- a/server/web/tree_test.go +++ b/server/web/tree_test.go @@ -90,7 +90,17 @@ func init() { routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"})) routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"})) routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"})) - + routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"})) + routers = append(routers, matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"})) + routers = append(routers, matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"})) // not match example // https://github.com/beego/beego/v2/issues/3865 diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index c675ae7d..226cffb8 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..1a12fb33 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,7 @@ +sonar.organization=beego +sonar.projectKey=beego_beego + +# relative paths to source directories. More details and properties are described +# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ +sonar.sources=. +sonar.exclusions=**/*_test.go \ No newline at end of file diff --git a/task/governor_command_test.go b/task/governor_command_test.go index 00ed37f2..c3547cdf 100644 --- a/task/governor_command_test.go +++ b/task/governor_command_test.go @@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time { return time.Now() } +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + func TestRunTaskCommand_Execute(t *testing.T) { task := &countTask{} AddTask("count", task) diff --git a/task/task.go b/task/task.go index 2ea34f24..00e67c4b 100644 --- a/task/task.go +++ b/task/task.go @@ -109,6 +109,7 @@ type Tasker interface { GetNext(ctx context.Context) time.Time SetPrev(context.Context, time.Time) GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -127,13 +128,14 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + Timeout time.Duration // timeout duration + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, @@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { // we only store the pointer, so it won't use too many space Errlist: make([]*taskerr, 100, 100), } + + for _, opt := range opts { + opt.apply(task) + } + task.SetCron(spec) return task } @@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return a option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -455,14 +487,12 @@ func (m *taskManager) StartTask() { func (m *taskManager) run() { now := time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + // first run the tasks, so set all tasks next run time. + m.setTasksStartTime(now) for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing + // here, we sort all task and get first task running time (effective). m.taskLock.RLock() sortList := NewMapSorter(m.adminTaskList) m.taskLock.RUnlock() @@ -475,37 +505,75 @@ func (m *taskManager) run() { } else { effective = sortList.Vals[0].GetNext(context.Background()) } + select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext(context.Background()) != effective { - break - } - go e.Run(nil) - e.SetPrev(context.Background(), e.GetNext(context.Background())) - e.SetNext(nil, effective) - } + case now = <-time.After(effective.Sub(now)): // wait for effective time + runNextTasks(sortList, effective) continue - case <-m.changed: + case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + m.setTasksStartTime(now) continue - case <-m.stop: - m.taskLock.Lock() - if m.started { - m.started = false - } - m.taskLock.Unlock() + case <-m.stop: // manager is stopped, and mark manager is stopped + m.markManagerStop() return } } } +// setTasksStartTime is set all tasks next running time +func (m *taskManager) setTasksStartTime(now time.Time) { + m.taskLock.Lock() + for _, task := range m.adminTaskList { + task.SetNext(context.Background(), now) + } + m.taskLock.Unlock() +} + +// markManagerStop it sets manager to be stopped +func (m *taskManager) markManagerStop() { + m.taskLock.Lock() + if m.started { + m.started = false + } + m.taskLock.Unlock() +} + +// runNextTasks it runs next task which next run time is equal to effective +func runNextTasks(sortList *MapSorter, effective time.Time) { + // Run every entry whose next time was this effective time. + var i = 0 + for _, e := range sortList.Vals { + i++ + if e.GetNext(context.Background()) != effective { + break + } + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + if duration := e.GetTimeout(ctx); duration != 0 { + go func(e Tasker) { + ctx, cancelFunc := context.WithTimeout(ctx, duration) + defer cancelFunc() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } else { + go func(e Tasker) { + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } + + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(context.Background(), effective) + } +} + // StopTask stop all tasks func (m *taskManager) StopTask() { go func() { diff --git a/task/task_test.go b/task/task_test.go index 5e117cbd..1078aa01 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -90,6 +90,57 @@ func TestSpec(t *testing.T) { } } +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + once1, once2 := sync.Once{}, sync.Once{} + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + once1.Do(func() { + fmt.Println("tk1 done") + wg.Done() + }) + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/11 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + once2.Do(func() { + fmt.Println("tk2 done") + wg.Done() + }) + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + func TestTask_Run(t *testing.T) { cnt := -1 task := func(ctx context.Context) error { @@ -109,6 +160,23 @@ func TestTask_Run(t *testing.T) { assert.Equal(t, "Hello, world! 101", l[1].errinfo) } +func TestCrudTask(t *testing.T) { + m := newTaskManager() + m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.DeleteTask("my-task1") + assert.Equal(t, 1, len(m.adminTaskList)) + + m.ClearTask() + assert.Equal(t, 0, len(m.adminTaskList)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() {