Merge branch 'develop' of https://gitclone.com/github.com/beego/beego into frt/delete-txorm
# Conflicts: # CHANGELOG.md
This commit is contained in:
commit
e8448a520f
2
.github/workflows/changelog.yml
vendored
2
.github/workflows/changelog.yml
vendored
@ -8,7 +8,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, labeled, unlabeled]
|
types: [opened, synchronize, reopened, labeled, unlabeled]
|
||||||
branches:
|
branches:
|
||||||
- master
|
- develop
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
changelog:
|
changelog:
|
||||||
|
|||||||
32
.github/workflows/golangci-lint.yml
vendored
Normal file
32
.github/workflows/golangci-lint.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
name: golangci-lint
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
jobs:
|
||||||
|
golangci:
|
||||||
|
name: lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v2
|
||||||
|
with:
|
||||||
|
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
|
||||||
|
version: v1.29
|
||||||
|
|
||||||
|
# Optional: working directory, useful for monorepos
|
||||||
|
# working-directory: ./
|
||||||
|
|
||||||
|
# Optional: golangci-lint command line arguments.
|
||||||
|
args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true
|
||||||
|
|
||||||
|
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||||
|
only-new-issues: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action will use pre-installed Go
|
||||||
|
# skip-go-installation: true
|
||||||
12
CHANGELOG.md
12
CHANGELOG.md
@ -1,7 +1,19 @@
|
|||||||
# developing
|
# developing
|
||||||
|
- ORM mock. [4407](https://github.com/beego/beego/pull/4407)
|
||||||
|
- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433)
|
||||||
|
- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427)
|
||||||
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
|
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
|
||||||
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
|
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
|
||||||
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
|
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
|
||||||
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
|
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
|
||||||
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
|
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
|
||||||
|
- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294)
|
||||||
|
- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404)
|
||||||
|
- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416)
|
||||||
|
- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424)
|
||||||
|
- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441)
|
||||||
|
- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453)
|
||||||
|
- Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446)
|
||||||
|
- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452)
|
||||||
|
- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456)
|
||||||
- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461)
|
- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461)
|
||||||
|
|||||||
1
ERROR_SPECIFICATION.md
Normal file
1
ERROR_SPECIFICATION.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Error Module
|
||||||
@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister {
|
|||||||
// Add("/api",&RestController{},"get,post:ApiFunc"
|
// Add("/api",&RestController{},"get,post:ApiFunc"
|
||||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||||
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
||||||
(*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...)
|
(*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
|
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
|
||||||
|
|||||||
@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) {
|
|||||||
func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time {
|
func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time {
|
||||||
return o.delegate.GetPrev()
|
return o.delegate.GetPrev()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
131
client/httplib/error_code.go
Normal file
131
client/httplib/error_code.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/beego/beego/v2/core/berror"
|
||||||
|
)
|
||||||
|
|
||||||
|
var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", `
|
||||||
|
You pass a invalid url to httplib module. Please check your url, be careful about special character.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", `
|
||||||
|
You pass a invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2
|
||||||
|
But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2.
|
||||||
|
but you must confirm that server support those abnormal protocol version.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", `
|
||||||
|
You use a invalid data as request body.
|
||||||
|
For now, we only support type string and byte[].
|
||||||
|
`)
|
||||||
|
|
||||||
|
var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", `
|
||||||
|
You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well.
|
||||||
|
Sometimes you got XML document and you want to make it as request body. So you call XMLBody.
|
||||||
|
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", `
|
||||||
|
You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well.
|
||||||
|
Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody.
|
||||||
|
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", `
|
||||||
|
You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well.
|
||||||
|
Sometimes you got JSON document and you want to make it as request body. So you call JSONBody.
|
||||||
|
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
|
||||||
|
`)
|
||||||
|
|
||||||
|
|
||||||
|
// start with 5 --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", `
|
||||||
|
In normal case than handling files with BeegoRequest, you should not see this error code.
|
||||||
|
Unexpected EOF, invalid characters, bad file descriptor may cause this error.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", `
|
||||||
|
There are several cases that cause this error:
|
||||||
|
1. file not found. Please check the file name;
|
||||||
|
2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code.
|
||||||
|
make sure that this file is in correct directory which Beego looks for;
|
||||||
|
3. Beego don't have the privilege to read this file, please change file mode;
|
||||||
|
`)
|
||||||
|
|
||||||
|
var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", `
|
||||||
|
When we try to read file content and then copy it to another writer, and failed.
|
||||||
|
1. Unexpected EOF;
|
||||||
|
2. Bad file descriptor;
|
||||||
|
3. Write conflict;
|
||||||
|
|
||||||
|
Please check your file content, and confirm that file is not processed by other process (or by user manually).
|
||||||
|
`)
|
||||||
|
|
||||||
|
var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", `
|
||||||
|
After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", `
|
||||||
|
Beego send HTTP request, but it failed.
|
||||||
|
If you config retry times, it means that Beego had retried and failed.
|
||||||
|
When you got this error, there are vary kind of reason:
|
||||||
|
1. Network unstable and timeout. In this case, sometimes server has received the request.
|
||||||
|
2. Server error. Make sure that server works well.
|
||||||
|
3. The request is invalid, which means that you pass some invalid parameter.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", `
|
||||||
|
Beego parse gzip-encode body failed. Usually Beego got invalid response.
|
||||||
|
Please confirm that server returns gzip data.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", `
|
||||||
|
Beego want to create file if not exist and failed.
|
||||||
|
In most cases, it means that Beego doesn't have the privilege to create this file.
|
||||||
|
Please change file mode to ensure that Beego is able to create files on specific directory.
|
||||||
|
Or you can run Beego with higher authority.
|
||||||
|
In some cases, you pass invalid filename. Make sure that the file name is valid on your system.
|
||||||
|
`)
|
||||||
|
|
||||||
|
var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName,
|
||||||
|
"UnmarshalResponseToObjectFailed", `
|
||||||
|
Beego trying to unmarshal response's body to structure but failed.
|
||||||
|
Make sure that:
|
||||||
|
1. You pass valid structure pointer to the function;
|
||||||
|
2. The body is valid json document
|
||||||
|
`)
|
||||||
|
|
||||||
|
var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName,
|
||||||
|
"UnmarshalResponseToObjectFailed", `
|
||||||
|
Beego trying to unmarshal response's body to structure but failed.
|
||||||
|
Make sure that:
|
||||||
|
1. You pass valid structure pointer to the function;
|
||||||
|
2. The body is valid XML document
|
||||||
|
`)
|
||||||
|
|
||||||
|
var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName,
|
||||||
|
"UnmarshalResponseToObjectFailed", `
|
||||||
|
Beego trying to unmarshal response's body to structure but failed.
|
||||||
|
Make sure that:
|
||||||
|
1. You pass valid structure pointer to the function;
|
||||||
|
2. The body is valid YAML document
|
||||||
|
`)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
@ -26,15 +27,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
summaryVec prometheus.ObserverVec
|
|
||||||
AppName string
|
AppName string
|
||||||
ServerName string
|
ServerName string
|
||||||
RunMode string
|
RunMode string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var summaryVec prometheus.ObserverVec
|
||||||
|
var initSummaryVec sync.Once
|
||||||
|
|
||||||
func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter {
|
func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter {
|
||||||
|
|
||||||
builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
initSummaryVec.Do(func() {
|
||||||
|
summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||||
Name: "beego",
|
Name: "beego",
|
||||||
Subsystem: "remote_http_request",
|
Subsystem: "remote_http_request",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
@ -45,6 +49,9 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt
|
|||||||
Help: "The statics info for remote http requests",
|
Help: "The statics info for remote http requests",
|
||||||
}, []string{"proto", "scheme", "method", "host", "path", "status", "isError"})
|
}, []string{"proto", "scheme", "method", "host", "path", "status", "isError"})
|
||||||
|
|
||||||
|
prometheus.MustRegister(summaryVec)
|
||||||
|
})
|
||||||
|
|
||||||
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
resp, err := next(ctx, req)
|
resp, err := next(ctx, req)
|
||||||
@ -72,6 +79,6 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time
|
|||||||
|
|
||||||
dur := int(endTime.Sub(startTime) / time.Millisecond)
|
dur := int(endTime.Sub(startTime) / time.Millisecond)
|
||||||
|
|
||||||
builder.summaryVec.WithLabelValues(proto, scheme, method, host, path,
|
summaryVec.WithLabelValues(proto, scheme, method, host, path,
|
||||||
strconv.Itoa(status), strconv.FormatBool(err != nil)).Observe(float64(dur))
|
strconv.Itoa(status), strconv.FormatBool(err != nil)).Observe(float64(dur))
|
||||||
}
|
}
|
||||||
|
|||||||
36
client/httplib/http_response_test.go
Normal file
36
client/httplib/http_response_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
func TestNewHttpResponseWithJsonBody(t *testing.T) {
|
||||||
|
// string
|
||||||
|
resp := NewHttpResponseWithJsonBody("{}")
|
||||||
|
assert.Equal(t, int64(2), resp.ContentLength)
|
||||||
|
|
||||||
|
resp = NewHttpResponseWithJsonBody([]byte("{}"))
|
||||||
|
assert.Equal(t, int64(2), resp.ContentLength)
|
||||||
|
|
||||||
|
resp = NewHttpResponseWithJsonBody(&user{
|
||||||
|
Name: "Tom",
|
||||||
|
})
|
||||||
|
assert.True(t, resp.ContentLength > 0)
|
||||||
|
}
|
||||||
@ -40,59 +40,36 @@ import (
|
|||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/core/berror"
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultSetting = BeegoHTTPSettings{
|
|
||||||
UserAgent: "beegoServer",
|
|
||||||
ConnectTimeout: 60 * time.Second,
|
|
||||||
ReadWriteTimeout: 60 * time.Second,
|
|
||||||
Gzip: true,
|
|
||||||
DumpBody: true,
|
|
||||||
FilterChains: []FilterChain{mockFilter.FilterChain},
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultCookieJar http.CookieJar
|
|
||||||
var settingMutex sync.Mutex
|
|
||||||
|
|
||||||
// it will be the last filter and execute request.Do
|
// it will be the last filter and execute request.Do
|
||||||
var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||||
return req.doRequest(ctx)
|
return req.doRequest(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createDefaultCookie creates a global cookiejar to store cookies.
|
|
||||||
func createDefaultCookie() {
|
|
||||||
settingMutex.Lock()
|
|
||||||
defer settingMutex.Unlock()
|
|
||||||
defaultCookieJar, _ = cookiejar.New(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDefaultSetting overwrites default settings
|
|
||||||
func SetDefaultSetting(setting BeegoHTTPSettings) {
|
|
||||||
settingMutex.Lock()
|
|
||||||
defer settingMutex.Unlock()
|
|
||||||
defaultSetting = setting
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBeegoRequest returns *BeegoHttpRequest with specific method
|
// NewBeegoRequest returns *BeegoHttpRequest with specific method
|
||||||
|
// TODO add error as return value
|
||||||
|
// I think if we don't return error
|
||||||
|
// users are hard to check whether we create Beego request successfully
|
||||||
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
|
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
|
||||||
var resp http.Response
|
var resp http.Response
|
||||||
u, err := url.Parse(rawurl)
|
u, err := url.Parse(rawurl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Httplib:", err)
|
logs.Error("%+v", berror.Wrapf(err, InvalidUrl, "invalid raw url: %s", rawurl))
|
||||||
}
|
}
|
||||||
req := http.Request{
|
req := http.Request{
|
||||||
URL: u,
|
URL: u,
|
||||||
@ -137,24 +114,6 @@ func Head(url string) *BeegoHTTPRequest {
|
|||||||
return NewBeegoRequest(url, "HEAD")
|
return NewBeegoRequest(url, "HEAD")
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeegoHTTPSettings is the http.Client setting
|
|
||||||
type BeegoHTTPSettings struct {
|
|
||||||
ShowDebug bool
|
|
||||||
UserAgent string
|
|
||||||
ConnectTimeout time.Duration
|
|
||||||
ReadWriteTimeout time.Duration
|
|
||||||
TLSClientConfig *tls.Config
|
|
||||||
Proxy func(*http.Request) (*url.URL, error)
|
|
||||||
Transport http.RoundTripper
|
|
||||||
CheckRedirect func(req *http.Request, via []*http.Request) error
|
|
||||||
EnableCookie bool
|
|
||||||
Gzip bool
|
|
||||||
DumpBody bool
|
|
||||||
Retries int // if set to -1 means will retry forever
|
|
||||||
RetryDelay time.Duration
|
|
||||||
FilterChains []FilterChain
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url.
|
// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url.
|
||||||
type BeegoHTTPRequest struct {
|
type BeegoHTTPRequest struct {
|
||||||
url string
|
url string
|
||||||
@ -254,7 +213,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetProtocolVersion sets the protocol version for incoming requests.
|
// SetProtocolVersion sets the protocol version for incoming requests.
|
||||||
// Client requests always use HTTP/1.1.
|
// Client requests always use HTTP/1.1
|
||||||
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
|
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
|
||||||
if len(vers) == 0 {
|
if len(vers) == 0 {
|
||||||
vers = "HTTP/1.1"
|
vers = "HTTP/1.1"
|
||||||
@ -265,8 +224,9 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
|
|||||||
b.req.Proto = vers
|
b.req.Proto = vers
|
||||||
b.req.ProtoMajor = major
|
b.req.ProtoMajor = major
|
||||||
b.req.ProtoMinor = minor
|
b.req.ProtoMinor = minor
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers))
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,6 +294,7 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest
|
|||||||
|
|
||||||
// Body adds request raw body.
|
// Body adds request raw body.
|
||||||
// Supports string and []byte.
|
// Supports string and []byte.
|
||||||
|
// TODO return error if data is invalid
|
||||||
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
|
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
|
||||||
switch t := data.(type) {
|
switch t := data.(type) {
|
||||||
case string:
|
case string:
|
||||||
@ -350,6 +311,8 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
|
|||||||
return ioutil.NopCloser(bf), nil
|
return ioutil.NopCloser(bf), nil
|
||||||
}
|
}
|
||||||
b.req.ContentLength = int64(len(t))
|
b.req.ContentLength = int64(len(t))
|
||||||
|
default:
|
||||||
|
logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
@ -359,9 +322,12 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
|
|||||||
if b.req.Body == nil && obj != nil {
|
if b.req.Body == nil && obj != nil {
|
||||||
byts, err := xml.Marshal(obj)
|
byts, err := xml.Marshal(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return b, err
|
return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data")
|
||||||
}
|
}
|
||||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||||
|
b.req.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return ioutil.NopCloser(bytes.NewReader(byts)), nil
|
||||||
|
}
|
||||||
b.req.ContentLength = int64(len(byts))
|
b.req.ContentLength = int64(len(byts))
|
||||||
b.req.Header.Set("Content-Type", "application/xml")
|
b.req.Header.Set("Content-Type", "application/xml")
|
||||||
}
|
}
|
||||||
@ -373,7 +339,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error)
|
|||||||
if b.req.Body == nil && obj != nil {
|
if b.req.Body == nil && obj != nil {
|
||||||
byts, err := yaml.Marshal(obj)
|
byts, err := yaml.Marshal(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return b, err
|
return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data")
|
||||||
}
|
}
|
||||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||||
b.req.ContentLength = int64(len(byts))
|
b.req.ContentLength = int64(len(byts))
|
||||||
@ -387,7 +353,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error)
|
|||||||
if b.req.Body == nil && obj != nil {
|
if b.req.Body == nil && obj != nil {
|
||||||
byts, err := json.Marshal(obj)
|
byts, err := json.Marshal(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return b, err
|
return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body")
|
||||||
}
|
}
|
||||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||||
b.req.ContentLength = int64(len(byts))
|
b.req.ContentLength = int64(len(byts))
|
||||||
@ -415,28 +381,15 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
|||||||
bodyWriter := multipart.NewWriter(pw)
|
bodyWriter := multipart.NewWriter(pw)
|
||||||
go func() {
|
go func() {
|
||||||
for formname, filename := range b.files {
|
for formname, filename := range b.files {
|
||||||
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
|
b.handleFileToBody(bodyWriter, formname, filename)
|
||||||
if err != nil {
|
|
||||||
log.Println("Httplib:", err)
|
|
||||||
}
|
|
||||||
fh, err := os.Open(filename)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Httplib:", err)
|
|
||||||
}
|
|
||||||
// iocopy
|
|
||||||
_, err = io.Copy(fileWriter, fh)
|
|
||||||
fh.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Httplib:", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for k, v := range b.params {
|
for k, v := range b.params {
|
||||||
for _, vv := range v {
|
for _, vv := range v {
|
||||||
bodyWriter.WriteField(k, vv)
|
_ = bodyWriter.WriteField(k, vv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bodyWriter.Close()
|
_ = bodyWriter.Close()
|
||||||
pw.Close()
|
_ = pw.Close()
|
||||||
}()
|
}()
|
||||||
b.Header("Content-Type", bodyWriter.FormDataContentType())
|
b.Header("Content-Type", bodyWriter.FormDataContentType())
|
||||||
b.req.Body = ioutil.NopCloser(pr)
|
b.req.Body = ioutil.NopCloser(pr)
|
||||||
@ -452,6 +405,29 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) {
|
||||||
|
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
|
||||||
|
const errFmt = "Httplib: %+v"
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed,
|
||||||
|
"could not create form file, formname: %s, filename: %s", formname, filename))
|
||||||
|
}
|
||||||
|
fh, err := os.Open(filename)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename))
|
||||||
|
}
|
||||||
|
// iocopy
|
||||||
|
_, err = io.Copy(fileWriter, fh)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename))
|
||||||
|
}
|
||||||
|
err = fh.Close()
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
|
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
|
||||||
if b.resp.StatusCode != 0 {
|
if b.resp.StatusCode != 0 {
|
||||||
return b.resp, nil
|
return b.resp, nil
|
||||||
@ -480,62 +456,20 @@ func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Res
|
|||||||
return root(ctx, b)
|
return root(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) {
|
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (*http.Response, error) {
|
||||||
var paramBody string
|
paramBody := b.buildParamBody()
|
||||||
if len(b.params) > 0 {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
for k, v := range b.params {
|
|
||||||
for _, vv := range v {
|
|
||||||
buf.WriteString(url.QueryEscape(k))
|
|
||||||
buf.WriteByte('=')
|
|
||||||
buf.WriteString(url.QueryEscape(vv))
|
|
||||||
buf.WriteByte('&')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
paramBody = buf.String()
|
|
||||||
paramBody = paramBody[0 : len(paramBody)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
b.buildURL(paramBody)
|
b.buildURL(paramBody)
|
||||||
urlParsed, err := url.Parse(b.url)
|
urlParsed, err := url.Parse(b.url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.req.URL = urlParsed
|
b.req.URL = urlParsed
|
||||||
|
|
||||||
trans := b.setting.Transport
|
trans := b.buildTrans()
|
||||||
|
|
||||||
if trans == nil {
|
jar := b.buildCookieJar()
|
||||||
// create default transport
|
|
||||||
trans = &http.Transport{
|
|
||||||
TLSClientConfig: b.setting.TLSClientConfig,
|
|
||||||
Proxy: b.setting.Proxy,
|
|
||||||
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
|
|
||||||
MaxIdleConnsPerHost: 100,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// if b.transport is *http.Transport then set the settings.
|
|
||||||
if t, ok := trans.(*http.Transport); ok {
|
|
||||||
if t.TLSClientConfig == nil {
|
|
||||||
t.TLSClientConfig = b.setting.TLSClientConfig
|
|
||||||
}
|
|
||||||
if t.Proxy == nil {
|
|
||||||
t.Proxy = b.setting.Proxy
|
|
||||||
}
|
|
||||||
if t.Dial == nil {
|
|
||||||
t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var jar http.CookieJar
|
|
||||||
if b.setting.EnableCookie {
|
|
||||||
if defaultCookieJar == nil {
|
|
||||||
createDefaultCookie()
|
|
||||||
}
|
|
||||||
jar = defaultCookieJar
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Transport: trans,
|
Transport: trans,
|
||||||
@ -551,12 +485,16 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if b.setting.ShowDebug {
|
if b.setting.ShowDebug {
|
||||||
dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
|
dump, e := httputil.DumpRequest(b.req, b.setting.DumpBody)
|
||||||
if err != nil {
|
if e != nil {
|
||||||
log.Println(err.Error())
|
logs.Error("%+v", e)
|
||||||
}
|
}
|
||||||
b.dump = dump
|
b.dump = dump
|
||||||
}
|
}
|
||||||
|
return b.sendRequest(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) {
|
||||||
// retries default value is 0, it will run once.
|
// retries default value is 0, it will run once.
|
||||||
// retries equal to -1, it will run forever until success
|
// retries equal to -1, it will run forever until success
|
||||||
// retries is setted, it will retries fixed times.
|
// retries is setted, it will retries fixed times.
|
||||||
@ -564,11 +502,68 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response,
|
|||||||
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
|
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
|
||||||
resp, err = client.Do(b.req)
|
resp, err = client.Do(b.req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
time.Sleep(b.setting.RetryDelay)
|
time.Sleep(b.setting.RetryDelay)
|
||||||
}
|
}
|
||||||
return resp, err
|
return nil, berror.Wrap(err, SendRequestFailed, "sending request fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar {
|
||||||
|
var jar http.CookieJar
|
||||||
|
if b.setting.EnableCookie {
|
||||||
|
if defaultCookieJar == nil {
|
||||||
|
createDefaultCookie()
|
||||||
|
}
|
||||||
|
jar = defaultCookieJar
|
||||||
|
}
|
||||||
|
return jar
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper {
|
||||||
|
trans := b.setting.Transport
|
||||||
|
|
||||||
|
if trans == nil {
|
||||||
|
// create default transport
|
||||||
|
trans = &http.Transport{
|
||||||
|
TLSClientConfig: b.setting.TLSClientConfig,
|
||||||
|
Proxy: b.setting.Proxy,
|
||||||
|
DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
|
||||||
|
MaxIdleConnsPerHost: 100,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if b.transport is *http.Transport then set the settings.
|
||||||
|
if t, ok := trans.(*http.Transport); ok {
|
||||||
|
if t.TLSClientConfig == nil {
|
||||||
|
t.TLSClientConfig = b.setting.TLSClientConfig
|
||||||
|
}
|
||||||
|
if t.Proxy == nil {
|
||||||
|
t.Proxy = b.setting.Proxy
|
||||||
|
}
|
||||||
|
if t.DialContext == nil {
|
||||||
|
t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return trans
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BeegoHTTPRequest) buildParamBody() string {
|
||||||
|
var paramBody string
|
||||||
|
if len(b.params) > 0 {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for k, v := range b.params {
|
||||||
|
for _, vv := range v {
|
||||||
|
buf.WriteString(url.QueryEscape(k))
|
||||||
|
buf.WriteByte('=')
|
||||||
|
buf.WriteString(url.QueryEscape(vv))
|
||||||
|
buf.WriteByte('&')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
paramBody = buf.String()
|
||||||
|
paramBody = paramBody[0 : len(paramBody)-1]
|
||||||
|
}
|
||||||
|
return paramBody
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the body string in response.
|
// String returns the body string in response.
|
||||||
@ -599,10 +594,10 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
|
|||||||
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
reader, err := gzip.NewReader(resp.Body)
|
reader, err := gzip.NewReader(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed")
|
||||||
}
|
}
|
||||||
b.body, err = ioutil.ReadAll(reader)
|
b.body, err = ioutil.ReadAll(reader)
|
||||||
return b.body, err
|
return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed")
|
||||||
}
|
}
|
||||||
b.body, err = ioutil.ReadAll(resp.Body)
|
b.body, err = ioutil.ReadAll(resp.Body)
|
||||||
return b.body, err
|
return b.body, err
|
||||||
@ -645,7 +640,7 @@ func pathExistAndMkdir(filename string) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToJSON returns the map that marshals from the body bytes as json in response.
|
// ToJSON returns the map that marshals from the body bytes as json in response.
|
||||||
@ -655,7 +650,8 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return json.Unmarshal(data, v)
|
return berror.Wrap(json.Unmarshal(data, v),
|
||||||
|
UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToXML returns the map that marshals from the body bytes as xml in response .
|
// ToXML returns the map that marshals from the body bytes as xml in response .
|
||||||
@ -665,7 +661,8 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return xml.Unmarshal(data, v)
|
return berror.Wrap(xml.Unmarshal(data, v),
|
||||||
|
UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToYAML returns the map that marshals from the body bytes as yaml in response .
|
// ToYAML returns the map that marshals from the body bytes as yaml in response .
|
||||||
@ -675,7 +672,8 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return yaml.Unmarshal(data, v)
|
return berror.Wrap(yaml.Unmarshal(data, v),
|
||||||
|
UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response executes request client gets response manually.
|
// Response executes request client gets response manually.
|
||||||
@ -684,8 +682,18 @@ func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
|
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
|
||||||
|
// Deprecated
|
||||||
|
// we will move this at the end of 2021
|
||||||
|
// please use TimeoutDialerCtx
|
||||||
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
|
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
|
||||||
return func(netw, addr string) (net.Conn, error) {
|
return func(netw, addr string) (net.Conn, error) {
|
||||||
|
return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TimeoutDialerCtx(cTimeout time.Duration,
|
||||||
|
rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) {
|
||||||
|
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||||
conn, err := net.DialTimeout(netw, addr, cTimeout)
|
conn, err := net.DialTimeout(netw, addr, cTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -300,3 +300,136 @@ func TestAddFilter(t *testing.T) {
|
|||||||
r := Get("http://beego.me")
|
r := Get("http://beego.me")
|
||||||
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
|
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterChainOrder(t *testing.T) {
|
||||||
|
req := Get("http://beego.me")
|
||||||
|
req.AddFilters(func(next Filter) Filter {
|
||||||
|
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||||
|
return NewHttpResponseWithJsonBody("first"), nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
req.AddFilters(func(next Filter) Filter {
|
||||||
|
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||||
|
return NewHttpResponseWithJsonBody("second"), nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := req.DoRequestWithCtx(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
data := make([]byte, 5)
|
||||||
|
_, _ = resp.Body.Read(data)
|
||||||
|
assert.Equal(t, "first", string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHead(t *testing.T) {
|
||||||
|
req := Head("http://beego.me")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
assert.Equal(t, "HEAD", req.req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDelete(t *testing.T) {
|
||||||
|
req := Delete("http://beego.me")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
assert.Equal(t, "DELETE", req.req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPost(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
assert.Equal(t, "POST", req.req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBeegoRequest(t *testing.T) {
|
||||||
|
req := NewBeegoRequest("http://beego.me", "GET")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
assert.Equal(t, "GET", req.req.Method)
|
||||||
|
|
||||||
|
// invalid case but still go request
|
||||||
|
req = NewBeegoRequest("httpa\ta://beego.me", "GET")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeegoHTTPRequest_SetProtocolVersion(t *testing.T) {
|
||||||
|
req := NewBeegoRequest("http://beego.me", "GET")
|
||||||
|
req.SetProtocolVersion("HTTP/3.10")
|
||||||
|
assert.Equal(t, "HTTP/3.10", req.req.Proto)
|
||||||
|
assert.Equal(t, 3, req.req.ProtoMajor)
|
||||||
|
assert.Equal(t, 10, req.req.ProtoMinor)
|
||||||
|
|
||||||
|
req.SetProtocolVersion("")
|
||||||
|
assert.Equal(t, "HTTP/1.1", req.req.Proto)
|
||||||
|
assert.Equal(t, 1, req.req.ProtoMajor)
|
||||||
|
assert.Equal(t, 1, req.req.ProtoMinor)
|
||||||
|
|
||||||
|
// invalid case
|
||||||
|
req.SetProtocolVersion("HTTP/aaa1.1")
|
||||||
|
assert.Equal(t, "HTTP/1.1", req.req.Proto)
|
||||||
|
assert.Equal(t, 1, req.req.ProtoMajor)
|
||||||
|
assert.Equal(t, 1, req.req.ProtoMinor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPut(t *testing.T) {
|
||||||
|
req := Put("http://beego.me")
|
||||||
|
assert.NotNil(t, req)
|
||||||
|
assert.Equal(t, "PUT", req.req.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeegoHTTPRequest_Header(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
key, value := "test-header", "test-header-value"
|
||||||
|
req.Header(key, value)
|
||||||
|
assert.Equal(t, value, req.req.Header.Get(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeegoHTTPRequest_SetHost(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
host := "test-hose"
|
||||||
|
req.SetHost(host)
|
||||||
|
assert.Equal(t, host, req.req.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeegoHTTPRequest_Param(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
key, value := "test-param", "test-param-value"
|
||||||
|
req.Param(key, value)
|
||||||
|
assert.Equal(t, value, req.params[key][0])
|
||||||
|
|
||||||
|
value1 := "test-param-value-1"
|
||||||
|
req.Param(key, value1)
|
||||||
|
assert.Equal(t, value1, req.params[key][1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeegoHTTPRequest_Body(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
body := `hello, world`
|
||||||
|
req.Body([]byte(body))
|
||||||
|
assert.Equal(t, int64(len(body)), req.req.ContentLength)
|
||||||
|
assert.NotNil(t, req.req.GetBody)
|
||||||
|
assert.NotNil(t, req.req.Body)
|
||||||
|
|
||||||
|
body = "hhhh, i am test"
|
||||||
|
req.Body(body)
|
||||||
|
assert.Equal(t, int64(len(body)), req.req.ContentLength)
|
||||||
|
assert.NotNil(t, req.req.GetBody)
|
||||||
|
assert.NotNil(t, req.req.Body)
|
||||||
|
|
||||||
|
// invalid case
|
||||||
|
req.Body(13)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
type user struct {
|
||||||
|
Name string `xml:"name"`
|
||||||
|
}
|
||||||
|
func TestBeegoHTTPRequest_XMLBody(t *testing.T) {
|
||||||
|
req := Post("http://beego.me")
|
||||||
|
body := &user{
|
||||||
|
Name: "Tom",
|
||||||
|
}
|
||||||
|
_, err := req.XMLBody(body)
|
||||||
|
assert.True(t, req.req.ContentLength > 0)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, req.req.GetBody)
|
||||||
|
}
|
||||||
@ -12,17 +12,22 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
"github.com/beego/beego/v2/core/logs"
|
"github.com/beego/beego/v2/core/logs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mockCtxKey = "beego-httplib-mock"
|
const mockCtxKey = "beego-httplib-mock"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
InitMockSetting()
|
||||||
|
}
|
||||||
|
|
||||||
type Stub interface {
|
type Stub interface {
|
||||||
Mock(cond RequestCondition, resp *http.Response, err error)
|
Mock(cond RequestCondition, resp *http.Response, err error)
|
||||||
Clear()
|
Clear()
|
||||||
@ -31,6 +36,10 @@ type Stub interface {
|
|||||||
|
|
||||||
var mockFilter = &MockResponseFilter{}
|
var mockFilter = &MockResponseFilter{}
|
||||||
|
|
||||||
|
func InitMockSetting() {
|
||||||
|
httplib.AddDefaultFilter(mockFilter.FilterChain)
|
||||||
|
}
|
||||||
|
|
||||||
func StartMock() Stub {
|
func StartMock() Stub {
|
||||||
return mockFilter
|
return mockFilter
|
||||||
}
|
}
|
||||||
@ -12,17 +12,19 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestCondition interface {
|
type RequestCondition interface {
|
||||||
Match(ctx context.Context, req *BeegoHTTPRequest) bool
|
Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// reqCondition create condition
|
// reqCondition create condition
|
||||||
@ -54,7 +56,7 @@ func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondi
|
|||||||
return sc
|
return sc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
res := true
|
res := true
|
||||||
if len(sc.path) > 0 {
|
if len(sc.path) > 0 {
|
||||||
res = sc.matchPath(ctx, req)
|
res = sc.matchPath(ctx, req)
|
||||||
@ -70,12 +72,12 @@ func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) boo
|
|||||||
sc.matchBodyFields(ctx, req)
|
sc.matchBodyFields(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchPath(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
path := req.GetRequest().URL.Path
|
path := req.GetRequest().URL.Path
|
||||||
return path == sc.path
|
return path == sc.path
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
path := req.GetRequest().URL.Path
|
path := req.GetRequest().URL.Path
|
||||||
if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil {
|
if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil {
|
||||||
return b
|
return b
|
||||||
@ -83,7 +85,7 @@ func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPReque
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
qs := req.GetRequest().URL.Query()
|
qs := req.GetRequest().URL.Query()
|
||||||
for k, v := range sc.query {
|
for k, v := range sc.query {
|
||||||
if uv, ok := qs[k]; !ok || uv[0] != v {
|
if uv, ok := qs[k]; !ok || uv[0] != v {
|
||||||
@ -93,7 +95,7 @@ func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
headers := req.GetRequest().Header
|
headers := req.GetRequest().Header
|
||||||
for k, v := range sc.header {
|
for k, v := range sc.header {
|
||||||
if uv, ok := headers[k]; !ok || uv[0] != v {
|
if uv, ok := headers[k]; !ok || uv[0] != v {
|
||||||
@ -103,7 +105,7 @@ func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPReques
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
if len(sc.body) == 0 {
|
if len(sc.body) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -135,7 +137,7 @@ func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SimpleCondition) matchMethod(ctx context.Context, req *BeegoHTTPRequest) bool {
|
func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
|
||||||
if len(sc.method) > 0 {
|
if len(sc.method) > 0 {
|
||||||
return sc.method == req.GetRequest().Method
|
return sc.method == req.GetRequest().Method
|
||||||
}
|
}
|
||||||
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -28,37 +29,37 @@ func init() {
|
|||||||
|
|
||||||
func TestSimpleCondition_MatchPath(t *testing.T) {
|
func TestSimpleCondition_MatchPath(t *testing.T) {
|
||||||
sc := NewSimpleCondition("/abc/s")
|
sc := NewSimpleCondition("/abc/s")
|
||||||
res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s"))
|
res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s"))
|
||||||
assert.True(t, res)
|
assert.True(t, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSimpleCondition_MatchQuery(t *testing.T) {
|
func TestSimpleCondition_MatchQuery(t *testing.T) {
|
||||||
k, v := "my-key", "my-value"
|
k, v := "my-key", "my-value"
|
||||||
sc := NewSimpleCondition("/abc/s")
|
sc := NewSimpleCondition("/abc/s")
|
||||||
res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value"))
|
res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value"))
|
||||||
assert.True(t, res)
|
assert.True(t, res)
|
||||||
|
|
||||||
sc = NewSimpleCondition("/abc/s", WithQuery(k, v))
|
sc = NewSimpleCondition("/abc/s", WithQuery(k, v))
|
||||||
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value"))
|
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value"))
|
||||||
assert.True(t, res)
|
assert.True(t, res)
|
||||||
|
|
||||||
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-valuesss"))
|
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss"))
|
||||||
assert.False(t, res)
|
assert.False(t, res)
|
||||||
|
|
||||||
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key-a=my-value"))
|
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value"))
|
||||||
assert.False(t, res)
|
assert.False(t, res)
|
||||||
|
|
||||||
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello"))
|
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello"))
|
||||||
assert.True(t, res)
|
assert.True(t, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSimpleCondition_MatchHeader(t *testing.T) {
|
func TestSimpleCondition_MatchHeader(t *testing.T) {
|
||||||
k, v := "my-header", "my-header-value"
|
k, v := "my-header", "my-header-value"
|
||||||
sc := NewSimpleCondition("/abc/s")
|
sc := NewSimpleCondition("/abc/s")
|
||||||
req := Get("http://localhost:8080/abc/s")
|
req := httplib.Get("http://localhost:8080/abc/s")
|
||||||
assert.True(t, sc.Match(context.Background(), req))
|
assert.True(t, sc.Match(context.Background(), req))
|
||||||
|
|
||||||
req = Get("http://localhost:8080/abc/s")
|
req = httplib.Get("http://localhost:8080/abc/s")
|
||||||
req.Header(k, v)
|
req.Header(k, v)
|
||||||
assert.True(t, sc.Match(context.Background(), req))
|
assert.True(t, sc.Match(context.Background(), req))
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ func TestSimpleCondition_MatchHeader(t *testing.T) {
|
|||||||
func TestSimpleCondition_MatchBodyField(t *testing.T) {
|
func TestSimpleCondition_MatchBodyField(t *testing.T) {
|
||||||
|
|
||||||
sc := NewSimpleCondition("/abc/s")
|
sc := NewSimpleCondition("/abc/s")
|
||||||
req := Post("http://localhost:8080/abc/s")
|
req := httplib.Post("http://localhost:8080/abc/s")
|
||||||
|
|
||||||
assert.True(t, sc.Match(context.Background(), req))
|
assert.True(t, sc.Match(context.Background(), req))
|
||||||
|
|
||||||
@ -102,7 +103,7 @@ func TestSimpleCondition_MatchBodyField(t *testing.T) {
|
|||||||
|
|
||||||
func TestSimpleCondition_Match(t *testing.T) {
|
func TestSimpleCondition_Match(t *testing.T) {
|
||||||
sc := NewSimpleCondition("/abc/s")
|
sc := NewSimpleCondition("/abc/s")
|
||||||
req := Post("http://localhost:8080/abc/s")
|
req := httplib.Post("http://localhost:8080/abc/s")
|
||||||
|
|
||||||
assert.True(t, sc.Match(context.Background(), req))
|
assert.True(t, sc.Match(context.Background(), req))
|
||||||
|
|
||||||
@ -115,9 +116,9 @@ func TestSimpleCondition_Match(t *testing.T) {
|
|||||||
|
|
||||||
func TestSimpleCondition_MatchPathReg(t *testing.T) {
|
func TestSimpleCondition_MatchPathReg(t *testing.T) {
|
||||||
sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`))
|
sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`))
|
||||||
req := Post("http://localhost:8080/abc/s")
|
req := httplib.Post("http://localhost:8080/abc/s")
|
||||||
assert.True(t, sc.Match(context.Background(), req))
|
assert.True(t, sc.Match(context.Background(), req))
|
||||||
|
|
||||||
req = Post("http://localhost:8080/abcd/s")
|
req = httplib.Post("http://localhost:8080/abcd/s")
|
||||||
assert.False(t, sc.Match(context.Background(), req))
|
assert.False(t, sc.Match(context.Background(), req))
|
||||||
}
|
}
|
||||||
@ -12,12 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockResponse will return mock response if find any suitable mock data
|
// MockResponse will return mock response if find any suitable mock data
|
||||||
@ -32,13 +33,10 @@ func NewMockResponseFilter() *MockResponseFilter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockResponseFilter) FilterChain(next Filter) Filter {
|
func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter {
|
||||||
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||||
|
|
||||||
ms := mockFromCtx(ctx)
|
ms := mockFromCtx(ctx)
|
||||||
ms = append(ms, m.ms...)
|
ms = append(ms, m.ms...)
|
||||||
|
|
||||||
fmt.Printf("url: %s, mock: %d \n", req.url, len(ms))
|
|
||||||
for _, mock := range ms {
|
for _, mock := range ms {
|
||||||
if mock.cond.Match(ctx, req) {
|
if mock.cond.Match(ctx, req) {
|
||||||
return mock.resp, mock.err
|
return mock.resp, mock.err
|
||||||
@ -12,20 +12,22 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMockResponseFilter_FilterChain(t *testing.T) {
|
func TestMockResponseFilter_FilterChain(t *testing.T) {
|
||||||
req := Get("http://localhost:8080/abc/s")
|
req := httplib.Get("http://localhost:8080/abc/s")
|
||||||
ft := NewMockResponseFilter()
|
ft := NewMockResponseFilter()
|
||||||
|
|
||||||
expectedResp := NewHttpResponseWithJsonBody(`{}`)
|
expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`)
|
||||||
expectedErr := errors.New("expected error")
|
expectedErr := errors.New("expected error")
|
||||||
ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr)
|
ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr)
|
||||||
|
|
||||||
@ -35,16 +37,16 @@ func TestMockResponseFilter_FilterChain(t *testing.T) {
|
|||||||
assert.Equal(t, expectedErr, err)
|
assert.Equal(t, expectedErr, err)
|
||||||
assert.Equal(t, expectedResp, resp)
|
assert.Equal(t, expectedResp, resp)
|
||||||
|
|
||||||
req = Get("http://localhost:8080/abcd/s")
|
req = httplib.Get("http://localhost:8080/abcd/s")
|
||||||
req.AddFilters(ft.FilterChain)
|
req.AddFilters(ft.FilterChain)
|
||||||
|
|
||||||
resp, err = req.DoRequest()
|
resp, err = req.DoRequest()
|
||||||
assert.NotEqual(t, expectedErr, err)
|
assert.NotEqual(t, expectedErr, err)
|
||||||
assert.NotEqual(t, expectedResp, resp)
|
assert.NotEqual(t, expectedResp, resp)
|
||||||
|
|
||||||
req = Get("http://localhost:8080/abc/s")
|
req = httplib.Get("http://localhost:8080/abc/s")
|
||||||
req.AddFilters(ft.FilterChain)
|
req.AddFilters(ft.FilterChain)
|
||||||
expectedResp1 := NewHttpResponseWithJsonBody(map[string]string{})
|
expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{})
|
||||||
expectedErr1 := errors.New("expected error")
|
expectedErr1 := errors.New("expected error")
|
||||||
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
|
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
|
||||||
|
|
||||||
@ -52,7 +54,7 @@ func TestMockResponseFilter_FilterChain(t *testing.T) {
|
|||||||
assert.Equal(t, expectedErr, err)
|
assert.Equal(t, expectedErr, err)
|
||||||
assert.Equal(t, expectedResp, resp)
|
assert.Equal(t, expectedResp, resp)
|
||||||
|
|
||||||
req = Get("http://localhost:8080/abc/abs/bbc")
|
req = httplib.Get("http://localhost:8080/abc/abs/bbc")
|
||||||
req.AddFilters(ft.FilterChain)
|
req.AddFilters(ft.FilterChain)
|
||||||
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
|
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
|
||||||
resp, err = req.DoRequest()
|
resp, err = req.DoRequest()
|
||||||
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package httplib
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -21,16 +21,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/httplib"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStartMock(t *testing.T) {
|
func TestStartMock(t *testing.T) {
|
||||||
|
|
||||||
defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain}
|
// httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain}
|
||||||
|
|
||||||
stub := StartMock()
|
stub := StartMock()
|
||||||
// defer stub.Clear()
|
// defer stub.Clear()
|
||||||
|
|
||||||
expectedResp := NewHttpResponseWithJsonBody([]byte(`{}`))
|
expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`))
|
||||||
expectedErr := errors.New("expected err")
|
expectedErr := errors.New("expected err")
|
||||||
|
|
||||||
stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr)
|
stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr)
|
||||||
@ -45,14 +47,14 @@ func TestStartMock(t *testing.T) {
|
|||||||
// TestStartMock_Isolation Test StartMock that
|
// TestStartMock_Isolation Test StartMock that
|
||||||
// mock only work for this request
|
// mock only work for this request
|
||||||
func TestStartMock_Isolation(t *testing.T) {
|
func TestStartMock_Isolation(t *testing.T) {
|
||||||
defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain}
|
// httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain}
|
||||||
// setup global stub
|
// setup global stub
|
||||||
stub := StartMock()
|
stub := StartMock()
|
||||||
globalMockResp := NewHttpResponseWithJsonBody([]byte(`{}`))
|
globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`))
|
||||||
globalMockErr := errors.New("expected err")
|
globalMockErr := errors.New("expected err")
|
||||||
stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr)
|
stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr)
|
||||||
|
|
||||||
expectedResp := NewHttpResponseWithJsonBody(struct {
|
expectedResp := httplib.NewHttpResponseWithJsonBody(struct {
|
||||||
A string `json:"a"`
|
A string `json:"a"`
|
||||||
}{
|
}{
|
||||||
A: "aaa",
|
A: "aaa",
|
||||||
@ -67,9 +69,9 @@ func TestStartMock_Isolation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) {
|
func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) {
|
||||||
return Get("http://localhost:7777/abc").DoRequestWithCtx(ctx)
|
return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OriginalCodeUsingHttplib() (*http.Response, error){
|
func OriginalCodeUsingHttplib() (*http.Response, error){
|
||||||
return Get("http://localhost:7777/abc").DoRequest()
|
return httplib.Get("http://localhost:7777/abc").DoRequest()
|
||||||
}
|
}
|
||||||
17
client/httplib/module.go
Normal file
17
client/httplib/module.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright 2021 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httplib
|
||||||
|
|
||||||
|
const moduleName = "httplib"
|
||||||
81
client/httplib/setting.go
Normal file
81
client/httplib/setting.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package httplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BeegoHTTPSettings is the http.Client setting
|
||||||
|
type BeegoHTTPSettings struct {
|
||||||
|
ShowDebug bool
|
||||||
|
UserAgent string
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
ReadWriteTimeout time.Duration
|
||||||
|
TLSClientConfig *tls.Config
|
||||||
|
Proxy func(*http.Request) (*url.URL, error)
|
||||||
|
Transport http.RoundTripper
|
||||||
|
CheckRedirect func(req *http.Request, via []*http.Request) error
|
||||||
|
EnableCookie bool
|
||||||
|
Gzip bool
|
||||||
|
DumpBody bool
|
||||||
|
Retries int // if set to -1 means will retry forever
|
||||||
|
RetryDelay time.Duration
|
||||||
|
FilterChains []FilterChain
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDefaultCookie creates a global cookiejar to store cookies.
|
||||||
|
func createDefaultCookie() {
|
||||||
|
settingMutex.Lock()
|
||||||
|
defer settingMutex.Unlock()
|
||||||
|
defaultCookieJar, _ = cookiejar.New(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultSetting overwrites default settings
|
||||||
|
// Keep in mind that when you invoke the SetDefaultSetting
|
||||||
|
// some methods invoked before SetDefaultSetting
|
||||||
|
func SetDefaultSetting(setting BeegoHTTPSettings) {
|
||||||
|
settingMutex.Lock()
|
||||||
|
defer settingMutex.Unlock()
|
||||||
|
defaultSetting = setting
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultSetting = BeegoHTTPSettings{
|
||||||
|
UserAgent: "beegoServer",
|
||||||
|
ConnectTimeout: 60 * time.Second,
|
||||||
|
ReadWriteTimeout: 60 * time.Second,
|
||||||
|
Gzip: true,
|
||||||
|
DumpBody: true,
|
||||||
|
FilterChains: make([]FilterChain, 0, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultCookieJar http.CookieJar
|
||||||
|
var settingMutex sync.Mutex
|
||||||
|
|
||||||
|
// AddDefaultFilter add a new filter into defaultSetting
|
||||||
|
// Be careful about using this method if you invoke SetDefaultSetting somewhere
|
||||||
|
func AddDefaultFilter(fc FilterChain) {
|
||||||
|
settingMutex.Lock()
|
||||||
|
defer settingMutex.Unlock()
|
||||||
|
if defaultSetting.FilterChains == nil {
|
||||||
|
defaultSetting.FilterChains = make([]FilterChain, 0, 4)
|
||||||
|
}
|
||||||
|
defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc)
|
||||||
|
}
|
||||||
6
client/orm/clauses/const.go
Normal file
6
client/orm/clauses/const.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package clauses
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExprSep = "__"
|
||||||
|
ExprDot = "."
|
||||||
|
)
|
||||||
103
client/orm/clauses/order_clause/order.go
Normal file
103
client/orm/clauses/order_clause/order.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package order_clause
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sort int8
|
||||||
|
|
||||||
|
const (
|
||||||
|
None Sort = 0
|
||||||
|
Ascending Sort = 1
|
||||||
|
Descending Sort = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Option func(order *Order)
|
||||||
|
|
||||||
|
type Order struct {
|
||||||
|
column string
|
||||||
|
sort Sort
|
||||||
|
isRaw bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func Clause(options ...Option) *Order {
|
||||||
|
o := &Order{}
|
||||||
|
for _, option := range options {
|
||||||
|
option(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) GetColumn() string {
|
||||||
|
return o.column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) GetSort() Sort {
|
||||||
|
return o.sort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) SortString() string {
|
||||||
|
switch o.GetSort() {
|
||||||
|
case Ascending:
|
||||||
|
return "ASC"
|
||||||
|
case Descending:
|
||||||
|
return "DESC"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ``
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) IsRaw() bool {
|
||||||
|
return o.isRaw
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseOrder(expressions ...string) []*Order {
|
||||||
|
var orders []*Order
|
||||||
|
for _, expression := range expressions {
|
||||||
|
sort := Ascending
|
||||||
|
column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot)
|
||||||
|
if column[0] == '-' {
|
||||||
|
sort = Descending
|
||||||
|
column = column[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
orders = append(orders, &Order{
|
||||||
|
column: column,
|
||||||
|
sort: sort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return orders
|
||||||
|
}
|
||||||
|
|
||||||
|
func Column(column string) Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sort(sort Sort) Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.sort = sort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortAscending() Option {
|
||||||
|
return sort(Ascending)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortDescending() Option {
|
||||||
|
return sort(Descending)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortNone() Option {
|
||||||
|
return sort(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Raw() Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.isRaw = true
|
||||||
|
}
|
||||||
|
}
|
||||||
144
client/orm/clauses/order_clause/order_test.go
Normal file
144
client/orm/clauses/order_clause/order_test.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package order_clause
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClause(t *testing.T) {
|
||||||
|
var (
|
||||||
|
column = `a`
|
||||||
|
)
|
||||||
|
|
||||||
|
o := Clause(
|
||||||
|
Column(column),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetColumn() != column {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortAscending(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortAscending(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetSort() != Ascending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortDescending(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortDescending(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortNone(t *testing.T) {
|
||||||
|
o1 := Clause(
|
||||||
|
SortNone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o1.GetSort() != None {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause()
|
||||||
|
|
||||||
|
if o2.GetSort() != None {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaw(t *testing.T) {
|
||||||
|
o1 := Clause()
|
||||||
|
|
||||||
|
if o1.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause(
|
||||||
|
Raw(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !o2.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumn(t *testing.T) {
|
||||||
|
o1 := Clause(
|
||||||
|
Column(`aaa`),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o1.GetColumn() != `aaa` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOrder(t *testing.T) {
|
||||||
|
orders := ParseOrder(
|
||||||
|
`-user__status`,
|
||||||
|
`status`,
|
||||||
|
`user__status`,
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Log(orders)
|
||||||
|
|
||||||
|
if orders[0].GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[0].GetColumn() != `user.status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[1].GetColumn() != `status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[1].GetSort() != Ascending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[2].GetColumn() != `user.status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_GetColumn(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
Column(`user__id`),
|
||||||
|
)
|
||||||
|
if o.GetColumn() != `user.id` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_GetSort(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortDescending(),
|
||||||
|
)
|
||||||
|
if o.GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_IsRaw(t *testing.T) {
|
||||||
|
o1 := Clause()
|
||||||
|
if o1.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause(
|
||||||
|
Raw(),
|
||||||
|
)
|
||||||
|
if !o2.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
fmt.Printf(" %s\n", err.Error())
|
fmt.Printf(" %s\n", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
for i, mi := range modelCache.allOrdered() {
|
for i, mi := range modelCache.allOrdered() {
|
||||||
|
|
||||||
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
|
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
|
||||||
@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var fields []*fieldInfo
|
var fields []*fieldInfo
|
||||||
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
|
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if d.rtOnError {
|
if d.rtOnError {
|
||||||
return err
|
return err
|
||||||
@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, idx := range indexes[mi.table] {
|
for _, idx := range indexes[mi.table] {
|
||||||
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
|
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
|
||||||
if !d.noInfo {
|
if !d.noInfo {
|
||||||
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||||
}
|
}
|
||||||
|
|||||||
140
client/orm/db.go
140
client/orm/db.go
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create insert sql preparation statement object.
|
// create insert sql preparation statement object.
|
||||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
||||||
@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
|
|
||||||
d.ins.HasReturningID(mi, &query)
|
d.ins.HasReturningID(mi, &query)
|
||||||
|
|
||||||
stmt, err := q.Prepare(query)
|
stmt, err := q.PrepareContext(ctx, query)
|
||||||
return stmt, query, err
|
return stmt, query, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert struct with prepared statement and given struct reflect value.
|
// insert struct with prepared statement and given struct reflect value.
|
||||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
res, err := stmt.Exec(values...)
|
res, err := stmt.ExecContext(ctx, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.LastInsertId()
|
return res.LastInsertId()
|
||||||
}
|
}
|
||||||
@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// query sql ,read records and persist in dbBaser.
|
// query sql ,read records and persist in dbBaser.
|
||||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
||||||
var whereCols []string
|
var whereCols []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
row := q.QueryRow(query, args...)
|
row := q.QueryRowContext(ctx, query, args...)
|
||||||
if err := row.Scan(refs...); err != nil {
|
if err := row.Scan(refs...); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return ErrNoRows
|
return ErrNoRows
|
||||||
@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
names := make([]string, 0, len(mi.fields.dbcols))
|
names := make([]string, 0, len(mi.fields.dbcols))
|
||||||
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := d.InsertValue(q, mi, false, names, values)
|
id, err := d.InsertValue(ctx, q, mi, false, names, values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(autoFields) > 0 {
|
if len(autoFields) > 0 {
|
||||||
err = d.ins.setval(q, mi, autoFields)
|
err = d.ins.setval(ctx, q, mi, autoFields)
|
||||||
}
|
}
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-insert sql with given slice struct reflect.Value.
|
// multi-insert sql with given slice struct reflect.Value.
|
||||||
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||||
var (
|
var (
|
||||||
cnt int64
|
cnt int64
|
||||||
nums int
|
nums int
|
||||||
@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
}
|
}
|
||||||
|
|
||||||
if i > 1 && i%bulk == 0 || length == i {
|
if i > 1 && i%bulk == 0 || length == i {
|
||||||
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
if len(autoFields) > 0 {
|
if len(autoFields) > 0 {
|
||||||
err = d.ins.setval(q, mi, autoFields)
|
err = d.ins.setval(ctx, q, mi, autoFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cnt, err
|
return cnt, err
|
||||||
@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
|
|
||||||
// execute insert sql with given struct and given values.
|
// execute insert sql with given struct and given values.
|
||||||
// insert the given values, not the field values in struct.
|
// insert the given values, not the field values in struct.
|
||||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
// InsertOrUpdate a row
|
// InsertOrUpdate a row
|
||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
args0 := ""
|
args0 := ""
|
||||||
iouStr := ""
|
iouStr := ""
|
||||||
argsMap := map[string]string{}
|
argsMap := map[string]string{}
|
||||||
@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
|
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
|
||||||
@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute update sql dbQuerier with given struct reflect.Value.
|
// execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
res, err := q.Exec(query, setValues...)
|
res, err := q.ExecContext(ctx, query, setValues...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
// execute delete sql dbQuerier with given struct reflect.Value.
|
// execute delete sql dbQuerier with given struct reflect.Value.
|
||||||
// delete index is pk.
|
// delete index is pk.
|
||||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
var whereCols []string
|
var whereCols []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
// if specify cols length > 0, then use it for where condition.
|
// if specify cols length > 0, then use it for where condition.
|
||||||
@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
|
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
res, err := q.Exec(query, args...)
|
res, err := q.ExecContext(ctx, query, args...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
num, err := res.RowsAffected()
|
num, err := res.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
|
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := d.deleteRels(q, mi, args, tz)
|
err := d.deleteRels(ctx, q, mi, args, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
// update table-related record by querySet.
|
// update table-related record by querySet.
|
||||||
// need querySet not struct reflect.Value to update related records.
|
// need querySet not struct reflect.Value to update related records.
|
||||||
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
||||||
columns := make([]string, 0, len(params))
|
columns := make([]string, 0, len(params))
|
||||||
values := make([]interface{}, 0, len(params))
|
values := make([]interface{}, 0, len(params))
|
||||||
for col, val := range params {
|
for col, val := range params {
|
||||||
@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
var err error
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
var res sql.Result
|
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
res, err = q.ExecContext(qs.ctx, query, values...)
|
|
||||||
} else {
|
|
||||||
res, err = q.Exec(query, values...)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
|
|
||||||
// delete related records.
|
// delete related records.
|
||||||
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
||||||
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
||||||
for _, fi := range mi.fields.fieldsReverse {
|
for _, fi := range mi.fields.fieldsReverse {
|
||||||
fi = fi.reverseFieldInfo
|
fi = fi.reverseFieldInfo
|
||||||
switch fi.onDelete {
|
switch fi.onDelete {
|
||||||
case odCascade:
|
case odCascade:
|
||||||
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||||
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
|
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
|||||||
if fi.onDelete == odSetDefault {
|
if fi.onDelete == odSetDefault {
|
||||||
params[fi.column] = fi.initial.String()
|
params[fi.column] = fi.initial.String()
|
||||||
}
|
}
|
||||||
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
|
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// delete table-related records.
|
// delete table-related records.
|
||||||
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.skipEnd = true
|
tables.skipEnd = true
|
||||||
|
|
||||||
@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
var rs *sql.Rows
|
||||||
r, err := q.Query(query, args...)
|
r, err := q.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
|
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
var res sql.Result
|
res, err := q.ExecContext(ctx, query, args...)
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
res, err = q.ExecContext(qs.ctx, query, args...)
|
|
||||||
} else {
|
|
||||||
res, err = q.Exec(query, args...)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
num, err := res.RowsAffected()
|
num, err := res.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if num > 0 {
|
if num > 0 {
|
||||||
err := d.deleteRels(q, mi, args, tz)
|
err := d.deleteRels(ctx, q, mi, args, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
// read related records.
|
// read related records.
|
||||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||||
|
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
|
|
||||||
errTyp := true
|
unregister := true
|
||||||
one := true
|
one := true
|
||||||
isPtr := true
|
isPtr := true
|
||||||
|
name := ""
|
||||||
|
|
||||||
if val.Kind() == reflect.Ptr {
|
if val.Kind() == reflect.Ptr {
|
||||||
fn := ""
|
fn := ""
|
||||||
@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
isPtr = false
|
isPtr = false
|
||||||
fn = getFullName(typ)
|
fn = getFullName(typ)
|
||||||
|
name = getTableName(reflect.New(typ))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fn = getFullName(ind.Type())
|
fn = getFullName(ind.Type())
|
||||||
|
name = getTableName(ind)
|
||||||
}
|
}
|
||||||
errTyp = fn != mi.fullName
|
unregister = fn != mi.fullName
|
||||||
}
|
}
|
||||||
|
|
||||||
if errTyp {
|
if unregister {
|
||||||
if one {
|
RegisterModel(container)
|
||||||
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
|
|
||||||
} else {
|
|
||||||
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rlimit := qs.limit
|
rlimit := qs.limit
|
||||||
@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
if qs.distinct {
|
if qs.distinct {
|
||||||
sqlSelect += " DISTINCT"
|
sqlSelect += " DISTINCT"
|
||||||
}
|
}
|
||||||
|
if qs.aggregate != "" {
|
||||||
|
sels = qs.aggregate
|
||||||
|
}
|
||||||
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
|
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
|
||||||
sqlSelect, sels, Q, mi.table, Q,
|
sqlSelect, sels, Q, mi.table, Q,
|
||||||
specifyIndexes, join, where, groupBy, orderBy, limit)
|
specifyIndexes, join, where, groupBy, orderBy, limit)
|
||||||
@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
rs, err := q.QueryContext(ctx, query, args...)
|
||||||
var err error
|
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
rs, err = q.QueryContext(qs.ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rs, err = q.Query(query, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
|
slice := ind
|
||||||
|
if unregister {
|
||||||
|
mi, _ = modelCache.get(name)
|
||||||
|
tCols = mi.fields.dbcols
|
||||||
|
colsNum = len(tCols)
|
||||||
}
|
}
|
||||||
|
|
||||||
refs := make([]interface{}, colsNum)
|
refs := make([]interface{}, colsNum)
|
||||||
@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
var ref interface{}
|
var ref interface{}
|
||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rs.Close()
|
|
||||||
|
|
||||||
slice := ind
|
|
||||||
|
|
||||||
var cnt int64
|
var cnt int64
|
||||||
for rs.Next() {
|
for rs.Next() {
|
||||||
if one && cnt == 0 || !one {
|
if one && cnt == 0 || !one {
|
||||||
@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// excute count sql and return count result int64.
|
// excute count sql and return count result int64.
|
||||||
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
|
|
||||||
@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var row *sql.Row
|
row := q.QueryRowContext(ctx, query, args...)
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
row = q.QueryRowContext(qs.ctx, query, args...)
|
|
||||||
} else {
|
|
||||||
row = q.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
err = row.Scan(&cnt)
|
err = row.Scan(&cnt)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1649,7 +1631,7 @@ setValue:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// query sql, read values , save to *[]ParamList.
|
// query sql, read values , save to *[]ParamList.
|
||||||
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
maps []Params
|
maps []Params
|
||||||
@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
rs, err := q.Query(query, args...)
|
rs, err := q.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sync auto key
|
// sync auto key
|
||||||
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get all cloumns in table.
|
// get all cloumns in table.
|
||||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
columns := make(map[string][3]string)
|
columns := make(map[string][3]string)
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return columns, err
|
return columns, err
|
||||||
}
|
}
|
||||||
@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// not implement.
|
// not implement.
|
||||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute sql to check index exist.
|
// execute sql to check index exist.
|
||||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
// Add "`" for mysql sql building
|
// Add "`" for mysql sql building
|
||||||
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
var iouStr string
|
var iouStr string
|
||||||
argsMap := map[string]string{}
|
argsMap := map[string]string{}
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index is exist
|
// check index is exist
|
||||||
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
|
row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
|
||||||
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
|
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
|
||||||
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
|
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
|
||||||
|
|
||||||
@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
|
|||||||
|
|
||||||
// execute insert sql with given struct and given values.
|
// execute insert sql with given struct and given values.
|
||||||
// insert the given values, not the field values in struct.
|
// insert the given values, not the field values in struct.
|
||||||
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
|
|||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sync auto key
|
// sync auto key
|
||||||
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
if len(autoFields) == 0 {
|
if len(autoFields) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string
|
|||||||
mi.table, name,
|
mi.table, name,
|
||||||
Q, name, Q,
|
Q, name, Q,
|
||||||
Q, mi.table, Q)
|
Q, mi.table, Q)
|
||||||
if _, err := db.Exec(query); err != nil {
|
if _, err := db.ExecContext(ctx, query); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index exist in postgresql.
|
// check index exist in postgresql.
|
||||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||||
row := db.QueryRow(query)
|
row := db.QueryRowContext(ctx, query)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -73,11 +74,11 @@ type dbBaseSqlite struct {
|
|||||||
var _ dbBaser = new(dbBaseSqlite)
|
var _ dbBaser = new(dbBaseSqlite)
|
||||||
|
|
||||||
// override base db read for update behavior as SQlite does not support syntax
|
// override base db read for update behavior as SQlite does not support syntax
|
||||||
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
||||||
if isForUpdate {
|
if isForUpdate {
|
||||||
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
|
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
|
||||||
}
|
}
|
||||||
return d.dbBase.Read(q, mi, ind, tz, cols, false)
|
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get sqlite operator.
|
// get sqlite operator.
|
||||||
@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get columns in sqlite.
|
// get columns in sqlite.
|
||||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index exist in sqlite.
|
// check index exist in sqlite.
|
||||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,8 @@ package orm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generate order sql.
|
// generate order sql.
|
||||||
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
|
||||||
if len(orders) == 0 {
|
if len(orders) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
|||||||
|
|
||||||
orderSqls := make([]string, 0, len(orders))
|
orderSqls := make([]string, 0, len(orders))
|
||||||
for _, order := range orders {
|
for _, order := range orders {
|
||||||
asc := "ASC"
|
column := order.GetColumn()
|
||||||
if order[0] == '-' {
|
clause := strings.Split(column, clauses.ExprDot)
|
||||||
asc = "DESC"
|
|
||||||
order = order[1:]
|
|
||||||
}
|
|
||||||
exprs := strings.Split(order, ExprSep)
|
|
||||||
|
|
||||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
if order.IsRaw() {
|
||||||
|
if len(clause) == 2 {
|
||||||
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
|
||||||
|
} else if len(clause) == 1 {
|
||||||
|
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
index, _, fi, suc := t.parseExprs(t.mi, clause)
|
||||||
if !suc {
|
if !suc {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||||
}
|
}
|
||||||
|
|
||||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute sql to check index exist.
|
// execute sql to check index exist.
|
||||||
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
|
|||||||
@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) {
|
|||||||
|
|
||||||
assert.Nil(t, o.Driver())
|
assert.Nil(t, o.Driver())
|
||||||
|
|
||||||
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
|
|
||||||
assert.Nil(t, o.QueryM2M(nil, ""))
|
assert.Nil(t, o.QueryM2M(nil, ""))
|
||||||
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
||||||
assert.Nil(t, o.Read(nil))
|
assert.Nil(t, o.Read(nil))
|
||||||
@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, int64(0), i)
|
assert.Equal(t, int64(0), i)
|
||||||
|
|
||||||
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
|
|
||||||
assert.Nil(t, o.QueryTable(nil))
|
assert.Nil(t, o.QueryTable(nil))
|
||||||
|
|
||||||
assert.Nil(t, o.Read(nil))
|
assert.Nil(t, o.Read(nil))
|
||||||
|
|||||||
@ -27,7 +27,7 @@ import (
|
|||||||
// this Filter's behavior looks a little bit strange
|
// this Filter's behavior looks a little bit strange
|
||||||
// for example:
|
// for example:
|
||||||
// if we want to trace QuerySetter
|
// if we want to trace QuerySetter
|
||||||
// actually we trace invoking "QueryTable" and "QueryTableWithCtx"
|
// actually we trace invoking "QueryTable"
|
||||||
// the method Begin*, Commit and Rollback are ignored.
|
// the method Begin*, Commit and Rollback are ignored.
|
||||||
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
|
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
@ -31,17 +32,20 @@ import (
|
|||||||
// this Filter's behavior looks a little bit strange
|
// this Filter's behavior looks a little bit strange
|
||||||
// for example:
|
// for example:
|
||||||
// if we want to records the metrics of QuerySetter
|
// if we want to records the metrics of QuerySetter
|
||||||
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx"
|
// actually we only records metrics of invoking "QueryTable"
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
summaryVec prometheus.ObserverVec
|
|
||||||
AppName string
|
AppName string
|
||||||
ServerName string
|
ServerName string
|
||||||
RunMode string
|
RunMode string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var summaryVec prometheus.ObserverVec
|
||||||
|
var initSummaryVec sync.Once
|
||||||
|
|
||||||
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
||||||
|
|
||||||
builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
initSummaryVec.Do(func() {
|
||||||
|
summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||||
Name: "beego",
|
Name: "beego",
|
||||||
Subsystem: "orm_operation",
|
Subsystem: "orm_operation",
|
||||||
ConstLabels: map[string]string{
|
ConstLabels: map[string]string{
|
||||||
@ -51,6 +55,8 @@ func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
|||||||
},
|
},
|
||||||
Help: "The statics info for orm operation",
|
Help: "The statics info for orm operation",
|
||||||
}, []string{"method", "name", "insideTx", "txName"})
|
}, []string{"method", "name", "insideTx", "txName"})
|
||||||
|
prometheus.MustRegister(summaryVec)
|
||||||
|
})
|
||||||
|
|
||||||
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@ -74,12 +80,12 @@ func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocati
|
|||||||
builder.reportTxn(ctx, inv)
|
builder.reportTxn(ctx, inv)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
builder.summaryVec.WithLabelValues(inv.Method, inv.GetTableName(),
|
summaryVec.WithLabelValues(inv.Method, inv.GetTableName(),
|
||||||
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
|
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) {
|
func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) {
|
||||||
dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond
|
dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond
|
||||||
builder.summaryVec.WithLabelValues(inv.Method, inv.TxName,
|
summaryVec.WithLabelValues(inv.Method, inv.TxName,
|
||||||
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
|
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ func TestFilterChainBuilder_FilterChain1(t *testing.T) {
|
|||||||
builder := &FilterChainBuilder{}
|
builder := &FilterChainBuilder{}
|
||||||
filter := builder.FilterChain(next)
|
filter := builder.FilterChain(next)
|
||||||
|
|
||||||
assert.NotNil(t, builder.summaryVec)
|
assert.NotNil(t, summaryVec)
|
||||||
assert.NotNil(t, filter)
|
assert.NotNil(t, filter)
|
||||||
|
|
||||||
inv := &orm.Invocation{}
|
inv := &orm.Invocation{}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
"github.com/beego/beego/v2/core/utils"
|
"github.com/beego/beego/v2/core/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
|
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||||
return f.QueryM2MWithCtx(context.Background(), md, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
|
||||||
|
|
||||||
mi, _ := modelCache.getByMd(md)
|
mi, _ := modelCache.getByMd(md)
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "QueryM2MWithCtx",
|
Method: "QueryM2M",
|
||||||
Args: []interface{}{md, name},
|
Args: []interface{}{md, name},
|
||||||
Md: md,
|
Md: md,
|
||||||
mi: mi,
|
mi: mi,
|
||||||
InsideTx: f.insideTx,
|
InsideTx: f.insideTx,
|
||||||
TxStartTime: f.txStartTime,
|
TxStartTime: f.txStartTime,
|
||||||
f: func(c context.Context) []interface{} {
|
f: func(c context.Context) []interface{} {
|
||||||
res := f.ormer.QueryM2MWithCtx(c, md, name)
|
res := f.ormer.QueryM2M(md, name)
|
||||||
return []interface{}{res}
|
return []interface{}{res}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
res := f.root(ctx, inv)
|
res := f.root(context.Background(), inv)
|
||||||
if res[0] == nil {
|
if res[0] == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return res[0].(QueryM2Mer)
|
return res[0].(QueryM2Mer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
|
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.")
|
||||||
|
return f.QueryM2M(md, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
||||||
var (
|
var (
|
||||||
name string
|
name string
|
||||||
md interface{}
|
md interface{}
|
||||||
@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
|
|||||||
}
|
}
|
||||||
|
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "QueryTableWithCtx",
|
Method: "QueryTable",
|
||||||
Args: []interface{}{ptrStructOrTableName},
|
Args: []interface{}{ptrStructOrTableName},
|
||||||
InsideTx: f.insideTx,
|
InsideTx: f.insideTx,
|
||||||
TxStartTime: f.txStartTime,
|
TxStartTime: f.txStartTime,
|
||||||
Md: md,
|
Md: md,
|
||||||
mi: mi,
|
mi: mi,
|
||||||
f: func(c context.Context) []interface{} {
|
f: func(c context.Context) []interface{} {
|
||||||
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
|
res := f.ormer.QueryTable(ptrStructOrTableName)
|
||||||
return []interface{}{res}
|
return []interface{}{res}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
res := f.root(ctx, inv)
|
res := f.root(context.Background(), inv)
|
||||||
|
|
||||||
if res[0] == nil {
|
if res[0] == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
|
|||||||
return res[0].(QuerySeter)
|
return res[0].(QuerySeter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||||
|
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.")
|
||||||
|
return f.QueryTable(ptrStructOrTableName)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
|
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "DBStats",
|
Method: "DBStats",
|
||||||
|
|||||||
@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
|
|||||||
o := &filterMockOrm{}
|
o := &filterMockOrm{}
|
||||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||||
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
|
assert.Equal(t, "QueryM2M", inv.Method)
|
||||||
assert.Equal(t, 2, len(inv.Args))
|
assert.Equal(t, 2, len(inv.Args))
|
||||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||||
assert.False(t, inv.InsideTx)
|
assert.False(t, inv.InsideTx)
|
||||||
@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) {
|
|||||||
o := &filterMockOrm{}
|
o := &filterMockOrm{}
|
||||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||||
assert.Equal(t, "QueryTableWithCtx", inv.Method)
|
assert.Equal(t, "QueryTable", inv.Method)
|
||||||
assert.Equal(t, 1, len(inv.Args))
|
assert.Equal(t, 1, len(inv.Args))
|
||||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||||
assert.False(t, inv.InsideTx)
|
assert.False(t, inv.InsideTx)
|
||||||
|
|||||||
63
client/orm/mock/condition.go
Normal file
63
client/orm/mock/condition.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mock struct {
|
||||||
|
cond Condition
|
||||||
|
resp []interface{}
|
||||||
|
cb func(inv *orm.Invocation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock {
|
||||||
|
return &Mock{
|
||||||
|
cond: cond,
|
||||||
|
resp: resp,
|
||||||
|
cb: cb,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Condition interface {
|
||||||
|
Match(ctx context.Context, inv *orm.Invocation) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimpleCondition struct {
|
||||||
|
tableName string
|
||||||
|
method string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSimpleCondition(tableName string, methodName string) Condition {
|
||||||
|
return &SimpleCondition{
|
||||||
|
tableName: tableName,
|
||||||
|
method: methodName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool {
|
||||||
|
res := true
|
||||||
|
if len(s.tableName) != 0 {
|
||||||
|
res = res && (s.tableName == inv.GetTableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.method) != 0 {
|
||||||
|
res = res && (s.method == inv.Method)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
41
client/orm/mock/condition_test.go
Normal file
41
client/orm/mock/condition_test.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimpleCondition_Match(t *testing.T) {
|
||||||
|
cond := NewSimpleCondition("", "")
|
||||||
|
res := cond.Match(context.Background(), &orm.Invocation{})
|
||||||
|
assert.True(t, res)
|
||||||
|
cond = NewSimpleCondition("hello", "")
|
||||||
|
assert.False(t, cond.Match(context.Background(), &orm.Invocation{}))
|
||||||
|
|
||||||
|
cond = NewSimpleCondition("", "A")
|
||||||
|
assert.False(t, cond.Match(context.Background(), &orm.Invocation{
|
||||||
|
Method: "B",
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert.True(t, cond.Match(context.Background(), &orm.Invocation{
|
||||||
|
Method: "A",
|
||||||
|
}))
|
||||||
|
}
|
||||||
40
client/orm/mock/context.go
Normal file
40
client/orm/mock/context.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockCtxKeyType string
|
||||||
|
|
||||||
|
const mockCtxKey = mockCtxKeyType("beego-orm-mock")
|
||||||
|
|
||||||
|
func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context {
|
||||||
|
return context.WithValue(ctx, mockCtxKey, mock)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockFromCtx(ctx context.Context) []*Mock {
|
||||||
|
ms := ctx.Value(mockCtxKey)
|
||||||
|
if ms != nil {
|
||||||
|
if res, ok := ms.([]*Mock); ok {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
logs.Error("mockCtxKey found in context, but value is not type []*Mock")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
29
client/orm/mock/context_test.go
Normal file
29
client/orm/mock/context_test.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCtx(t *testing.T) {
|
||||||
|
ms := make([]*Mock, 0, 4)
|
||||||
|
ctx := CtxWithMock(context.Background(), ms...)
|
||||||
|
res := mockFromCtx(ctx)
|
||||||
|
assert.Equal(t, ms, res)
|
||||||
|
}
|
||||||
72
client/orm/mock/mock.go
Normal file
72
client/orm/mock/mock.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stub = newOrmStub()
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
orm.AddGlobalFilterChain(stub.FilterChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stub interface {
|
||||||
|
Mock(m *Mock)
|
||||||
|
Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
type OrmStub struct {
|
||||||
|
ms []*Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func StartMock() Stub {
|
||||||
|
return stub
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOrmStub() *OrmStub {
|
||||||
|
return &OrmStub{
|
||||||
|
ms: make([]*Mock, 0, 4),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OrmStub) Mock(m *Mock) {
|
||||||
|
o.ms = append(o.ms, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OrmStub) Clear() {
|
||||||
|
o.ms = make([]*Mock, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter {
|
||||||
|
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||||
|
|
||||||
|
ms := mockFromCtx(ctx)
|
||||||
|
ms = append(ms, o.ms...)
|
||||||
|
|
||||||
|
for _, mock := range ms {
|
||||||
|
if mock.cond.Match(ctx, inv) {
|
||||||
|
if mock.cb != nil {
|
||||||
|
mock.cb(inv)
|
||||||
|
}
|
||||||
|
return mock.resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next(ctx, inv)
|
||||||
|
}
|
||||||
|
}
|
||||||
162
client/orm/mock/mock_orm.go
Normal file
162
client/orm/mock/mock_orm.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterMockDB("default")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMockDB create an "virtual DB" by using sqllite
|
||||||
|
// you should not
|
||||||
|
func RegisterMockDB(name string) {
|
||||||
|
source := filepath.Join(os.TempDir(), name+".db")
|
||||||
|
_ = orm.RegisterDataBase(name, "sqlite3", source)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTable only check table name
|
||||||
|
func MockTable(tableName string, resp ...interface{}) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, ""), resp, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockMethod only check method name
|
||||||
|
func MockMethod(method string, resp ...interface{}) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition("", method), resp, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockOrmRead support orm.Read and orm.ReadWithCtx
|
||||||
|
// cb is used to mock read data from DB
|
||||||
|
func MockRead(tableName string, cb func(data interface{}), err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) {
|
||||||
|
if cb != nil {
|
||||||
|
cb(inv.Args[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx
|
||||||
|
// cb is used to mock read data from DB
|
||||||
|
func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"),
|
||||||
|
[]interface{}{err},
|
||||||
|
func(inv *orm.Invocation) {
|
||||||
|
cb(inv.Args[0])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx
|
||||||
|
// cb is used to mock read data from DB
|
||||||
|
func MockReadOrCreateWithCtx(tableName string,
|
||||||
|
cb func(data interface{}),
|
||||||
|
insert bool, id int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"),
|
||||||
|
[]interface{}{insert, id, err},
|
||||||
|
func(inv *orm.Invocation) {
|
||||||
|
cb(inv.Args[0])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockInsertWithCtx support Insert and InsertWithCtx
|
||||||
|
func MockInsertWithCtx(tableName string, id int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx
|
||||||
|
func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx
|
||||||
|
func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockUpdateWithCtx support UpdateWithCtx and Update
|
||||||
|
func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDeleteWithCtx support Delete and DeleteWithCtx
|
||||||
|
func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M
|
||||||
|
// Now you may be need to use golang/mock to generate QueryM2M mock instance
|
||||||
|
// Or use DoNothingQueryM2Mer
|
||||||
|
// for example:
|
||||||
|
// post := Post{Id: 4}
|
||||||
|
// m2m := Ormer.QueryM2M(&post, "Tags")
|
||||||
|
// when you write test code:
|
||||||
|
// MockQueryM2MWithCtx("post", "Tags", mockM2Mer)
|
||||||
|
// "post" is the table name of model Post structure
|
||||||
|
// TODO provide orm.QueryM2Mer
|
||||||
|
func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock {
|
||||||
|
return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated
|
||||||
|
func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock {
|
||||||
|
return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable
|
||||||
|
func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockRawWithCtx support RawWithCtx and Raw
|
||||||
|
func MockRawWithCtx(rs orm.RawSeter) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDriver support Driver
|
||||||
|
// func MockDriver(driver orm.Driver) *Mock {
|
||||||
|
// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver})
|
||||||
|
// }
|
||||||
|
|
||||||
|
// MockDBStats support DBStats
|
||||||
|
func MockDBStats(stats *sql.DBStats) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts
|
||||||
|
// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock {
|
||||||
|
// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err})
|
||||||
|
// }
|
||||||
|
|
||||||
|
// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts
|
||||||
|
// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock {
|
||||||
|
// return MockBeginWithCtxAndOpts(txOrm, err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// MockCommit support Commit
|
||||||
|
func MockCommit(err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockRollback support Rollback
|
||||||
|
func MockRollback(err error) *Mock {
|
||||||
|
return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil)
|
||||||
|
}
|
||||||
297
client/orm/mock/mock_orm_test.go
Normal file
297
client/orm/mock/mock_orm_test.go
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
const mockErrorMsg = "mock error"
|
||||||
|
func init() {
|
||||||
|
orm.RegisterModel(&User{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockDBStats(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
stats := &sql.DBStats{}
|
||||||
|
s.Mock(MockDBStats(stats))
|
||||||
|
|
||||||
|
o := orm.NewOrm()
|
||||||
|
|
||||||
|
res := o.DBStats()
|
||||||
|
|
||||||
|
assert.Equal(t, stats, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockDeleteWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
rows, err := o.Delete(&User{})
|
||||||
|
assert.Equal(t, int64(12), rows)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockInsertOrUpdateWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
id, err := o.InsertOrUpdate(&User{})
|
||||||
|
assert.Equal(t, int64(12), id)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockRead(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
err := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, err))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
u := &User{}
|
||||||
|
e := o.Read(u)
|
||||||
|
assert.Equal(t, err, e)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockQueryM2MWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := &DoNothingQueryM2Mer{}
|
||||||
|
s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res := o.QueryM2M(&User{}, "Tags")
|
||||||
|
assert.Equal(t, mock, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockQueryTableWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := &DoNothingQuerySetter{}
|
||||||
|
s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res := o.QueryTable(&User{})
|
||||||
|
assert.Equal(t, mock, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockTable(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockTable((&User{}).TableName(), mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res := o.Read(&User{})
|
||||||
|
assert.Equal(t, mock, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockInsertMultiWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res, err := o.InsertMulti(11, []interface{}{&User{}})
|
||||||
|
assert.Equal(t, int64(12), res)
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockInsertWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res, err := o.Insert(&User{})
|
||||||
|
assert.Equal(t, int64(13), res)
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockUpdateWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res, err := o.Update(&User{})
|
||||||
|
assert.Equal(t, int64(12), res)
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockLoadRelatedWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res, err := o.LoadRelated(&User{}, "T")
|
||||||
|
assert.Equal(t, int64(12), res)
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockMethod(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockMethod("ReadWithCtx", mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
err := o.Read(&User{})
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockReadForUpdateWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
u := &User{}
|
||||||
|
err := o.ReadForUpdate(u)
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockRawWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := &DoNothingRawSetter{}
|
||||||
|
s.Mock(MockRawWithCtx(mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
res := o.Raw("")
|
||||||
|
assert.Equal(t, mock, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMockReadOrCreateWithCtx(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, false, 12, mock))
|
||||||
|
o := orm.NewOrm()
|
||||||
|
u := &User{}
|
||||||
|
inserted, id, err := o.ReadOrCreate(u, "")
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
assert.Equal(t, int64(12), id)
|
||||||
|
assert.False(t, inserted)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionClosure(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, mock))
|
||||||
|
u, err := originalTxUsingClosure()
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionManually(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, mock))
|
||||||
|
u, err := originalTxManually()
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionRollback(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error")))
|
||||||
|
s.Mock(MockRollback(mock))
|
||||||
|
_, err := originalTx()
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionCommit(t *testing.T) {
|
||||||
|
s := StartMock()
|
||||||
|
defer s.Clear()
|
||||||
|
mock := errors.New(mockErrorMsg)
|
||||||
|
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
|
||||||
|
u := data.(*User)
|
||||||
|
u.Name = "Tom"
|
||||||
|
}, nil))
|
||||||
|
s.Mock(MockCommit(mock))
|
||||||
|
u, err := originalTx()
|
||||||
|
assert.Equal(t, mock, err)
|
||||||
|
assert.Equal(t, "Tom", u.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func originalTx() (*User, error) {
|
||||||
|
u := &User{}
|
||||||
|
o := orm.NewOrm()
|
||||||
|
txOrm, _ := o.Begin()
|
||||||
|
err := txOrm.Read(u)
|
||||||
|
if err == nil {
|
||||||
|
err = txOrm.Commit()
|
||||||
|
return u, err
|
||||||
|
} else {
|
||||||
|
err = txOrm.Rollback()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func originalTxManually() (*User, error) {
|
||||||
|
u := &User{}
|
||||||
|
o := orm.NewOrm()
|
||||||
|
txOrm, _ := o.Begin()
|
||||||
|
err := txOrm.Read(u)
|
||||||
|
_ = txOrm.Commit()
|
||||||
|
return u, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func originalTxUsingClosure() (*User, error) {
|
||||||
|
u := &User{}
|
||||||
|
var err error
|
||||||
|
o := orm.NewOrm()
|
||||||
|
_ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error {
|
||||||
|
err = txOrm.Read(u)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return u, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) TableName() string {
|
||||||
|
return "user"
|
||||||
|
}
|
||||||
92
client/orm/mock/mock_queryM2Mer.go
Normal file
92
client/orm/mock/mock_queryM2Mer.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoNothingQueryM2Mer do nothing
|
||||||
|
// use it to build mock orm.QueryM2Mer
|
||||||
|
type DoNothingQueryM2Mer struct {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) Clear() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQueryM2Mer) Count() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryM2MerCondition struct {
|
||||||
|
tableName string
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition {
|
||||||
|
return &QueryM2MerCondition{
|
||||||
|
tableName: tableName,
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool {
|
||||||
|
res := true
|
||||||
|
if len(q.tableName) > 0 {
|
||||||
|
res = res && (q.tableName == inv.GetTableName())
|
||||||
|
}
|
||||||
|
if len(q.name) > 0 {
|
||||||
|
res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string))
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
63
client/orm/mock/mock_queryM2Mer_test.go
Normal file
63
client/orm/mock/mock_queryM2Mer_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDoNothingQueryM2Mer(t *testing.T) {
|
||||||
|
m2m := &DoNothingQueryM2Mer{}
|
||||||
|
|
||||||
|
i, err := m2m.Clear()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = m2m.Count()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = m2m.Add()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = m2m.Remove()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.True(t, m2m.Exist(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewQueryM2MerCondition(t *testing.T) {
|
||||||
|
cond := NewQueryM2MerCondition("", "")
|
||||||
|
res := cond.Match(context.Background(), &orm.Invocation{})
|
||||||
|
assert.True(t, res)
|
||||||
|
cond = NewQueryM2MerCondition("hello", "")
|
||||||
|
assert.False(t, cond.Match(context.Background(), &orm.Invocation{}))
|
||||||
|
|
||||||
|
cond = NewQueryM2MerCondition("", "A")
|
||||||
|
assert.False(t, cond.Match(context.Background(), &orm.Invocation{
|
||||||
|
Args: []interface{}{0, "B"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert.True(t, cond.Match(context.Background(), &orm.Invocation{
|
||||||
|
Args: []interface{}{0, "A"},
|
||||||
|
}))
|
||||||
|
}
|
||||||
183
client/orm/mock/mock_querySetter.go
Normal file
183
client/orm/mock/mock_querySetter.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoNothingQuerySetter do nothing
|
||||||
|
// usually you use this to build your mock QuerySetter
|
||||||
|
type DoNothingQuerySetter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) GetCond() *orm.Condition {
|
||||||
|
return orm.NewCondition()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Count() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Exist() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Delete() (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
74
client/orm/mock/mock_querySetter_test.go
Normal file
74
client/orm/mock/mock_querySetter_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDoNothingQuerySetter(t *testing.T) {
|
||||||
|
setter := &DoNothingQuerySetter{}
|
||||||
|
setter.GroupBy().Filter("").Limit(10).
|
||||||
|
Distinct().Exclude("a").FilterRaw("", "").
|
||||||
|
ForceIndex().ForUpdate().IgnoreIndex().
|
||||||
|
Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex()
|
||||||
|
|
||||||
|
assert.True(t, setter.Exist())
|
||||||
|
err := setter.One(nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
i, err := setter.Count()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.Delete()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.All(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.Update(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.RowsToMap(nil, "", "")
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.RowsToStruct(nil, "", "")
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.Values(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.ValuesFlat(nil, "")
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = setter.ValuesList(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
ins, err := setter.PrepareInsert()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Nil(t, ins)
|
||||||
|
|
||||||
|
assert.NotNil(t, setter.GetCond())
|
||||||
|
}
|
||||||
64
client/orm/mock/mock_rawSetter.go
Normal file
64
client/orm/mock/mock_rawSetter.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DoNothingRawSetter struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) Exec() (sql.Result, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
63
client/orm/mock/mock_rawSetter_test.go
Normal file
63
client/orm/mock/mock_rawSetter_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDoNothingRawSetter(t *testing.T) {
|
||||||
|
rs := &DoNothingRawSetter{}
|
||||||
|
i, err := rs.ValuesList(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = rs.Values(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = rs.ValuesFlat(nil)
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = rs.RowsToStruct(nil, "", "")
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = rs.RowsToMap(nil, "", "")
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
i, err = rs.QueryRows()
|
||||||
|
assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = rs.QueryRow()
|
||||||
|
// assert.Equal(t, int64(0), i)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
s, err := rs.Exec()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Nil(t, s)
|
||||||
|
|
||||||
|
p, err := rs.Prepare()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Nil(t, p)
|
||||||
|
|
||||||
|
rrs := rs.SetArgs()
|
||||||
|
assert.Equal(t, rrs, rs)
|
||||||
|
}
|
||||||
58
client/orm/mock/mock_test.go
Normal file
58
client/orm/mock/mock_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOrmStub_FilterChain(t *testing.T) {
|
||||||
|
os := newOrmStub()
|
||||||
|
inv := &orm.Invocation{
|
||||||
|
Args: []interface{}{10},
|
||||||
|
}
|
||||||
|
i := 1
|
||||||
|
os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||||
|
i++
|
||||||
|
return nil
|
||||||
|
})(context.Background(), inv)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, i)
|
||||||
|
|
||||||
|
m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) {
|
||||||
|
arg := inv.Args[0]
|
||||||
|
j := arg.(int)
|
||||||
|
inv.Args[0] = j + 1
|
||||||
|
})
|
||||||
|
os.Mock(m)
|
||||||
|
|
||||||
|
os.FilterChain(nil)(context.Background(), inv)
|
||||||
|
assert.Equal(t, 11, inv.Args[0])
|
||||||
|
|
||||||
|
inv.Args[0] = 10
|
||||||
|
ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) {
|
||||||
|
arg := inv.Args[0]
|
||||||
|
j := arg.(int)
|
||||||
|
inv.Args[0] = j + 3
|
||||||
|
})
|
||||||
|
|
||||||
|
os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv)
|
||||||
|
assert.Equal(t, 13, inv.Args[0])
|
||||||
|
}
|
||||||
@ -332,10 +332,6 @@ end:
|
|||||||
|
|
||||||
// register register models to model cache
|
// register register models to model cache
|
||||||
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
|
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
|
||||||
if mc.done {
|
|
||||||
err = fmt.Errorf("register must be run before BootStrap")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
|
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if val.Elem().Kind() == reflect.Slice {
|
||||||
|
val = reflect.New(val.Elem().Type().Elem())
|
||||||
|
}
|
||||||
table := getTableName(val)
|
table := getTableName(val)
|
||||||
|
|
||||||
if prefixOrSuffixStr != "" {
|
if prefixOrSuffixStr != "" {
|
||||||
@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := mc.get(table); ok {
|
if _, ok := mc.get(table); ok {
|
||||||
err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mi := newModelInfo(val)
|
mi := newModelInfo(val)
|
||||||
@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mi.fields.pk == nil {
|
|
||||||
err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mi.table = table
|
mi.table = table
|
||||||
|
|||||||
@ -255,6 +255,22 @@ func NewTM() *TM {
|
|||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeptInfo struct {
|
||||||
|
ID int `orm:"column(id)"`
|
||||||
|
Created time.Time `orm:"auto_now_add"`
|
||||||
|
DeptName string
|
||||||
|
EmployeeName string
|
||||||
|
Salary int
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnregisterModel struct {
|
||||||
|
ID int `orm:"column(id)"`
|
||||||
|
Created time.Time `orm:"auto_now_add"`
|
||||||
|
DeptName string
|
||||||
|
EmployeeName string
|
||||||
|
Salary int
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int `orm:"column(id)"`
|
ID int `orm:"column(id)"`
|
||||||
UserName string `orm:"size(30);unique"`
|
UserName string `orm:"size(30);unique"`
|
||||||
|
|||||||
@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string {
|
|||||||
|
|
||||||
// get whether the table needs to be created for the database alias
|
// get whether the table needs to be created for the database alias
|
||||||
func isApplicableTableForDB(val reflect.Value, db string) bool {
|
func isApplicableTableForDB(val reflect.Value, db string) bool {
|
||||||
|
if !val.IsValid() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
fun := val.MethodByName("IsApplicableTableForDB")
|
fun := val.MethodByName("IsApplicableTableForDB")
|
||||||
if fun.IsValid() {
|
if fun.IsValid() {
|
||||||
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})
|
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})
|
||||||
|
|||||||
@ -58,6 +58,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
@ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
||||||
@ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
|
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read a row from the database, or insert one if it doesn't exist
|
// Try to read a row from the database, or insert one if it doesn't exist
|
||||||
@ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo
|
|||||||
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||||
cols = append([]string{col1}, cols...)
|
cols = append([]string{col1}, cols...)
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
|
||||||
if err == ErrNoRows {
|
if err == ErrNoRows {
|
||||||
// Create
|
// Create
|
||||||
id, err := o.InsertWithCtx(ctx, md)
|
id, err := o.InsertWithCtx(ctx, md)
|
||||||
@ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
|
|||||||
for i := 0; i < sind.Len(); i++ {
|
for i := 0; i < sind.Len(); i++ {
|
||||||
ind := reflect.Indirect(sind.Index(i))
|
ind := reflect.Indirect(sind.Index(i))
|
||||||
mi, _ := o.getMiInd(ind.Interface(), false)
|
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
@ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||||
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ)
|
||||||
}
|
}
|
||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
@ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (
|
|||||||
}
|
}
|
||||||
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
|
id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete model in database
|
// delete model in database
|
||||||
@ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
|
num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str
|
|||||||
|
|
||||||
// create a models to models queryer
|
// create a models to models queryer
|
||||||
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
|
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||||
return o.QueryM2MWithCtx(context.Background(), md, name)
|
|
||||||
}
|
|
||||||
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
fi := o.getFieldInfo(mi, name)
|
fi := o.getFieldInfo(mi, name)
|
||||||
|
|
||||||
@ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
|
|||||||
return newQueryM2M(md, o, mi, fi, ind)
|
return newQueryM2M(md, o, mi, fi, ind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
|
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.")
|
||||||
|
return o.QueryM2M(md, name)
|
||||||
|
}
|
||||||
|
|
||||||
// load related models to md model.
|
// load related models to md model.
|
||||||
// args are limit, offset int and order string.
|
// args are limit, offset int and order string.
|
||||||
//
|
//
|
||||||
@ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
|
|||||||
qs.relDepth = relDepth
|
qs.relDepth = relDepth
|
||||||
|
|
||||||
if len(order) > 0 {
|
if len(order) > 0 {
|
||||||
qs.orders = []string{order}
|
qs.orders = order_clause.ParseOrder(order)
|
||||||
}
|
}
|
||||||
|
|
||||||
find := ind.FieldByIndex(fi.fieldIndex)
|
find := ind.FieldByIndex(fi.fieldIndex)
|
||||||
@ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
|
|||||||
// table name can be string or struct.
|
// table name can be string or struct.
|
||||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
|
||||||
}
|
|
||||||
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
|
|
||||||
var name string
|
var name string
|
||||||
if table, ok := ptrStructOrTableName.(string); ok {
|
if table, ok := ptrStructOrTableName.(string); ok {
|
||||||
name = nameStrategyMap[defaultNameStrategy](table)
|
name = nameStrategyMap[defaultNameStrategy](table)
|
||||||
@ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in
|
|||||||
if qs == nil {
|
if qs == nil {
|
||||||
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
|
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
|
||||||
}
|
}
|
||||||
return
|
return qs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
|
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.")
|
||||||
|
return o.QueryTable(ptrStructOrTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return a raw query seter for raw sql string.
|
// return a raw query seter for raw sql string.
|
||||||
@ -595,9 +602,8 @@ func NewOrm() Ormer {
|
|||||||
func NewOrmUsingDB(aliasName string) Ormer {
|
func NewOrmUsingDB(aliasName string) Ormer {
|
||||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||||
return newDBWithAlias(al)
|
return newDBWithAlias(al)
|
||||||
} else {
|
|
||||||
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
|
|
||||||
}
|
}
|
||||||
|
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
|
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
|
||||||
|
|||||||
@ -16,12 +16,13 @@ package orm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExprSep define the expression separation
|
// ExprSep define the expression separation
|
||||||
const (
|
const (
|
||||||
ExprSep = "__"
|
ExprSep = clauses.ExprSep
|
||||||
)
|
)
|
||||||
|
|
||||||
type condValue struct {
|
type condValue struct {
|
||||||
|
|||||||
@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
|
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
|
||||||
|
return d.ExecContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res, err := d.stmt.Exec(args...)
|
res, err := d.stmt.ExecContext(ctx, args...)
|
||||||
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
|
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
|
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
|
||||||
|
return d.QueryContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res, err := d.stmt.Query(args...)
|
res, err := d.stmt.QueryContext(ctx, args...)
|
||||||
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
|
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
|
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
|
||||||
|
return d.QueryRowContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res := d.stmt.QueryRow(args...)
|
res := d.stmt.QueryRow(args...)
|
||||||
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
|
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
@ -31,6 +32,10 @@ var _ Inserter = new(insertSet)
|
|||||||
|
|
||||||
// insert model ignore it's registered or not.
|
// insert model ignore it's registered or not.
|
||||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||||
|
return o.InsertWithCtx(context.Background(), md)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||||
if o.closed {
|
if o.closed {
|
||||||
return 0, ErrStmtClosed
|
return 0, ErrStmtClosed
|
||||||
}
|
}
|
||||||
@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
|
|||||||
if name != o.mi.fullName {
|
if name != o.mi.fullName {
|
||||||
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
|
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
|
||||||
}
|
}
|
||||||
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
|
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -70,11 +75,11 @@ func (o *insertSet) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create new insert queryer.
|
// create new insert queryer.
|
||||||
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) {
|
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
|
||||||
bi := new(insertSet)
|
bi := new(insertSet)
|
||||||
bi.orm = orm
|
bi.orm = orm
|
||||||
bi.mi = mi
|
bi.mi = mi
|
||||||
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
|
st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,10 @@
|
|||||||
|
|
||||||
package orm
|
package orm
|
||||||
|
|
||||||
import "reflect"
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
// model to model struct
|
// model to model struct
|
||||||
type queryM2M struct {
|
type queryM2M struct {
|
||||||
@ -33,6 +36,10 @@ type queryM2M struct {
|
|||||||
//
|
//
|
||||||
// make sure the relation is defined in post model struct tag.
|
// make sure the relation is defined in post model struct tag.
|
||||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||||
|
return o.AddWithCtx(context.Background(), mds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
mi := fi.relThroughModelInfo
|
mi := fi.relThroughModelInfo
|
||||||
mfi := fi.reverseFieldInfo
|
mfi := fi.reverseFieldInfo
|
||||||
@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
names = append(names, otherNames...)
|
names = append(names, otherNames...)
|
||||||
values = append(values, otherValues...)
|
values = append(values, otherValues...)
|
||||||
return dbase.InsertValue(orm.db, mi, true, names, values)
|
return dbase.InsertValue(ctx, orm.db, mi, true, names, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove models following the origin model relationship
|
// remove models following the origin model relationship
|
||||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||||
|
return o.RemoveWithCtx(context.Background(), mds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||||
|
|
||||||
@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
|||||||
|
|
||||||
// check model is existed in relationship of origin model
|
// check model is existed in relationship of origin model
|
||||||
func (o *queryM2M) Exist(md interface{}) bool {
|
func (o *queryM2M) Exist(md interface{}) bool {
|
||||||
|
return o.ExistWithCtx(context.Background(), md)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean all models in related of origin model
|
// clean all models in related of origin model
|
||||||
func (o *queryM2M) Clear() (int64, error) {
|
func (o *queryM2M) Clear() (int64, error) {
|
||||||
|
return o.ClearWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// count all related models of origin model
|
// count all related models of origin model
|
||||||
func (o *queryM2M) Count() (int64, error) {
|
func (o *queryM2M) Count() (int64, error) {
|
||||||
|
return o.CountWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ QueryM2Mer = new(queryM2M)
|
var _ QueryM2Mer = new(queryM2M)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"github.com/beego/beego/v2/client/orm/hints"
|
"github.com/beego/beego/v2/client/orm/hints"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,14 +72,13 @@ type querySet struct {
|
|||||||
limit int64
|
limit int64
|
||||||
offset int64
|
offset int64
|
||||||
groups []string
|
groups []string
|
||||||
orders []string
|
orders []*order_clause.Order
|
||||||
distinct bool
|
distinct bool
|
||||||
forUpdate bool
|
forUpdate bool
|
||||||
useIndex int
|
useIndex int
|
||||||
indexes []string
|
indexes []string
|
||||||
orm *ormBase
|
orm *ormBase
|
||||||
ctx context.Context
|
aggregate string
|
||||||
forContext bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ QuerySeter = new(querySet)
|
var _ QuerySeter = new(querySet)
|
||||||
@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter {
|
|||||||
|
|
||||||
// add ORDER expression.
|
// add ORDER expression.
|
||||||
// "column" means ASC, "-column" means DESC.
|
// "column" means ASC, "-column" means DESC.
|
||||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
func (o querySet) OrderBy(expressions ...string) QuerySeter {
|
||||||
o.orders = exprs
|
if len(expressions) <= 0 {
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
o.orders = order_clause.ParseOrder(expressions...)
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
|
||||||
|
// add ORDER expression.
|
||||||
|
func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter {
|
||||||
|
if len(orders) <= 0 {
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
o.orders = orders
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition {
|
|||||||
|
|
||||||
// return QuerySeter execution result number
|
// return QuerySeter execution result number
|
||||||
func (o *querySet) Count() (int64, error) {
|
func (o *querySet) Count() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.CountWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check result empty or not after QuerySeter executed
|
// check result empty or not after QuerySeter executed
|
||||||
func (o *querySet) Exist() bool {
|
func (o *querySet) Exist() bool {
|
||||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.ExistWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
|
||||||
|
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute update with parameters
|
// execute update with parameters
|
||||||
func (o *querySet) Update(values Params) (int64, error) {
|
func (o *querySet) Update(values Params) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
return o.UpdateWithCtx(context.Background(), values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute delete
|
// execute delete
|
||||||
func (o *querySet) Delete() (int64, error) {
|
func (o *querySet) Delete() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.DeleteWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return a insert queryer.
|
// return a insert queryer.
|
||||||
@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) {
|
|||||||
// i,err := sq.PrepareInsert()
|
// i,err := sq.PrepareInsert()
|
||||||
// i.Add(&user1{},&user2{})
|
// i.Add(&user1{},&user2{})
|
||||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||||
return newInsertSet(o.orm, o.mi)
|
return o.PrepareInsertWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
|
||||||
|
return newInsertSet(ctx, o.orm, o.mi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to containers.
|
// query all data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
return o.AllWithCtx(context.Background(), container, cols...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query one row data and map to containers.
|
// query one row data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||||
|
return o.OneWithCtx(context.Background(), container, cols...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
|
||||||
o.limit = 1
|
o.limit = 1
|
||||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error {
|
|||||||
// expres means condition expression.
|
// expres means condition expression.
|
||||||
// it converts data to []map[column]value.
|
// it converts data to []map[column]value.
|
||||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.ValuesWithCtx(context.Background(), results, exprs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to [][]interface
|
// query all data and map to [][]interface
|
||||||
// it converts data to [][column_index]value
|
// it converts data to [][column_index]value
|
||||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.ValuesListWithCtx(context.Background(), results, exprs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to []interface.
|
// query all data and map to []interface.
|
||||||
// it's designed for one row record set, auto change to []value, not [][column]value.
|
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
return o.ValuesFlatWithCtx(context.Background(), result, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all rows into map[string]interface with specify key and value column name.
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
|
|||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set context to QuerySeter.
|
|
||||||
func (o querySet) WithContext(ctx context.Context) QuerySeter {
|
|
||||||
o.ctx = ctx
|
|
||||||
o.forContext = true
|
|
||||||
return &o
|
|
||||||
}
|
|
||||||
|
|
||||||
// create new QuerySeter.
|
// create new QuerySeter.
|
||||||
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
||||||
o := new(querySet)
|
o := new(querySet)
|
||||||
@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
|||||||
o.orm = orm
|
o.orm = orm
|
||||||
return o
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// aggregate func
|
||||||
|
func (o querySet) Aggregate(s string) QuerySeter {
|
||||||
|
o.aggregate = s
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
@ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) {
|
|||||||
RegisterModel(new(Index))
|
RegisterModel(new(Index))
|
||||||
RegisterModel(new(StrPk))
|
RegisterModel(new(StrPk))
|
||||||
RegisterModel(new(TM))
|
RegisterModel(new(TM))
|
||||||
|
RegisterModel(new(DeptInfo))
|
||||||
|
|
||||||
err := RunSyncdb("default", true, Debug)
|
err := RunSyncdb("default", true, Debug)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) {
|
|||||||
RegisterModel(new(Index))
|
RegisterModel(new(Index))
|
||||||
RegisterModel(new(StrPk))
|
RegisterModel(new(StrPk))
|
||||||
RegisterModel(new(TM))
|
RegisterModel(new(TM))
|
||||||
|
RegisterModel(new(DeptInfo))
|
||||||
|
|
||||||
BootStrap()
|
BootStrap()
|
||||||
|
|
||||||
@ -333,6 +336,73 @@ func TestTM(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
|
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnregisterModel(t *testing.T) {
|
||||||
|
data := []*DeptInfo{
|
||||||
|
{
|
||||||
|
DeptName: "A",
|
||||||
|
EmployeeName: "A1",
|
||||||
|
Salary: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "A",
|
||||||
|
EmployeeName: "A2",
|
||||||
|
Salary: 2000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B1",
|
||||||
|
Salary: 2000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B2",
|
||||||
|
Salary: 4000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B3",
|
||||||
|
Salary: 3000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
qs := dORM.QueryTable("dept_info")
|
||||||
|
i, _ := qs.PrepareInsert()
|
||||||
|
for _, d := range data {
|
||||||
|
_, err := i.Insert(d)
|
||||||
|
if err != nil {
|
||||||
|
throwFail(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f := func() {
|
||||||
|
var res []UnregisterModel
|
||||||
|
n, err := dORM.QueryTable("dept_info").All(&res)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(n, 5))
|
||||||
|
throwFail(t, AssertIs(res[0].EmployeeName, "A1"))
|
||||||
|
|
||||||
|
type Sum struct {
|
||||||
|
DeptName string
|
||||||
|
Total int
|
||||||
|
}
|
||||||
|
var sun []Sum
|
||||||
|
qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun)
|
||||||
|
throwFail(t, AssertIs(sun[0].DeptName, "A"))
|
||||||
|
throwFail(t, AssertIs(sun[0].Total, 3000))
|
||||||
|
|
||||||
|
type Max struct {
|
||||||
|
DeptName string
|
||||||
|
Max float64
|
||||||
|
}
|
||||||
|
var max []Max
|
||||||
|
qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max)
|
||||||
|
throwFail(t, AssertIs(max[1].DeptName, "B"))
|
||||||
|
throwFail(t, AssertIs(max[1].Max, 4000))
|
||||||
|
}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNullDataTypes(t *testing.T) {
|
func TestNullDataTypes(t *testing.T) {
|
||||||
d := DataNull{}
|
d := DataNull{}
|
||||||
|
|
||||||
@ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) {
|
|||||||
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
|
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`profile__age`),
|
||||||
|
order_clause.SortDescending(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
if IsMysql {
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`rand()`),
|
||||||
|
order_clause.Raw(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAll(t *testing.T) {
|
func TestAll(t *testing.T) {
|
||||||
@ -1163,6 +1253,19 @@ func TestValues(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column("Id"),
|
||||||
|
order_clause.SortAscending(),
|
||||||
|
),
|
||||||
|
).Values(&maps)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
if num == 3 {
|
||||||
|
throwFail(t, AssertIs(maps[0]["UserName"], "slene"))
|
||||||
|
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
||||||
|
}
|
||||||
|
|
||||||
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
|
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 3))
|
throwFail(t, AssertIs(num, 3))
|
||||||
@ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(!cycleFlag, true))
|
throwFail(t, AssertIs(!cycleFlag, true))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
user := User{UserName: "slene"}
|
||||||
|
|
||||||
|
err := dORM.ReadWithCtx(ctx, &user, "UserName")
|
||||||
|
throwFail(t, err)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
err = dORM.ReadWithCtx(ctx, &user, "UserName")
|
||||||
|
throwFail(t, AssertIs(err, context.Canceled))
|
||||||
|
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
qs := dORM.QueryTable(user)
|
||||||
|
_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx)
|
||||||
|
throwFail(t, AssertIs(err, context.Canceled))
|
||||||
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"github.com/beego/beego/v2/core/utils"
|
"github.com/beego/beego/v2/core/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,12 +197,16 @@ type DQL interface {
|
|||||||
// post := Post{Id: 4}
|
// post := Post{Id: 4}
|
||||||
// m2m := Ormer.QueryM2M(&post, "Tags")
|
// m2m := Ormer.QueryM2M(&post, "Tags")
|
||||||
QueryM2M(md interface{}, name string) QueryM2Mer
|
QueryM2M(md interface{}, name string) QueryM2Mer
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
|
||||||
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
|
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
|
||||||
|
|
||||||
// return a QuerySeter for table operations.
|
// return a QuerySeter for table operations.
|
||||||
// table name can be string or struct.
|
// table name can be string or struct.
|
||||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
QueryTable(ptrStructOrTableName interface{}) QuerySeter
|
QueryTable(ptrStructOrTableName interface{}) QuerySeter
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
|
||||||
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
|
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
|
||||||
|
|
||||||
DBStats() *sql.DBStats
|
DBStats() *sql.DBStats
|
||||||
@ -236,6 +241,7 @@ type TxOrmer interface {
|
|||||||
// Inserter insert prepared statement
|
// Inserter insert prepared statement
|
||||||
type Inserter interface {
|
type Inserter interface {
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
|
InsertWithCtx(context.Context, interface{}) (int64, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,6 +301,28 @@ type QuerySeter interface {
|
|||||||
// for example:
|
// for example:
|
||||||
// qs.OrderBy("-status")
|
// qs.OrderBy("-status")
|
||||||
OrderBy(exprs ...string) QuerySeter
|
OrderBy(exprs ...string) QuerySeter
|
||||||
|
// add ORDER expression by order clauses
|
||||||
|
// for example:
|
||||||
|
// OrderClauses(
|
||||||
|
// order_clause.Clause(
|
||||||
|
// order.Column("Id"),
|
||||||
|
// order.SortAscending(),
|
||||||
|
// ),
|
||||||
|
// order_clause.Clause(
|
||||||
|
// order.Column("status"),
|
||||||
|
// order.SortDescending(),
|
||||||
|
// ),
|
||||||
|
// )
|
||||||
|
// OrderClauses(order_clause.Clause(
|
||||||
|
// order_clause.Column(`user__status`),
|
||||||
|
// order_clause.SortDescending(),//default None
|
||||||
|
// ))
|
||||||
|
// OrderClauses(order_clause.Clause(
|
||||||
|
// order_clause.Column(`random()`),
|
||||||
|
// order_clause.SortNone(),//default None
|
||||||
|
// order_clause.Raw(),//default false.if true, do not check field is valid or not
|
||||||
|
// ))
|
||||||
|
OrderClauses(orders ...*order_clause.Order) QuerySeter
|
||||||
// add FORCE INDEX expression.
|
// add FORCE INDEX expression.
|
||||||
// for example:
|
// for example:
|
||||||
// qs.ForceIndex(`idx_name1`,`idx_name2`)
|
// qs.ForceIndex(`idx_name1`,`idx_name2`)
|
||||||
@ -333,9 +361,11 @@ type QuerySeter interface {
|
|||||||
// for example:
|
// for example:
|
||||||
// num, err = qs.Filter("profile__age__gt", 28).Count()
|
// num, err = qs.Filter("profile__age__gt", 28).Count()
|
||||||
Count() (int64, error)
|
Count() (int64, error)
|
||||||
|
CountWithCtx(context.Context) (int64, error)
|
||||||
// check result empty or not after QuerySeter executed
|
// check result empty or not after QuerySeter executed
|
||||||
// the same as QuerySeter.Count > 0
|
// the same as QuerySeter.Count > 0
|
||||||
Exist() bool
|
Exist() bool
|
||||||
|
ExistWithCtx(context.Context) bool
|
||||||
// execute update with parameters
|
// execute update with parameters
|
||||||
// for example:
|
// for example:
|
||||||
// num, err = qs.Filter("user_name", "slene").Update(Params{
|
// num, err = qs.Filter("user_name", "slene").Update(Params{
|
||||||
@ -345,11 +375,13 @@ type QuerySeter interface {
|
|||||||
// "user_name": "slene2"
|
// "user_name": "slene2"
|
||||||
// }) // user slene's name will change to slene2
|
// }) // user slene's name will change to slene2
|
||||||
Update(values Params) (int64, error)
|
Update(values Params) (int64, error)
|
||||||
|
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
|
||||||
// delete from table
|
// delete from table
|
||||||
// for example:
|
// for example:
|
||||||
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
|
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
|
||||||
// //delete two user who's name is testing1 or testing2
|
// //delete two user who's name is testing1 or testing2
|
||||||
Delete() (int64, error)
|
Delete() (int64, error)
|
||||||
|
DeleteWithCtx(context.Context) (int64, error)
|
||||||
// return a insert queryer.
|
// return a insert queryer.
|
||||||
// it can be used in times.
|
// it can be used in times.
|
||||||
// example:
|
// example:
|
||||||
@ -358,18 +390,21 @@ type QuerySeter interface {
|
|||||||
// num, err = i.Insert(&user2) // user table will add one record user2 at once
|
// num, err = i.Insert(&user2) // user table will add one record user2 at once
|
||||||
// err = i.Close() //don't forget call Close
|
// err = i.Close() //don't forget call Close
|
||||||
PrepareInsert() (Inserter, error)
|
PrepareInsert() (Inserter, error)
|
||||||
|
PrepareInsertWithCtx(context.Context) (Inserter, error)
|
||||||
// query all data and map to containers.
|
// query all data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
// for example:
|
// for example:
|
||||||
// var users []*User
|
// var users []*User
|
||||||
// qs.All(&users) // users[0],users[1],users[2] ...
|
// qs.All(&users) // users[0],users[1],users[2] ...
|
||||||
All(container interface{}, cols ...string) (int64, error)
|
All(container interface{}, cols ...string) (int64, error)
|
||||||
|
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
|
||||||
// query one row data and map to containers.
|
// query one row data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
// for example:
|
// for example:
|
||||||
// var user User
|
// var user User
|
||||||
// qs.One(&user) //user.UserName == "slene"
|
// qs.One(&user) //user.UserName == "slene"
|
||||||
One(container interface{}, cols ...string) error
|
One(container interface{}, cols ...string) error
|
||||||
|
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
|
||||||
// query all data and map to []map[string]interface.
|
// query all data and map to []map[string]interface.
|
||||||
// expres means condition expression.
|
// expres means condition expression.
|
||||||
// it converts data to []map[column]value.
|
// it converts data to []map[column]value.
|
||||||
@ -377,18 +412,21 @@ type QuerySeter interface {
|
|||||||
// var maps []Params
|
// var maps []Params
|
||||||
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
|
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
|
||||||
Values(results *[]Params, exprs ...string) (int64, error)
|
Values(results *[]Params, exprs ...string) (int64, error)
|
||||||
|
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
|
||||||
// query all data and map to [][]interface
|
// query all data and map to [][]interface
|
||||||
// it converts data to [][column_index]value
|
// it converts data to [][column_index]value
|
||||||
// for example:
|
// for example:
|
||||||
// var list []ParamsList
|
// var list []ParamsList
|
||||||
// qs.ValuesList(&list) // list[0][1] == "slene"
|
// qs.ValuesList(&list) // list[0][1] == "slene"
|
||||||
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
|
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
|
||||||
|
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
|
||||||
// query all data and map to []interface.
|
// query all data and map to []interface.
|
||||||
// it's designed for one column record set, auto change to []value, not [][column]value.
|
// it's designed for one column record set, auto change to []value, not [][column]value.
|
||||||
// for example:
|
// for example:
|
||||||
// var list ParamsList
|
// var list ParamsList
|
||||||
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
|
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
|
||||||
ValuesFlat(result *ParamsList, expr string) (int64, error)
|
ValuesFlat(result *ParamsList, expr string) (int64, error)
|
||||||
|
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
|
||||||
// query all rows into map[string]interface with specify key and value column name.
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
// keyCol = "name", valueCol = "value"
|
// keyCol = "name", valueCol = "value"
|
||||||
// table data
|
// table data
|
||||||
@ -411,6 +449,15 @@ type QuerySeter interface {
|
|||||||
// Found int
|
// Found int
|
||||||
// }
|
// }
|
||||||
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
||||||
|
// aggregate func.
|
||||||
|
// for example:
|
||||||
|
// type result struct {
|
||||||
|
// DeptName string
|
||||||
|
// Total int
|
||||||
|
// }
|
||||||
|
// var res []result
|
||||||
|
// o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res)
|
||||||
|
Aggregate(s string) QuerySeter
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryM2Mer model to model query struct
|
// QueryM2Mer model to model query struct
|
||||||
@ -428,18 +475,23 @@ type QueryM2Mer interface {
|
|||||||
// insert one or more rows to m2m table
|
// insert one or more rows to m2m table
|
||||||
// make sure the relation is defined in post model struct tag.
|
// make sure the relation is defined in post model struct tag.
|
||||||
Add(...interface{}) (int64, error)
|
Add(...interface{}) (int64, error)
|
||||||
|
AddWithCtx(context.Context, ...interface{}) (int64, error)
|
||||||
// remove models following the origin model relationship
|
// remove models following the origin model relationship
|
||||||
// only delete rows from m2m table
|
// only delete rows from m2m table
|
||||||
// for example:
|
// for example:
|
||||||
// tag3 := &Tag{Id:5,Name: "TestTag3"}
|
// tag3 := &Tag{Id:5,Name: "TestTag3"}
|
||||||
// num, err = m2m.Remove(tag3)
|
// num, err = m2m.Remove(tag3)
|
||||||
Remove(...interface{}) (int64, error)
|
Remove(...interface{}) (int64, error)
|
||||||
|
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
|
||||||
// check model is existed in relationship of origin model
|
// check model is existed in relationship of origin model
|
||||||
Exist(interface{}) bool
|
Exist(interface{}) bool
|
||||||
|
ExistWithCtx(context.Context, interface{}) bool
|
||||||
// clean all models in related of origin model
|
// clean all models in related of origin model
|
||||||
Clear() (int64, error)
|
Clear() (int64, error)
|
||||||
|
ClearWithCtx(context.Context) (int64, error)
|
||||||
// count all related models of origin model
|
// count all related models of origin model
|
||||||
Count() (int64, error)
|
Count() (int64, error)
|
||||||
|
CountWithCtx(context.Context) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawPreparer raw query statement
|
// RawPreparer raw query statement
|
||||||
@ -513,11 +565,11 @@ type RawSeter interface {
|
|||||||
type stmtQuerier interface {
|
type stmtQuerier interface {
|
||||||
Close() error
|
Close() error
|
||||||
Exec(args ...interface{}) (sql.Result, error)
|
Exec(args ...interface{}) (sql.Result, error)
|
||||||
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||||
Query(args ...interface{}) (*sql.Rows, error)
|
Query(args ...interface{}) (*sql.Rows, error)
|
||||||
// QueryContext(args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRow(args ...interface{}) *sql.Row
|
QueryRow(args ...interface{}) *sql.Row
|
||||||
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
// db querier
|
// db querier
|
||||||
@ -554,28 +606,28 @@ type txEnder interface {
|
|||||||
|
|
||||||
// base database struct
|
// base database struct
|
||||||
type dbBaser interface {
|
type dbBaser interface {
|
||||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
|
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
|
||||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
||||||
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||||
|
|
||||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
||||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
|
|
||||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
||||||
|
|
||||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||||
|
|
||||||
SupportUpdateJoin() bool
|
SupportUpdateJoin() bool
|
||||||
OperatorSQL(string) string
|
OperatorSQL(string) string
|
||||||
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
|
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
|
||||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||||
MaxLimit() uint64
|
MaxLimit() uint64
|
||||||
TableQuote() string
|
TableQuote() string
|
||||||
ReplaceMarks(*string)
|
ReplaceMarks(*string)
|
||||||
@ -584,12 +636,12 @@ type dbBaser interface {
|
|||||||
TimeToDB(*time.Time, *time.Location)
|
TimeToDB(*time.Time, *time.Location)
|
||||||
DbTypes() map[string]string
|
DbTypes() map[string]string
|
||||||
GetTables(dbQuerier) (map[string]bool, error)
|
GetTables(dbQuerier) (map[string]bool, error)
|
||||||
GetColumns(dbQuerier, string) (map[string][3]string, error)
|
GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error)
|
||||||
ShowTablesQuery() string
|
ShowTablesQuery() string
|
||||||
ShowColumnsQuery(string) string
|
ShowColumnsQuery(string) string
|
||||||
IndexExists(dbQuerier, string, string) bool
|
IndexExists(context.Context, dbQuerier, string, string) bool
|
||||||
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
||||||
setval(dbQuerier, *modelInfo, []string) error
|
setval(context.Context, dbQuerier, *modelInfo, []string) error
|
||||||
|
|
||||||
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
|
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
|
||||||
}
|
}
|
||||||
|
|||||||
91
core/berror/codes.go
Normal file
91
core/berror/codes.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package berror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// A Code is an unsigned 32-bit error code as defined in the beego spec.
|
||||||
|
type Code interface {
|
||||||
|
Code() uint32
|
||||||
|
Module() string
|
||||||
|
Desc() string
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultCodeRegistry = &codeRegistry{
|
||||||
|
codes: make(map[uint32]*codeDefinition, 127),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefineCode defining a new Code
|
||||||
|
// Before defining a new code, please read Beego specification.
|
||||||
|
// desc could be markdown doc
|
||||||
|
func DefineCode(code uint32, module string, name string, desc string) Code {
|
||||||
|
res := &codeDefinition{
|
||||||
|
code: code,
|
||||||
|
module: module,
|
||||||
|
desc: desc,
|
||||||
|
}
|
||||||
|
defaultCodeRegistry.lock.Lock()
|
||||||
|
defer defaultCodeRegistry.lock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := defaultCodeRegistry.codes[code]; ok {
|
||||||
|
panic(fmt.Sprintf("duplicate code, code %d has been registered", code))
|
||||||
|
}
|
||||||
|
defaultCodeRegistry.codes[code] = res
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
type codeRegistry struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
codes map[uint32]*codeDefinition
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (cr *codeRegistry) Get(code uint32) (Code, bool) {
|
||||||
|
cr.lock.RLock()
|
||||||
|
defer cr.lock.RUnlock()
|
||||||
|
c, ok := cr.codes[code]
|
||||||
|
return c, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
type codeDefinition struct {
|
||||||
|
code uint32
|
||||||
|
module string
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (c *codeDefinition) Name() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codeDefinition) Code() uint32 {
|
||||||
|
return c.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codeDefinition) Module() string {
|
||||||
|
return c.module
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codeDefinition) Desc() string {
|
||||||
|
return c.desc
|
||||||
|
}
|
||||||
|
|
||||||
69
core/berror/error.go
Normal file
69
core/berror/error.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package berror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// code, msg
|
||||||
|
const errFmt = "ERROR-%d, %s"
|
||||||
|
|
||||||
|
// Err returns an error representing c and msg. If c is OK, returns nil.
|
||||||
|
func Error(c Code, msg string) error {
|
||||||
|
return fmt.Errorf(errFmt, c.Code(), msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf returns error
|
||||||
|
func Errorf(c Code, format string, a ...interface{}) error {
|
||||||
|
return Error(c, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Wrap(err error, c Code, msg string) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, fmt.Sprintf(errFmt, c.Code(), msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Wrapf(err error, c Code, format string, a ...interface{}) error {
|
||||||
|
return Wrap(err, c, fmt.Sprintf(format, a...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromError is very simple. It just parse error msg and check whether code has been register
|
||||||
|
// if code not being register, return unknown
|
||||||
|
// if err.Error() is not valid beego error code, return unknown
|
||||||
|
func FromError(err error) (Code, bool) {
|
||||||
|
msg := err.Error()
|
||||||
|
codeSeg := strings.SplitN(msg, ",", 2)
|
||||||
|
if strings.HasPrefix(codeSeg[0], "ERROR-") {
|
||||||
|
codeStr := strings.SplitN(codeSeg[0], "-", 2)
|
||||||
|
if len(codeStr) < 2 {
|
||||||
|
return Unknown, false
|
||||||
|
}
|
||||||
|
codeInt, e := strconv.ParseUint(codeStr[1], 10, 32)
|
||||||
|
if e != nil {
|
||||||
|
return Unknown, false
|
||||||
|
}
|
||||||
|
if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok {
|
||||||
|
return code, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Unknown, false
|
||||||
|
}
|
||||||
77
core/berror/error_test.go
Normal file
77
core/berror/error_test.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
// Copyright 2020 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package berror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1")
|
||||||
|
|
||||||
|
var testErr = errors.New("hello, this is error")
|
||||||
|
|
||||||
|
func TestErrorf(t *testing.T) {
|
||||||
|
msg := Errorf(testCode1, "errorf %s", "aaaa")
|
||||||
|
assert.NotNil(t, msg)
|
||||||
|
assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapf(t *testing.T) {
|
||||||
|
err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.True(t, errors.Is(err, testErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromError(t *testing.T) {
|
||||||
|
err := errors.New("ERROR-1, errorf aaaa")
|
||||||
|
code, ok := FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, testCode1, code)
|
||||||
|
assert.Equal(t, "unit_test", code.Module())
|
||||||
|
assert.Equal(t, "Hello, test code1", code.Desc())
|
||||||
|
|
||||||
|
err = errors.New("not beego error")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
|
||||||
|
err = errors.New("ERROR-2, not register")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
|
||||||
|
err = errors.New("ERROR-aaa, invalid code")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
|
||||||
|
err = errors.New("aaaaaaaaaaaaaa")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
|
||||||
|
err = errors.New("ERROR-2-3, invalid error")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
|
||||||
|
err = errors.New("ERROR, invalid error")
|
||||||
|
code, ok = FromError(err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, Unknown, code)
|
||||||
|
}
|
||||||
52
core/berror/pre_define_code.go
Normal file
52
core/berror/pre_define_code.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// Copyright 2021 beego
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package berror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pre define code
|
||||||
|
|
||||||
|
// Unknown indicates got some error which is not defined
|
||||||
|
var Unknown = DefineCode(5000001, "error", "Unknown",fmt.Sprintf(`
|
||||||
|
Unknown error code. Usually you will see this code in three cases:
|
||||||
|
1. You forget to define Code or function DefineCode not being executed;
|
||||||
|
2. This is not Beego's error but you call FromError();
|
||||||
|
3. Beego got unexpected error and don't know how to handle it, and then return Unknown error
|
||||||
|
|
||||||
|
A common practice to DefineCode looks like:
|
||||||
|
%s
|
||||||
|
|
||||||
|
In this way, you may forget to import this package, and got Unknown error.
|
||||||
|
|
||||||
|
Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err)
|
||||||
|
|
||||||
|
`, goCodeBlock(`
|
||||||
|
import your_package
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefineCode(5100100, "your_module", "detail")
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
`)))
|
||||||
|
|
||||||
|
func goCodeBlock(code string) string {
|
||||||
|
return codeBlock("go", code)
|
||||||
|
}
|
||||||
|
func codeBlock(lan string, code string) string {
|
||||||
|
return fmt.Sprintf("```%s\n%s\n```", lan, code)
|
||||||
|
}
|
||||||
|
|
||||||
@ -69,8 +69,8 @@ func (p *PatternLogFormatter) ToString(lm *LogMsg) string {
|
|||||||
'm': lm.Msg,
|
'm': lm.Msg,
|
||||||
'n': strconv.Itoa(lm.LineNumber),
|
'n': strconv.Itoa(lm.LineNumber),
|
||||||
'l': strconv.Itoa(lm.Level),
|
'l': strconv.Itoa(lm.Level),
|
||||||
't': levelPrefix[lm.Level-1],
|
't': levelPrefix[lm.Level],
|
||||||
'T': levelNames[lm.Level-1],
|
'T': levelNames[lm.Level],
|
||||||
'F': lm.FilePath,
|
'F': lm.FilePath,
|
||||||
}
|
}
|
||||||
_, m['f'] = path.Split(lm.FilePath)
|
_, m['f'] = path.Split(lm.FilePath)
|
||||||
|
|||||||
@ -88,7 +88,7 @@ func TestPatternLogFormatter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
got := tes.ToString(lm)
|
got := tes.ToString(lm)
|
||||||
want := lm.FilePath + ":" + strconv.Itoa(lm.LineNumber) + "|" +
|
want := lm.FilePath + ":" + strconv.Itoa(lm.LineNumber) + "|" +
|
||||||
when.Format(tes.WhenFormat) + levelPrefix[lm.Level-1] + ">> " + lm.Msg
|
when.Format(tes.WhenFormat) + levelPrefix[lm.Level] + ">> " + lm.Msg
|
||||||
if got != want {
|
if got != want {
|
||||||
t.Errorf("want %s, got %s", want, got)
|
t.Errorf("want %s, got %s", want, got)
|
||||||
}
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -25,7 +25,7 @@ require (
|
|||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible
|
github.com/gomodule/redigo v2.0.0+incompatible
|
||||||
github.com/google/go-cmp v0.5.0 // indirect
|
github.com/google/go-cmp v0.5.0 // indirect
|
||||||
github.com/google/uuid v1.1.1 // indirect
|
github.com/google/uuid v1.1.1
|
||||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6
|
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6
|
||||||
|
|||||||
@ -108,8 +108,11 @@ func registerAdmin() error {
|
|||||||
c := &adminController{
|
c := &adminController{
|
||||||
servers: make([]*HttpServer, 0, 2),
|
servers: make([]*HttpServer, 0, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copy config to avoid conflict
|
||||||
|
adminCfg := *BConfig
|
||||||
beeAdminApp = &adminApp{
|
beeAdminApp = &adminApp{
|
||||||
HttpServer: NewHttpServerWithCfg(BConfig),
|
HttpServer: NewHttpServerWithCfg(&adminCfg),
|
||||||
}
|
}
|
||||||
// keep in mind that all data should be html escaped to avoid XSS attack
|
// keep in mind that all data should be html escaped to avoid XSS attack
|
||||||
beeAdminApp.Router("/", c, "get:AdminIndex")
|
beeAdminApp.Router("/", c, "get:AdminIndex")
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Session return session store of this context of request
|
||||||
|
func (ctx *Context) Session() (store session.Store, err error) {
|
||||||
|
if ctx.Input != nil {
|
||||||
|
if ctx.Input.CruSession != nil {
|
||||||
|
store = ctx.Input.CruSession
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
err = errors.New(`no valid session store(please initialize session)`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = errors.New(`no valid input`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Response is a wrapper for the http.ResponseWriter
|
// Response is a wrapper for the http.ResponseWriter
|
||||||
// Started: if true, response was already written to so the other handler will not be executed
|
// Started: if true, response was already written to so the other handler will not be executed
|
||||||
type Response struct {
|
type Response struct {
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package context
|
package context
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) {
|
|||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContext_Session(t *testing.T) {
|
||||||
|
c := NewContext()
|
||||||
|
if store, err := c.Session(); store != nil || err == nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_Session1(t *testing.T) {
|
||||||
|
c := Context{}
|
||||||
|
if store, err := c.Session(); store != nil || err == nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_Session2(t *testing.T) {
|
||||||
|
c := NewContext()
|
||||||
|
c.Input.CruSession = &session.MemSessionStore{}
|
||||||
|
|
||||||
|
if store, err := c.Session(); store == nil || err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,23 +17,49 @@ package prometheus
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
"github.com/beego/beego/v2"
|
"github.com/beego/beego/v2"
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
"github.com/beego/beego/v2/server/web"
|
"github.com/beego/beego/v2/server/web"
|
||||||
"github.com/beego/beego/v2/server/web/context"
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const unknownRouterPattern = "UnknownRouterPattern"
|
||||||
|
|
||||||
// FilterChainBuilder is an extension point,
|
// FilterChainBuilder is an extension point,
|
||||||
// when we want to support some configuration,
|
// when we want to support some configuration,
|
||||||
// please use this structure
|
// please use this structure
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var summaryVec prometheus.ObserverVec
|
||||||
|
var initSummaryVec sync.Once
|
||||||
|
|
||||||
// FilterChain returns a FilterFunc. The filter will records some metrics
|
// FilterChain returns a FilterFunc. The filter will records some metrics
|
||||||
func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc {
|
func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc {
|
||||||
|
|
||||||
|
initSummaryVec.Do(func() {
|
||||||
|
summaryVec = builder.buildVec()
|
||||||
|
err := prometheus.Register(summaryVec)
|
||||||
|
if _, ok := err.(*prometheus.AlreadyRegisteredError); err != nil && !ok {
|
||||||
|
logs.Error("web module register prometheus vector failed, %+v", err)
|
||||||
|
}
|
||||||
|
registerBuildInfo()
|
||||||
|
})
|
||||||
|
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
startTime := time.Now()
|
||||||
|
next(ctx)
|
||||||
|
endTime := time.Now()
|
||||||
|
go report(endTime.Sub(startTime), ctx, summaryVec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (builder *FilterChainBuilder) buildVec() *prometheus.SummaryVec {
|
||||||
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||||
Name: "beego",
|
Name: "beego",
|
||||||
Subsystem: "http_request",
|
Subsystem: "http_request",
|
||||||
@ -44,17 +70,7 @@ func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFu
|
|||||||
},
|
},
|
||||||
Help: "The statics info for http request",
|
Help: "The statics info for http request",
|
||||||
}, []string{"pattern", "method", "status"})
|
}, []string{"pattern", "method", "status"})
|
||||||
|
return summaryVec
|
||||||
prometheus.MustRegister(summaryVec)
|
|
||||||
|
|
||||||
registerBuildInfo()
|
|
||||||
|
|
||||||
return func(ctx *context.Context) {
|
|
||||||
startTime := time.Now()
|
|
||||||
next(ctx)
|
|
||||||
endTime := time.Now()
|
|
||||||
go report(endTime.Sub(startTime), ctx, summaryVec)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerBuildInfo() {
|
func registerBuildInfo() {
|
||||||
@ -75,13 +91,17 @@ func registerBuildInfo() {
|
|||||||
},
|
},
|
||||||
}, []string{})
|
}, []string{})
|
||||||
|
|
||||||
prometheus.MustRegister(buildInfo)
|
_ = prometheus.Register(buildInfo)
|
||||||
buildInfo.WithLabelValues().Set(1)
|
buildInfo.WithLabelValues().Set(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func report(dur time.Duration, ctx *context.Context, vec *prometheus.SummaryVec) {
|
func report(dur time.Duration, ctx *context.Context, vec prometheus.ObserverVec) {
|
||||||
status := ctx.Output.Status
|
status := ctx.Output.Status
|
||||||
ptn := ctx.Input.GetData("RouterPattern").(string)
|
ptnItf := ctx.Input.GetData("RouterPattern")
|
||||||
|
ptn := unknownRouterPattern
|
||||||
|
if ptnItf != nil {
|
||||||
|
ptn = ptnItf.(string)
|
||||||
|
}
|
||||||
ms := dur / time.Millisecond
|
ms := dur / time.Millisecond
|
||||||
vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status)).Observe(float64(ms))
|
vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status)).Observe(float64(ms))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@ -37,4 +38,19 @@ func TestFilterChain(t *testing.T) {
|
|||||||
ctx.Input.SetData("RouterPattern", "my-route")
|
ctx.Input.SetData("RouterPattern", "my-route")
|
||||||
filter(ctx)
|
filter(ctx)
|
||||||
assert.True(t, ctx.Input.GetData("invocation").(bool))
|
assert.True(t, ctx.Input.GetData("invocation").(bool))
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterChainBuilder_report(t *testing.T) {
|
||||||
|
|
||||||
|
ctx := context.NewContext()
|
||||||
|
r, _ := http.NewRequest("GET", "/prometheus/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx.Reset(w, r)
|
||||||
|
fb := &FilterChainBuilder{}
|
||||||
|
// without router info
|
||||||
|
report(time.Second, ctx, fb.buildVec())
|
||||||
|
|
||||||
|
ctx.Input.SetData("RouterPattern", "my-route")
|
||||||
|
report(time.Second, ctx, fb.buildVec())
|
||||||
}
|
}
|
||||||
35
server/web/filter/session/filter.go
Normal file
35
server/web/filter/session/filter.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
webContext "github.com/beego/beego/v2/server/web/context"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
//Session maintain session for web service
|
||||||
|
//Session new a session storage and store it into webContext.Context
|
||||||
|
func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain {
|
||||||
|
sessionConfig := session.NewManagerConfig(options...)
|
||||||
|
sessionManager, _ := session.NewManager(string(providerType), sessionConfig)
|
||||||
|
go sessionManager.GC()
|
||||||
|
|
||||||
|
return func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if ctx.Input.CruSession != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil {
|
||||||
|
logs.Error(`init session error:%s`, err.Error())
|
||||||
|
} else {
|
||||||
|
//release session at the end of request
|
||||||
|
defer sess.SessionRelease(context.Background(), ctx.ResponseWriter)
|
||||||
|
ctx.Input.CruSession = sess
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
86
server/web/filter/session/filter_test.go
Normal file
86
server/web/filter/session/filter_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
webContext "github.com/beego/beego/v2/server/web/context"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) {
|
||||||
|
r, _ := http.NewRequest(method, path, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != code {
|
||||||
|
t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSession(t *testing.T) {
|
||||||
|
storeKey := uuid.New().String()
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
Session(
|
||||||
|
session.ProviderMemory,
|
||||||
|
session.CfgCookieName(`go_session_id`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if store := ctx.Input.GetData(storeKey); store == nil {
|
||||||
|
t.Error(`store should not be nil`)
|
||||||
|
}
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler.Any("*", func(ctx *webContext.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSession1(t *testing.T) {
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
Session(
|
||||||
|
session.ProviderMemory,
|
||||||
|
session.CfgCookieName(`go_session_id`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if store, err := ctx.Session(); store == nil || err != nil {
|
||||||
|
t.Error(`store should not be nil`)
|
||||||
|
}
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler.Any("*", func(ctx *webContext.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
|
||||||
|
}
|
||||||
@ -15,9 +15,12 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
|
|||||||
ns := NewNamespace("/chain")
|
ns := NewNamespace("/chain")
|
||||||
|
|
||||||
ns.Get("/*", func(ctx *context.Context) {
|
ns.Get("/*", func(ctx *context.Context) {
|
||||||
ctx.Output.Body([]byte("hello"))
|
_ = ctx.Output.Body([]byte("hello"))
|
||||||
})
|
})
|
||||||
|
|
||||||
r, _ := http.NewRequest("GET", "/chain/user", nil)
|
r, _ := http.NewRequest("GET", "/chain/user", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
BeeApp.Handlers.Init()
|
||||||
BeeApp.Handlers.ServeHTTP(w, r)
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
|
||||||
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
|
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
|
||||||
|
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ := http.NewRequest("GET", "/abc", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
BeeApp.Handlers.Init()
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
first := w.Header().Get("first")
|
||||||
|
second := w.Header().Get("second")
|
||||||
|
|
||||||
|
ft, _ := strconv.ParseInt(first, 10, 64)
|
||||||
|
st, _ := strconv.ParseInt(second, 10, 64)
|
||||||
|
|
||||||
|
assert.True(t, st > ft)
|
||||||
|
}
|
||||||
|
|||||||
@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) {
|
|||||||
|
|
||||||
// setup the handler
|
// setup the handler
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
|
handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
// get the Set-Cookie value
|
// get the Set-Cookie value
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/coreos/etcd/pkg/fileutil"
|
||||||
|
|
||||||
"github.com/beego/beego/v2/core/logs"
|
"github.com/beego/beego/v2/core/logs"
|
||||||
"github.com/beego/beego/v2/server/web/context"
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
"github.com/beego/beego/v2/server/web/session"
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
@ -99,7 +101,12 @@ func registerGzip() error {
|
|||||||
|
|
||||||
func registerCommentRouter() error {
|
func registerCommentRouter() error {
|
||||||
if BConfig.RunMode == DEV {
|
if BConfig.RunMode == DEV {
|
||||||
if err := parserPkg(filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)); err != nil {
|
ctrlDir := filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)
|
||||||
|
if !fileutil.Exist(ctrlDir) {
|
||||||
|
logs.Warn("controller package not found, won't generate router: ", ctrlDir)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := parserPkg(ctrlDir); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
|
|||||||
// Router same as beego.Rourer
|
// Router same as beego.Rourer
|
||||||
// refer: https://godoc.org/github.com/beego/beego/v2#Router
|
// refer: https://godoc.org/github.com/beego/beego/v2#Router
|
||||||
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
|
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
|
||||||
n.handlers.Add(rootpath, c, mappingMethods...)
|
n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,6 +187,54 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RouterGet same as beego.RouterGet
|
||||||
|
func (n *Namespace) RouterGet(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterGet(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPost same as beego.RouterPost
|
||||||
|
func (n *Namespace) RouterPost(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterPost(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterDelete same as beego.RouterDelete
|
||||||
|
func (n *Namespace) RouterDelete(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterDelete(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPut same as beego.RouterPut
|
||||||
|
func (n *Namespace) RouterPut(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterPut(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterHead same as beego.RouterHead
|
||||||
|
func (n *Namespace) RouterHead(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterHead(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOptions same as beego.RouterOptions
|
||||||
|
func (n *Namespace) RouterOptions(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterOptions(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPatch same as beego.RouterPatch
|
||||||
|
func (n *Namespace) RouterPatch(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterPatch(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any same as beego.RouterAny
|
||||||
|
func (n *Namespace) RouterAny(rootpath string, f interface{}) *Namespace {
|
||||||
|
n.handlers.RouterAny(rootpath, f)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
// Namespace add nest Namespace
|
// Namespace add nest Namespace
|
||||||
// usage:
|
// usage:
|
||||||
// ns := beego.NewNamespace(“/v1”).
|
// ns := beego.NewNamespace(“/v1”).
|
||||||
@ -366,6 +414,62 @@ func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NSRouterGet call Namespace RouterGet
|
||||||
|
func NSRouterGet(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterGet(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterPost call Namespace RouterPost
|
||||||
|
func NSRouterPost(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterPost(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterHead call Namespace RouterHead
|
||||||
|
func NSRouterHead(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterHead(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterPut call Namespace RouterPut
|
||||||
|
func NSRouterPut(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterPut(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterDelete call Namespace RouterDelete
|
||||||
|
func NSRouterDelete(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterDelete(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterAny call Namespace RouterAny
|
||||||
|
func NSRouterAny(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterAny(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterOptions call Namespace RouterOptions
|
||||||
|
func NSRouterOptions(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterOptions(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NSRouterPatch call Namespace RouterPatch
|
||||||
|
func NSRouterPatch(rootpath string, f interface{}) LinkNamespace {
|
||||||
|
return func(ns *Namespace) {
|
||||||
|
ns.RouterPatch(rootpath, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NSAutoRouter call Namespace AutoRouter
|
// NSAutoRouter call Namespace AutoRouter
|
||||||
func NSAutoRouter(c ControllerInterface) LinkNamespace {
|
func NSAutoRouter(c ControllerInterface) LinkNamespace {
|
||||||
return func(ns *Namespace) {
|
return func(ns *Namespace) {
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -23,6 +24,40 @@ import (
|
|||||||
"github.com/beego/beego/v2/server/web/context"
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
exampleBody = "hello world"
|
||||||
|
examplePointerBody = "hello world pointer"
|
||||||
|
|
||||||
|
nsNamespace = "/router"
|
||||||
|
nsPath = "/user"
|
||||||
|
nsNamespacePath = "/router/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExampleController struct {
|
||||||
|
Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ExampleController) Ping() {
|
||||||
|
err := m.Ctx.Output.Body([]byte(exampleBody))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ExampleController) PingPointer() {
|
||||||
|
err := m.Ctx.Output.Body([]byte(examplePointerBody))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ExampleController) ping() {
|
||||||
|
err := m.Ctx.Output.Body([]byte("ping method"))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNamespaceGet(t *testing.T) {
|
func TestNamespaceGet(t *testing.T) {
|
||||||
r, _ := http.NewRequest("GET", "/v1/user", nil)
|
r, _ := http.NewRequest("GET", "/v1/user", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -166,3 +201,215 @@ func TestNamespaceInside(t *testing.T) {
|
|||||||
t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
|
t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterGet(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterGet(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterGet can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterPost(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterPost(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterPost can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterDelete(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterDelete(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterDelete can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterPut(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterPut(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterPut can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterHead(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterHead(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterHead can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterOptions(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterOptions(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterOptions can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterPatch(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterPatch(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterPatch can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceRouterAny(t *testing.T) {
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
ns.RouterAny(nsPath, ExampleController.Ping)
|
||||||
|
AddNamespace(ns)
|
||||||
|
|
||||||
|
for method := range HTTPMETHOD {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequest(method, nsNamespacePath, nil)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceRouterAny can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterGet(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterGet(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterGet can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterPost(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace("/router")
|
||||||
|
NSRouterPost(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterPost can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterDelete(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterDelete(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterDelete can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterPut(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterPut(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterPut can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterHead(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterHead(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterHead can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterOptions(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterOptions(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterOptions can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterPatch(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterPatch("/user", ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterPatch can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamespaceNSRouterAny(t *testing.T) {
|
||||||
|
ns := NewNamespace(nsNamespace)
|
||||||
|
NSRouterAny(nsPath, ExampleController.Ping)(ns)
|
||||||
|
AddNamespace(ns)
|
||||||
|
|
||||||
|
for method := range HTTPMETHOD {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequest(method, nsNamespacePath, nil)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestNamespaceNSRouterAny can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -118,12 +119,33 @@ type ControllerInfo struct {
|
|||||||
routerType int
|
routerType int
|
||||||
initialize func() ControllerInterface
|
initialize func() ControllerInterface
|
||||||
methodParams []*param.MethodParam
|
methodParams []*param.MethodParam
|
||||||
|
sessionOn bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ControllerOption func(*ControllerInfo)
|
||||||
|
|
||||||
func (c *ControllerInfo) GetPattern() string {
|
func (c *ControllerInfo) GetPattern() string {
|
||||||
return c.pattern
|
return c.pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
|
||||||
|
return func(c *ControllerInfo) {
|
||||||
|
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRouterSessionOn(sessionOn bool) ControllerOption {
|
||||||
|
return func(c *ControllerInfo) {
|
||||||
|
c.sessionOn = sessionOn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type filterChainConfig struct {
|
||||||
|
pattern string
|
||||||
|
chain FilterChain
|
||||||
|
opts []FilterOpt
|
||||||
|
}
|
||||||
|
|
||||||
// ControllerRegister containers registered router rules, controller handlers and filters.
|
// ControllerRegister containers registered router rules, controller handlers and filters.
|
||||||
type ControllerRegister struct {
|
type ControllerRegister struct {
|
||||||
routers map[string]*Tree
|
routers map[string]*Tree
|
||||||
@ -136,6 +158,9 @@ type ControllerRegister struct {
|
|||||||
// the filter created by FilterChain
|
// the filter created by FilterChain
|
||||||
chainRoot *FilterRouter
|
chainRoot *FilterRouter
|
||||||
|
|
||||||
|
// keep registered chain and build it when serve http
|
||||||
|
filterChains []filterChainConfig
|
||||||
|
|
||||||
cfg *Config
|
cfg *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,11 +181,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
filterChains: make([]filterChainConfig, 0, 4),
|
||||||
}
|
}
|
||||||
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
|
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init will be executed when HttpServer start running
|
||||||
|
func (p *ControllerRegister) Init() {
|
||||||
|
for i := len(p.filterChains) - 1; i >= 0; i-- {
|
||||||
|
fc := p.filterChains[i]
|
||||||
|
root := p.chainRoot
|
||||||
|
filterFunc := fc.chain(root.filterFunc)
|
||||||
|
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
|
||||||
|
p.chainRoot.next = root
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add controller handler and pattern rules to ControllerRegister.
|
// Add controller handler and pattern rules to ControllerRegister.
|
||||||
// usage:
|
// usage:
|
||||||
// default methods is the same name as method
|
// default methods is the same name as method
|
||||||
@ -171,15 +208,19 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
|
|||||||
// Add("/api/delete",&RestController{},"delete:DeleteFood")
|
// Add("/api/delete",&RestController{},"delete:DeleteFood")
|
||||||
// Add("/api",&RestController{},"get,post:ApiFunc"
|
// Add("/api",&RestController{},"get,post:ApiFunc"
|
||||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||||
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) {
|
||||||
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
|
p.addWithMethodParams(pattern, c, nil, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
|
func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string {
|
||||||
reflectVal := reflect.ValueOf(c)
|
reflectVal := reflect.ValueOf(c)
|
||||||
t := reflect.Indirect(reflectVal).Type()
|
t := reflect.Indirect(reflectVal).Type()
|
||||||
methods := make(map[string]string)
|
methods := make(map[string]string)
|
||||||
if len(mappingMethods) > 0 {
|
|
||||||
|
if len(mappingMethods) == 0 {
|
||||||
|
return methods
|
||||||
|
}
|
||||||
|
|
||||||
semi := strings.Split(mappingMethods[0], ";")
|
semi := strings.Split(mappingMethods[0], ";")
|
||||||
for _, v := range semi {
|
for _, v := range semi {
|
||||||
colon := strings.Split(v, ":")
|
colon := strings.Split(v, ":")
|
||||||
@ -188,24 +229,43 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
|
|||||||
}
|
}
|
||||||
comma := strings.Split(colon[0], ",")
|
comma := strings.Split(colon[0], ",")
|
||||||
for _, m := range comma {
|
for _, m := range comma {
|
||||||
if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
|
if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] {
|
||||||
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
|
|
||||||
methods[strings.ToUpper(m)] = colon[1]
|
|
||||||
} else {
|
|
||||||
panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
|
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
|
||||||
}
|
}
|
||||||
|
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
|
||||||
|
methods[strings.ToUpper(m)] = colon[1]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return methods
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) {
|
||||||
|
if len(route.methods) == 0 {
|
||||||
|
for m := range HTTPMETHOD {
|
||||||
|
p.addToRouter(m, route.pattern, route)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k := range route.methods {
|
||||||
|
if k != "*" {
|
||||||
|
p.addToRouter(k, route.pattern, route)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for m := range HTTPMETHOD {
|
||||||
|
p.addToRouter(m, route.pattern, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
route := &ControllerInfo{}
|
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) {
|
||||||
route.pattern = pattern
|
reflectVal := reflect.ValueOf(c)
|
||||||
route.methods = methods
|
t := reflect.Indirect(reflectVal).Type()
|
||||||
route.routerType = routerTypeBeego
|
|
||||||
route.controllerType = t
|
route := p.createBeegoRouter(t, pattern)
|
||||||
route.initialize = func() ControllerInterface {
|
route.initialize = func() ControllerInterface {
|
||||||
vc := reflect.New(route.controllerType)
|
vc := reflect.New(route.controllerType)
|
||||||
execController, ok := vc.Interface().(ControllerInterface)
|
execController, ok := vc.Interface().(ControllerInterface)
|
||||||
@ -229,23 +289,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
|
|||||||
|
|
||||||
return execController
|
return execController
|
||||||
}
|
}
|
||||||
|
|
||||||
route.methodParams = methodParams
|
route.methodParams = methodParams
|
||||||
if len(methods) == 0 {
|
for i := range opts {
|
||||||
for m := range HTTPMETHOD {
|
opts[i](route)
|
||||||
p.addToRouter(m, pattern, route)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for k := range methods {
|
|
||||||
if k == "*" {
|
|
||||||
for m := range HTTPMETHOD {
|
|
||||||
p.addToRouter(m, pattern, route)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
p.addToRouter(k, pattern, route)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
globalSessionOn := p.cfg.WebConfig.Session.SessionOn
|
||||||
|
if !globalSessionOn && route.sessionOn {
|
||||||
|
logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern)
|
||||||
|
route.sessionOn = globalSessionOn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.addRouterForMethod(route)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
|
func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
|
||||||
@ -273,7 +328,8 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
|
|||||||
for _, f := range a.Filters {
|
for _, f := range a.Filters {
|
||||||
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
|
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
|
||||||
}
|
}
|
||||||
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
|
|
||||||
|
p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -294,6 +350,261 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) {
|
|||||||
p.pool.Put(ctx)
|
p.pool.Put(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RouterGet add get method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterGet("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterGet(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodGet, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPost add post method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPost("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterPost(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodPost, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterHead add head method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterHead("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterHead(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodHead, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPut add put method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPut("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterPut(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodPut, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPatch add patch method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPatch("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterPatch(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodPatch, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterDelete add delete method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterDelete("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterDelete(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodDelete, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOptions add options method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterOptions("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterOptions(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod(http.MethodOptions, pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterAny add all method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterAny("/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) RouterAny(pattern string, f interface{}) {
|
||||||
|
p.AddRouterMethod("*", pattern, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRouterMethod add http method router
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// AddRouterMethod("get","/api/:id", MyController.Ping)
|
||||||
|
func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) {
|
||||||
|
httpMethod = p.getUpperMethodString(httpMethod)
|
||||||
|
ct, methodName := getReflectTypeAndMethod(f)
|
||||||
|
|
||||||
|
p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBeegoTypeRouter add beego type router
|
||||||
|
func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) {
|
||||||
|
route := p.createBeegoRouter(ct, pattern)
|
||||||
|
methods := p.getHttpMethodMapMethod(httpMethod, ctMethod)
|
||||||
|
route.methods = methods
|
||||||
|
|
||||||
|
p.addRouterForMethod(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createBeegoRouter create beego router base on reflect type and pattern
|
||||||
|
func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo {
|
||||||
|
route := &ControllerInfo{}
|
||||||
|
route.pattern = pattern
|
||||||
|
route.routerType = routerTypeBeego
|
||||||
|
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
|
||||||
|
route.controllerType = ct
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// createRestfulRouter create restful router with filter function and pattern
|
||||||
|
func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo {
|
||||||
|
route := &ControllerInfo{}
|
||||||
|
route.pattern = pattern
|
||||||
|
route.routerType = routerTypeRESTFul
|
||||||
|
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
|
||||||
|
route.runFunction = f
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// createHandlerRouter create handler router with handler and pattern
|
||||||
|
func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo {
|
||||||
|
route := &ControllerInfo{}
|
||||||
|
route.pattern = pattern
|
||||||
|
route.routerType = routerTypeHandler
|
||||||
|
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
|
||||||
|
route.handler = h
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will
|
||||||
|
// use http method as the controller method
|
||||||
|
func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string {
|
||||||
|
methods := make(map[string]string)
|
||||||
|
// not match-all sign, only add for the http method
|
||||||
|
if httpMethod != "*" {
|
||||||
|
|
||||||
|
if ctMethod == "" {
|
||||||
|
ctMethod = httpMethod
|
||||||
|
}
|
||||||
|
methods[httpMethod] = ctMethod
|
||||||
|
return methods
|
||||||
|
}
|
||||||
|
|
||||||
|
// add all http method
|
||||||
|
for val := range HTTPMETHOD {
|
||||||
|
if ctMethod == "" {
|
||||||
|
methods[val] = val
|
||||||
|
} else {
|
||||||
|
methods[val] = ctMethod
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return methods
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUpperMethodString get upper string of method, and panic if the method
|
||||||
|
// is not valid
|
||||||
|
func (p *ControllerRegister) getUpperMethodString(method string) string {
|
||||||
|
method = strings.ToUpper(method)
|
||||||
|
if method != "*" && !HTTPMETHOD[method] {
|
||||||
|
panic("not support http method: " + method)
|
||||||
|
}
|
||||||
|
return method
|
||||||
|
}
|
||||||
|
|
||||||
|
// get reflect controller type and method by controller method expression
|
||||||
|
func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) {
|
||||||
|
// check f is a function
|
||||||
|
funcType := reflect.TypeOf(f)
|
||||||
|
if funcType.Kind() != reflect.Func {
|
||||||
|
panic("not a method")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get function name
|
||||||
|
funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
|
||||||
|
if funcObj == nil {
|
||||||
|
panic("cannot find the method")
|
||||||
|
}
|
||||||
|
funcNameSli := strings.Split(funcObj.Name(), ".")
|
||||||
|
lFuncSli := len(funcNameSli)
|
||||||
|
if lFuncSli == 0 {
|
||||||
|
panic("invalid method full name: " + funcObj.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
method = funcNameSli[lFuncSli-1]
|
||||||
|
if len(method) == 0 {
|
||||||
|
panic("method name is empty")
|
||||||
|
} else if method[0] > 96 || method[0] < 65 {
|
||||||
|
panic(fmt.Sprintf("%s is not a public method", method))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check only one param which is the method receiver
|
||||||
|
if numIn := funcType.NumIn(); numIn != 1 {
|
||||||
|
panic("invalid number of param in")
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerType = funcType.In(0)
|
||||||
|
|
||||||
|
// check controller has the method
|
||||||
|
_, exists := controllerType.MethodByName(method)
|
||||||
|
if !exists {
|
||||||
|
panic(controllerType.String() + " has no method " + method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the receiver implement ControllerInterface
|
||||||
|
if controllerType.Kind() == reflect.Ptr {
|
||||||
|
controllerType = controllerType.Elem()
|
||||||
|
}
|
||||||
|
controller := reflect.New(controllerType)
|
||||||
|
_, ok := controller.Interface().(ControllerInterface)
|
||||||
|
if !ok {
|
||||||
|
panic(controllerType.String() + " is not implemented ControllerInterface")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Get add get method
|
// Get add get method
|
||||||
// usage:
|
// usage:
|
||||||
// Get("/", func(ctx *context.Context){
|
// Get("/", func(ctx *context.Context){
|
||||||
@ -372,34 +683,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
|
|||||||
// ctx.Output.Body("hello world")
|
// ctx.Output.Body("hello world")
|
||||||
// })
|
// })
|
||||||
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
|
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
|
||||||
method = strings.ToUpper(method)
|
method = p.getUpperMethodString(method)
|
||||||
if method != "*" && !HTTPMETHOD[method] {
|
|
||||||
panic("not support http method: " + method)
|
route := p.createRestfulRouter(f, pattern)
|
||||||
}
|
methods := p.getHttpMethodMapMethod(method, "")
|
||||||
route := &ControllerInfo{}
|
|
||||||
route.pattern = pattern
|
|
||||||
route.routerType = routerTypeRESTFul
|
|
||||||
route.runFunction = f
|
|
||||||
methods := make(map[string]string)
|
|
||||||
if method == "*" {
|
|
||||||
for val := range HTTPMETHOD {
|
|
||||||
methods[val] = val
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
methods[method] = method
|
|
||||||
}
|
|
||||||
route.methods = methods
|
route.methods = methods
|
||||||
for k := range methods {
|
|
||||||
p.addToRouter(k, pattern, route)
|
p.addRouterForMethod(route)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler add user defined Handler
|
// Handler add user defined Handler
|
||||||
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
|
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
|
||||||
route := &ControllerInfo{}
|
route := p.createHandlerRouter(h, pattern)
|
||||||
route.pattern = pattern
|
|
||||||
route.routerType = routerTypeHandler
|
|
||||||
route.handler = h
|
|
||||||
if len(options) > 0 {
|
if len(options) > 0 {
|
||||||
if _, ok := options[0].(bool); ok {
|
if _, ok := options[0].(bool); ok {
|
||||||
pattern = path.Join(pattern, "?:all(.*)")
|
pattern = path.Join(pattern, "?:all(.*)")
|
||||||
@ -431,15 +726,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
|
|||||||
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
|
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
|
||||||
for i := 0; i < rt.NumMethod(); i++ {
|
for i := 0; i < rt.NumMethod(); i++ {
|
||||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||||
route := &ControllerInfo{}
|
|
||||||
route.routerType = routerTypeBeego
|
|
||||||
route.methods = map[string]string{"*": rt.Method(i).Name}
|
|
||||||
route.controllerType = ct
|
|
||||||
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
|
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
|
||||||
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
|
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
|
||||||
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
|
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
|
||||||
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
|
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
|
||||||
route.pattern = pattern
|
|
||||||
|
route := p.createBeegoRouter(ct, pattern)
|
||||||
|
route.methods = map[string]string{"*": rt.Method(i).Name}
|
||||||
for m := range HTTPMETHOD {
|
for m := range HTTPMETHOD {
|
||||||
p.addToRouter(m, pattern, route)
|
p.addToRouter(m, pattern, route)
|
||||||
p.addToRouter(m, patternInit, route)
|
p.addToRouter(m, patternInit, route)
|
||||||
@ -472,12 +765,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
|
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
|
||||||
root := p.chainRoot
|
|
||||||
filterFunc := chain(root.filterFunc)
|
|
||||||
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
|
|
||||||
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
|
|
||||||
p.chainRoot.next = root
|
|
||||||
|
|
||||||
|
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
|
||||||
|
p.filterChains = append(p.filterChains, filterChainConfig{
|
||||||
|
pattern: pattern,
|
||||||
|
chain: chain,
|
||||||
|
opts: opts,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// add Filter into
|
// add Filter into
|
||||||
@ -542,7 +836,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
|
|||||||
for _, l := range t.leaves {
|
for _, l := range t.leaves {
|
||||||
if c, ok := l.runObject.(*ControllerInfo); ok {
|
if c, ok := l.runObject.(*ControllerInfo); ok {
|
||||||
if c.routerType == routerTypeBeego &&
|
if c.routerType == routerTypeBeego &&
|
||||||
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
|
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) {
|
||||||
find := false
|
find := false
|
||||||
if HTTPMETHOD[strings.ToUpper(methodName)] {
|
if HTTPMETHOD[strings.ToUpper(methodName)] {
|
||||||
if len(c.methods) == 0 {
|
if len(c.methods) == 0 {
|
||||||
@ -670,6 +964,9 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
|
|||||||
methodParams []*param.MethodParam
|
methodParams []*param.MethodParam
|
||||||
routerInfo *ControllerInfo
|
routerInfo *ControllerInfo
|
||||||
isRunnable bool
|
isRunnable bool
|
||||||
|
currentSessionOn bool
|
||||||
|
originRouterInfo *ControllerInfo
|
||||||
|
originFindRouter bool
|
||||||
)
|
)
|
||||||
|
|
||||||
if p.cfg.RecoverFunc != nil {
|
if p.cfg.RecoverFunc != nil {
|
||||||
@ -735,7 +1032,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// session init
|
// session init
|
||||||
if p.cfg.WebConfig.Session.SessionOn {
|
currentSessionOn = p.cfg.WebConfig.Session.SessionOn
|
||||||
|
originRouterInfo, originFindRouter = p.FindRouter(ctx)
|
||||||
|
if originFindRouter {
|
||||||
|
currentSessionOn = originRouterInfo.sessionOn
|
||||||
|
}
|
||||||
|
if currentSessionOn {
|
||||||
ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
|
ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ package web
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -26,6 +27,25 @@ import (
|
|||||||
"github.com/beego/beego/v2/server/web/context"
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PrefixTestController struct {
|
||||||
|
Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ptc *PrefixTestController) PrefixList() {
|
||||||
|
ptc.Ctx.Output.Body([]byte("i am list in prefix test"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestControllerWithInterface struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m TestControllerWithInterface) Ping() {
|
||||||
|
fmt.Println("pong")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TestControllerWithInterface) PingPointer() {
|
||||||
|
fmt.Println("pong pointer")
|
||||||
|
}
|
||||||
|
|
||||||
type TestController struct {
|
type TestController struct {
|
||||||
Controller
|
Controller
|
||||||
}
|
}
|
||||||
@ -87,10 +107,24 @@ func (jc *JSONController) Get() {
|
|||||||
jc.Ctx.Output.Body([]byte("ok"))
|
jc.Ctx.Output.Body([]byte("ok"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrefixUrlFor(t *testing.T) {
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList"))
|
||||||
|
|
||||||
|
if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` {
|
||||||
|
logs.Info(a)
|
||||||
|
t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list")
|
||||||
|
}
|
||||||
|
if a := handler.URLFor(`TestController.PrefixList`); a != `` {
|
||||||
|
logs.Info(a)
|
||||||
|
t.Errorf("TestController.PrefixList must equal to empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUrlFor(t *testing.T) {
|
func TestUrlFor(t *testing.T) {
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/api/list", &TestController{}, "*:List")
|
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
|
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
|
||||||
if a := handler.URLFor("TestController.List"); a != "/api/list" {
|
if a := handler.URLFor("TestController.List"); a != "/api/list" {
|
||||||
logs.Info(a)
|
logs.Info(a)
|
||||||
t.Errorf("TestController.List must equal to /api/list")
|
t.Errorf("TestController.List must equal to /api/list")
|
||||||
@ -113,9 +147,9 @@ func TestUrlFor3(t *testing.T) {
|
|||||||
|
|
||||||
func TestUrlFor2(t *testing.T) {
|
func TestUrlFor2(t *testing.T) {
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
|
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
|
handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL"))
|
||||||
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
|
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
|
||||||
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
|
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
|
||||||
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
|
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
|
||||||
logs.Info(handler.URLFor("TestController.GetURL"))
|
logs.Info(handler.URLFor("TestController.GetURL"))
|
||||||
@ -145,7 +179,7 @@ func TestUserFunc(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/api/list", &TestController{}, "*:List")
|
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Body.String() != "i am list" {
|
if w.Body.String() != "i am list" {
|
||||||
t.Errorf("user define func can't run")
|
t.Errorf("user define func can't run")
|
||||||
@ -235,7 +269,7 @@ func TestRouteOk(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/person/:last/:first", &TestController{}, "get:GetParams")
|
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
if body != "anderson+thomas+kungfu" {
|
if body != "anderson+thomas+kungfu" {
|
||||||
@ -249,7 +283,7 @@ func TestManyRoute(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter")
|
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
@ -266,7 +300,7 @@ func TestEmptyResponse(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody")
|
handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
if body := w.Body.String(); body != "" {
|
if body := w.Body.String(); body != "" {
|
||||||
@ -750,3 +784,321 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) {
|
|||||||
t.Errorf("TestRouterRequestEntityTooLarge can't run")
|
t.Errorf("TestRouterRequestEntityTooLarge can't run")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouterSessionSet(t *testing.T) {
|
||||||
|
oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn
|
||||||
|
defer func() {
|
||||||
|
BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn
|
||||||
|
}()
|
||||||
|
|
||||||
|
// global sessionOn = false, router sessionOn = false
|
||||||
|
r, _ := http.NewRequest("GET", "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
|
WithRouterSessionOn(false))
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// global sessionOn = false, router sessionOn = true
|
||||||
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler = NewControllerRegister()
|
||||||
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
|
WithRouterSessionOn(true))
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
BConfig.WebConfig.Session.SessionOn = true
|
||||||
|
if err := registerSession(); err != nil {
|
||||||
|
t.Errorf("register session failed, error: %s", err.Error())
|
||||||
|
}
|
||||||
|
// global sessionOn = true, router sessionOn = false
|
||||||
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler = NewControllerRegister()
|
||||||
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
|
WithRouterSessionOn(false))
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// global sessionOn = true, router sessionOn = true
|
||||||
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
handler = NewControllerRegister()
|
||||||
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
|
WithRouterSessionOn(true))
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Header().Get("Set-Cookie") == "" {
|
||||||
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterGet(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterGet("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterGet can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPost(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPost("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterPost can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterHead(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterHead("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterHead can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPut(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPut("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterPut can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPatch(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPatch("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterPatch can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterDelete(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterDelete("/user", ExampleController.Ping)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterDelete can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterAny(t *testing.T) {
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterAny("/user", ExampleController.Ping)
|
||||||
|
|
||||||
|
for method := range HTTPMETHOD {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequest(method, "/user", nil)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestRouterRouterAny can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterGetPointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterGet("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterGetPointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPostPointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPost("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterPostPointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterHeadPointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterHead("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterHeadPointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPutPointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPut("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterPutPointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterPatchPointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterPatch("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterPatchPointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterDeletePointerMethod(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterDelete("/user", (*ExampleController).PingPointer)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterDeletePointerMethod can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterRouterAnyPointerMethod(t *testing.T) {
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.RouterAny("/user", (*ExampleController).PingPointer)
|
||||||
|
|
||||||
|
for method := range HTTPMETHOD {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequest(method, "/user", nil)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != examplePointerBody {
|
||||||
|
t.Errorf("TestRouterRouterAnyPointerMethod can't run, get the response is " + w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) {
|
||||||
|
method := "some random method"
|
||||||
|
message := "not support http method: " + strings.ToUpper(method)
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil { //产生了panic异常
|
||||||
|
errStr, ok := err.(string)
|
||||||
|
if ok && errStr == message {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.AddRouterMethod(method, "/user", ExampleController.Ping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) {
|
||||||
|
method := http.MethodGet
|
||||||
|
message := "not a method"
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil { //产生了panic异常
|
||||||
|
errStr, ok := err.(string)
|
||||||
|
if ok && errStr == message {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.AddRouterMethod(method, "/user", ExampleController{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) {
|
||||||
|
method := http.MethodGet
|
||||||
|
message := "ping is not a public method"
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil { //产生了panic异常
|
||||||
|
errStr, ok := err.(string)
|
||||||
|
if ok && errStr == message {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.AddRouterMethod(method, "/user", ExampleController.ping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) {
|
||||||
|
method := http.MethodGet
|
||||||
|
message := "web.TestControllerWithInterface is not implemented ControllerInterface"
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil { //产生了panic异常
|
||||||
|
errStr, ok := err.(string)
|
||||||
|
if ok && errStr == message {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) {
|
||||||
|
method := http.MethodGet
|
||||||
|
message := "web.TestControllerWithInterface is not implemented ControllerInterface"
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil { //产生了panic异常
|
||||||
|
errStr, ok := err.(string)
|
||||||
|
if ok && errStr == message {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err))
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer)
|
||||||
|
}
|
||||||
|
|||||||
@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
|||||||
|
|
||||||
initBeforeHTTPRun()
|
initBeforeHTTPRun()
|
||||||
|
|
||||||
|
// init...
|
||||||
app.initAddr(addr)
|
app.initAddr(addr)
|
||||||
|
app.Handlers.Init()
|
||||||
|
|
||||||
addr = app.Cfg.Listen.HTTPAddr
|
addr = app.Cfg.Listen.HTTPAddr
|
||||||
|
|
||||||
@ -267,7 +269,11 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
|||||||
|
|
||||||
// Router see HttpServer.Router
|
// Router see HttpServer.Router
|
||||||
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
||||||
return BeeApp.Router(rootpath, c, mappingMethods...)
|
return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
|
||||||
|
return BeeApp.RouterWithOpts(rootpath, c, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Router adds a patterned controller handler to BeeApp.
|
// Router adds a patterned controller handler to BeeApp.
|
||||||
@ -287,7 +293,11 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H
|
|||||||
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
||||||
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
||||||
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
||||||
app.Handlers.Add(rootPath, c, mappingMethods...)
|
return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
|
||||||
|
app.Handlers.Add(rootPath, c, opts...)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -453,6 +463,166 @@ func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpSer
|
|||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RouterGet see HttpServer.RouterGet
|
||||||
|
func RouterGet(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterGet(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterGet used to register router for RouterGet method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterGet("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterGet(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterGet(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPost see HttpServer.RouterGet
|
||||||
|
func RouterPost(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterPost(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPost used to register router for RouterPost method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPost("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterPost(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterPost(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterHead see HttpServer.RouterHead
|
||||||
|
func RouterHead(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterHead(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterHead used to register router for RouterHead method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterHead("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterHead(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterHead(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPut see HttpServer.RouterPut
|
||||||
|
func RouterPut(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterPut(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPut used to register router for RouterPut method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPut("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterPut(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterPut(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPatch see HttpServer.RouterPatch
|
||||||
|
func RouterPatch(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterPatch(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterPatch used to register router for RouterPatch method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterPatch("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterPatch(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterPatch(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterDelete see HttpServer.RouterDelete
|
||||||
|
func RouterDelete(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterDelete(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterDelete used to register router for RouterDelete method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterDelete("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterDelete(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterDelete(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOptions see HttpServer.RouterOptions
|
||||||
|
func RouterOptions(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterOptions(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOptions used to register router for RouterOptions method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterOptions("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterOptions(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterOptions(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterAny see HttpServer.RouterAny
|
||||||
|
func RouterAny(rootpath string, f interface{}) {
|
||||||
|
BeeApp.RouterAny(rootpath, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterAny used to register router for RouterAny method
|
||||||
|
// usage:
|
||||||
|
// type MyController struct {
|
||||||
|
// web.Controller
|
||||||
|
// }
|
||||||
|
// func (m MyController) Ping() {
|
||||||
|
// m.Ctx.Output.Body([]byte("hello world"))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// RouterAny("/api/:id", MyController.Ping)
|
||||||
|
func (app *HttpServer) RouterAny(rootpath string, f interface{}) *HttpServer {
|
||||||
|
app.Handlers.RouterAny(rootpath, f)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
// Get see HttpServer.Get
|
// Get see HttpServer.Get
|
||||||
func Get(rootpath string, f FilterFunc) *HttpServer {
|
func Get(rootpath string, f FilterFunc) *HttpServer {
|
||||||
return BeeApp.Get(rootpath, f)
|
return BeeApp.Get(rootpath, f)
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -28,3 +30,82 @@ func TestNewHttpServerWithCfg(t *testing.T) {
|
|||||||
assert.Equal(t, "hello", BConfig.AppName)
|
assert.Equal(t, "hello", BConfig.AppName)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerRouterGet(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterGet("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterGet can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterPost(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterPost("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterPost can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterHead(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterHead("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterHead can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterPut(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterPut("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterPut can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterPatch(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterPatch("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterPatch can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterDelete(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
RouterDelete("/user", ExampleController.Ping)
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterDelete can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouterAny(t *testing.T) {
|
||||||
|
RouterAny("/user", ExampleController.Ping)
|
||||||
|
|
||||||
|
for method := range HTTPMETHOD {
|
||||||
|
r, _ := http.NewRequest(method, "/user", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != exampleBody {
|
||||||
|
t.Errorf("TestServerRouterAny can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -15,21 +15,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRedis(t *testing.T) {
|
func TestRedis(t *testing.T) {
|
||||||
sessionConfig := &session.ManagerConfig{
|
|
||||||
CookieName: "gosessionid",
|
|
||||||
EnableSetCookie: true,
|
|
||||||
Gclifetime: 3600,
|
|
||||||
Maxlifetime: 3600,
|
|
||||||
Secure: false,
|
|
||||||
CookieLifeTime: 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
redisAddr := os.Getenv("REDIS_ADDR")
|
redisAddr := os.Getenv("REDIS_ADDR")
|
||||||
if redisAddr == "" {
|
if redisAddr == "" {
|
||||||
redisAddr = "127.0.0.1:6379"
|
redisAddr = "127.0.0.1:6379"
|
||||||
}
|
}
|
||||||
|
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
|
||||||
|
|
||||||
|
sessionConfig := session.NewManagerConfig(
|
||||||
|
session.CfgCookieName(`gosessionid`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
session.CfgProviderConfig(redisConfig),
|
||||||
|
)
|
||||||
|
|
||||||
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
|
|
||||||
globalSession, err := session.NewManager("redis", sessionConfig)
|
globalSession, err := session.NewManager("redis", sessionConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("could not create manager:", err)
|
t.Fatal("could not create manager:", err)
|
||||||
|
|||||||
@ -13,15 +13,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRedisSentinel(t *testing.T) {
|
func TestRedisSentinel(t *testing.T) {
|
||||||
sessionConfig := &session.ManagerConfig{
|
sessionConfig := session.NewManagerConfig(
|
||||||
CookieName: "gosessionid",
|
session.CfgCookieName(`gosessionid`),
|
||||||
EnableSetCookie: true,
|
session.CfgSetCookie(true),
|
||||||
Gclifetime: 3600,
|
session.CfgGcLifeTime(3600),
|
||||||
Maxlifetime: 3600,
|
session.CfgMaxLifeTime(3600),
|
||||||
Secure: false,
|
session.CfgSecure(false),
|
||||||
CookieLifeTime: 3600,
|
session.CfgCookieLifeTime(3600),
|
||||||
ProviderConfig: "127.0.0.1:6379,100,,0,master",
|
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
|
||||||
}
|
)
|
||||||
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
|
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
t.Log(e)
|
t.Log(e)
|
||||||
|
|||||||
@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) {
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerConfig define the session config
|
|
||||||
type ManagerConfig struct {
|
|
||||||
CookieName string `json:"cookieName"`
|
|
||||||
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
|
||||||
Gclifetime int64 `json:"gclifetime"`
|
|
||||||
Maxlifetime int64 `json:"maxLifetime"`
|
|
||||||
DisableHTTPOnly bool `json:"disableHTTPOnly"`
|
|
||||||
Secure bool `json:"secure"`
|
|
||||||
CookieLifeTime int `json:"cookieLifeTime"`
|
|
||||||
ProviderConfig string `json:"providerConfig"`
|
|
||||||
Domain string `json:"domain"`
|
|
||||||
SessionIDLength int64 `json:"sessionIDLength"`
|
|
||||||
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
|
|
||||||
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
|
|
||||||
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
|
|
||||||
SessionIDPrefix string `json:"sessionIDPrefix"`
|
|
||||||
CookieSameSite http.SameSite `json:"cookieSameSite"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager contains Provider and its configuration.
|
// Manager contains Provider and its configuration.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
provider Provider
|
provider Provider
|
||||||
|
|||||||
143
server/web/session/session_config.go
Normal file
143
server/web/session/session_config.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// ManagerConfig define the session config
|
||||||
|
type ManagerConfig struct {
|
||||||
|
CookieName string `json:"cookieName"`
|
||||||
|
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
||||||
|
Gclifetime int64 `json:"gclifetime"`
|
||||||
|
Maxlifetime int64 `json:"maxLifetime"`
|
||||||
|
DisableHTTPOnly bool `json:"disableHTTPOnly"`
|
||||||
|
Secure bool `json:"secure"`
|
||||||
|
CookieLifeTime int `json:"cookieLifeTime"`
|
||||||
|
ProviderConfig string `json:"providerConfig"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
SessionIDLength int64 `json:"sessionIDLength"`
|
||||||
|
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
|
||||||
|
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
|
||||||
|
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
|
||||||
|
SessionIDPrefix string `json:"sessionIDPrefix"`
|
||||||
|
CookieSameSite http.SameSite `json:"cookieSameSite"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) {
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagerConfigOpt func(config *ManagerConfig)
|
||||||
|
|
||||||
|
func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig {
|
||||||
|
config := &ManagerConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(config)
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgCookieName set key of session id
|
||||||
|
func CfgCookieName(cookieName string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieName = cookieName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgCookieName set len of session id
|
||||||
|
func CfgSessionIdLength(len int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionIDLength = len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgSessionIdPrefix set prefix of session id
|
||||||
|
func CfgSessionIdPrefix(prefix string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionIDPrefix = prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSetCookie whether set `Set-Cookie` header in HTTP response
|
||||||
|
func CfgSetCookie(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSetCookie = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgGcLifeTime set session gc lift time
|
||||||
|
func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Gclifetime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgMaxLifeTime set session lift time
|
||||||
|
func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Maxlifetime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgGcLifeTime set session lift time
|
||||||
|
func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieLifeTime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgProviderConfig configure session provider
|
||||||
|
func CfgProviderConfig(providerConfig string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.ProviderConfig = providerConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgDomain set cookie domain
|
||||||
|
func CfgDomain(domain string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Domain = domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSessionIdInHTTPHeader enable session id in http header
|
||||||
|
func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSidInHTTPHeader = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSetSessionNameInHTTPHeader set key of session id in http header
|
||||||
|
func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionNameInHTTPHeader = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//EnableSidInURLQuery enable session id in query string
|
||||||
|
func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSidInURLQuery = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//DisableHTTPOnly set HTTPOnly for http.Cookie
|
||||||
|
func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.DisableHTTPOnly = !HTTPOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSecure set Secure for http.Cookie
|
||||||
|
func CfgSecure(Enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Secure = Enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSameSite set http.SameSite
|
||||||
|
func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieSameSite = sameSite
|
||||||
|
}
|
||||||
|
}
|
||||||
222
server/web/session/session_config_test.go
Normal file
222
server/web/session/session_config_test.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCfgCookieLifeTime(t *testing.T) {
|
||||||
|
value := 8754
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgCookieLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieLifeTime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgDomain(t *testing.T) {
|
||||||
|
value := `http://domain.com`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgDomain(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Domain != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSameSite(t *testing.T) {
|
||||||
|
value := http.SameSiteLaxMode
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSameSite(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieSameSite != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSecure(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSecure(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Secure != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSecure1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSecure(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Secure != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdPrefix(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdPrefix(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionIDPrefix != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetSessionNameInHTTPHeader(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetSessionNameInHTTPHeader(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionNameInHTTPHeader != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgCookieName(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgCookieName(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieName != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgEnableSidInURLQuery(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgEnableSidInURLQuery(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInURLQuery != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgGcLifeTime(t *testing.T) {
|
||||||
|
value := int64(5454)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgGcLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Gclifetime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgHTTPOnly(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgHTTPOnly(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.DisableHTTPOnly != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgHTTPOnly2(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgHTTPOnly(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.DisableHTTPOnly != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgMaxLifeTime(t *testing.T) {
|
||||||
|
value := int64(5454)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgMaxLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Maxlifetime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgProviderConfig(t *testing.T) {
|
||||||
|
value := `asodiuasldkj12i39809as`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgProviderConfig(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.ProviderConfig != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdInHTTPHeader(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdInHTTPHeader(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInHTTPHeader != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdInHTTPHeader1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdInHTTPHeader(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInHTTPHeader != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdLength(t *testing.T) {
|
||||||
|
value := int64(100)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdLength(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionIDLength != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetCookie(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetCookie(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSetCookie != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetCookie1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetCookie(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSetCookie != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewManagerConfig(t *testing.T) {
|
||||||
|
c := NewManagerConfig()
|
||||||
|
if c == nil {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerConfig_Opts(t *testing.T) {
|
||||||
|
c := NewManagerConfig()
|
||||||
|
c.Opts(CfgSetCookie(true))
|
||||||
|
|
||||||
|
if c.EnableSetCookie != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
18
server/web/session/session_provider_type.go
Normal file
18
server/web/session/session_provider_type.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
type ProviderType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProviderCookie ProviderType = `cookie`
|
||||||
|
ProviderFile ProviderType = `file`
|
||||||
|
ProviderMemory ProviderType = `memory`
|
||||||
|
ProviderCouchbase ProviderType = `couchbase`
|
||||||
|
ProviderLedis ProviderType = `ledis`
|
||||||
|
ProviderMemcache ProviderType = `memcache`
|
||||||
|
ProviderMysql ProviderType = `mysql`
|
||||||
|
ProviderPostgresql ProviderType = `postgresql`
|
||||||
|
ProviderRedis ProviderType = `redis`
|
||||||
|
ProviderRedisCluster ProviderType = `redis_cluster`
|
||||||
|
ProviderRedisSentinel ProviderType = `redis_sentinel`
|
||||||
|
ProviderSsdb ProviderType = `ssdb`
|
||||||
|
)
|
||||||
@ -210,9 +210,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) {
|
|||||||
func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) {
|
func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) {
|
||||||
if len(segments) == 0 {
|
if len(segments) == 0 {
|
||||||
if reg != "" {
|
if reg != "" {
|
||||||
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")})
|
t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...)
|
||||||
} else {
|
} else {
|
||||||
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards})
|
t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
seg := segments[0]
|
seg := segments[0]
|
||||||
|
|||||||
@ -90,7 +90,17 @@ func init() {
|
|||||||
routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}))
|
routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}))
|
||||||
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}))
|
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}))
|
||||||
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}))
|
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}))
|
||||||
|
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"}))
|
||||||
// not match example
|
// not match example
|
||||||
|
|
||||||
// https://github.com/beego/beego/v2/issues/3865
|
// https://github.com/beego/beego/v2/issues/3865
|
||||||
|
|||||||
@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler, "Test original root",
|
testHelperFnContentCheck(t, handler, "Test original root",
|
||||||
@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the root path TestPreUnregController action with the action from
|
// Replace the root path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot")
|
handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
|
||||||
|
|
||||||
// Test replacement root (expect change)
|
// Test replacement root (expect change)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
|
||||||
@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler,
|
testHelperFnContentCheck(t, handler,
|
||||||
@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the "level1" path TestPreUnregController action with the action from
|
// Replace the "level1" path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1")
|
handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
|
||||||
|
|
||||||
// Test replacement root (expect no change from the original)
|
// Test replacement root (expect no change from the original)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
||||||
@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler,
|
testHelperFnContentCheck(t, handler,
|
||||||
@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the "/level1/level2" path TestPreUnregController action with the action from
|
// Replace the "/level1/level2" path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2")
|
handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test replacement root (expect no change from the original)
|
// Test replacement root (expect no change from the original)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
||||||
|
|||||||
7
sonar-project.properties
Normal file
7
sonar-project.properties
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
sonar.organization=beego
|
||||||
|
sonar.projectKey=beego_beego
|
||||||
|
|
||||||
|
# relative paths to source directories. More details and properties are described
|
||||||
|
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
|
||||||
|
sonar.sources=.
|
||||||
|
sonar.exclusions=**/*_test.go
|
||||||
@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time {
|
|||||||
return time.Now()
|
return time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *countTask) GetTimeout(ctx context.Context) time.Duration {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunTaskCommand_Execute(t *testing.T) {
|
func TestRunTaskCommand_Execute(t *testing.T) {
|
||||||
task := &countTask{}
|
task := &countTask{}
|
||||||
AddTask("count", task)
|
AddTask("count", task)
|
||||||
|
|||||||
112
task/task.go
112
task/task.go
@ -109,6 +109,7 @@ type Tasker interface {
|
|||||||
GetNext(ctx context.Context) time.Time
|
GetNext(ctx context.Context) time.Time
|
||||||
SetPrev(context.Context, time.Time)
|
SetPrev(context.Context, time.Time)
|
||||||
GetPrev(ctx context.Context) time.Time
|
GetPrev(ctx context.Context) time.Time
|
||||||
|
GetTimeout(ctx context.Context) time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// task error
|
// task error
|
||||||
@ -127,13 +128,14 @@ type Task struct {
|
|||||||
DoFunc TaskFunc
|
DoFunc TaskFunc
|
||||||
Prev time.Time
|
Prev time.Time
|
||||||
Next time.Time
|
Next time.Time
|
||||||
|
Timeout time.Duration // timeout duration
|
||||||
Errlist []*taskerr // like errtime:errinfo
|
Errlist []*taskerr // like errtime:errinfo
|
||||||
ErrLimit int // max length for the errlist, 0 stand for no limit
|
ErrLimit int // max length for the errlist, 0 stand for no limit
|
||||||
errCnt int // records the error count during the execution
|
errCnt int // records the error count during the execution
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTask add new task with name, time and func
|
// NewTask add new task with name, time and func
|
||||||
func NewTask(tname string, spec string, f TaskFunc) *Task {
|
func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task {
|
||||||
|
|
||||||
task := &Task{
|
task := &Task{
|
||||||
Taskname: tname,
|
Taskname: tname,
|
||||||
@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task {
|
|||||||
// we only store the pointer, so it won't use too many space
|
// we only store the pointer, so it won't use too many space
|
||||||
Errlist: make([]*taskerr, 100, 100),
|
Errlist: make([]*taskerr, 100, 100),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.apply(task)
|
||||||
|
}
|
||||||
|
|
||||||
task.SetCron(spec)
|
task.SetCron(spec)
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time {
|
|||||||
return t.Prev
|
return t.Prev
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTimeout get timeout duration of this task
|
||||||
|
func (t *Task) GetTimeout(context.Context) time.Duration {
|
||||||
|
return t.Timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option interface
|
||||||
|
type Option interface {
|
||||||
|
apply(*Task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionFunc return a function to set task element
|
||||||
|
type optionFunc func(*Task)
|
||||||
|
|
||||||
|
// apply option to task
|
||||||
|
func (f optionFunc) apply(t *Task) {
|
||||||
|
f(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeoutOption return a option to set timeout duration for task
|
||||||
|
func TimeoutOption(timeout time.Duration) Option {
|
||||||
|
return optionFunc(func(t *Task) {
|
||||||
|
t.Timeout = timeout
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// six columns mean:
|
// six columns mean:
|
||||||
// second:0-59
|
// second:0-59
|
||||||
// minute:0-59
|
// minute:0-59
|
||||||
@ -455,14 +487,12 @@ func (m *taskManager) StartTask() {
|
|||||||
|
|
||||||
func (m *taskManager) run() {
|
func (m *taskManager) run() {
|
||||||
now := time.Now().Local()
|
now := time.Now().Local()
|
||||||
m.taskLock.Lock()
|
// first run the tasks, so set all tasks next run time.
|
||||||
for _, t := range m.adminTaskList {
|
m.setTasksStartTime(now)
|
||||||
t.SetNext(nil, now)
|
|
||||||
}
|
|
||||||
m.taskLock.Unlock()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// we only use RLock here because NewMapSorter copy the reference, do not change any thing
|
// we only use RLock here because NewMapSorter copy the reference, do not change any thing
|
||||||
|
// here, we sort all task and get first task running time (effective).
|
||||||
m.taskLock.RLock()
|
m.taskLock.RLock()
|
||||||
sortList := NewMapSorter(m.adminTaskList)
|
sortList := NewMapSorter(m.adminTaskList)
|
||||||
m.taskLock.RUnlock()
|
m.taskLock.RUnlock()
|
||||||
@ -475,34 +505,72 @@ func (m *taskManager) run() {
|
|||||||
} else {
|
} else {
|
||||||
effective = sortList.Vals[0].GetNext(context.Background())
|
effective = sortList.Vals[0].GetNext(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case now = <-time.After(effective.Sub(now)):
|
case now = <-time.After(effective.Sub(now)): // wait for effective time
|
||||||
// Run every entry whose next time was this effective time.
|
runNextTasks(sortList, effective)
|
||||||
for _, e := range sortList.Vals {
|
|
||||||
if e.GetNext(context.Background()) != effective {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
go e.Run(nil)
|
|
||||||
e.SetPrev(context.Background(), e.GetNext(context.Background()))
|
|
||||||
e.SetNext(nil, effective)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
case <-m.changed:
|
case <-m.changed: // tasks have been changed, set all tasks run again now
|
||||||
now = time.Now().Local()
|
now = time.Now().Local()
|
||||||
|
m.setTasksStartTime(now)
|
||||||
|
continue
|
||||||
|
case <-m.stop: // manager is stopped, and mark manager is stopped
|
||||||
|
m.markManagerStop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setTasksStartTime is set all tasks next running time
|
||||||
|
func (m *taskManager) setTasksStartTime(now time.Time) {
|
||||||
m.taskLock.Lock()
|
m.taskLock.Lock()
|
||||||
for _, t := range m.adminTaskList {
|
for _, task := range m.adminTaskList {
|
||||||
t.SetNext(nil, now)
|
task.SetNext(context.Background(), now)
|
||||||
}
|
}
|
||||||
m.taskLock.Unlock()
|
m.taskLock.Unlock()
|
||||||
continue
|
}
|
||||||
case <-m.stop:
|
|
||||||
|
// markManagerStop it sets manager to be stopped
|
||||||
|
func (m *taskManager) markManagerStop() {
|
||||||
m.taskLock.Lock()
|
m.taskLock.Lock()
|
||||||
if m.started {
|
if m.started {
|
||||||
m.started = false
|
m.started = false
|
||||||
}
|
}
|
||||||
m.taskLock.Unlock()
|
m.taskLock.Unlock()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runNextTasks it runs next task which next run time is equal to effective
|
||||||
|
func runNextTasks(sortList *MapSorter, effective time.Time) {
|
||||||
|
// Run every entry whose next time was this effective time.
|
||||||
|
var i = 0
|
||||||
|
for _, e := range sortList.Vals {
|
||||||
|
i++
|
||||||
|
if e.GetNext(context.Background()) != effective {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if timeout is on, if yes passing the timeout context
|
||||||
|
ctx := context.Background()
|
||||||
|
if duration := e.GetTimeout(ctx); duration != 0 {
|
||||||
|
go func(e Tasker) {
|
||||||
|
ctx, cancelFunc := context.WithTimeout(ctx, duration)
|
||||||
|
defer cancelFunc()
|
||||||
|
err := e.Run(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("tasker.run err: %s\n", err.Error())
|
||||||
|
}
|
||||||
|
}(e)
|
||||||
|
} else {
|
||||||
|
go func(e Tasker) {
|
||||||
|
err := e.Run(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("tasker.run err: %s\n", err.Error())
|
||||||
|
}
|
||||||
|
}(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.SetPrev(context.Background(), e.GetNext(context.Background()))
|
||||||
|
e.SetNext(context.Background(), effective)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -90,6 +90,57 @@ func TestSpec(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeout(t *testing.T) {
|
||||||
|
m := newTaskManager()
|
||||||
|
defer m.ClearTask()
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
once1, once2 := sync.Once{}, sync.Once{}
|
||||||
|
|
||||||
|
tk1 := NewTask("tk1", "0/10 * * ? * *",
|
||||||
|
func(ctx context.Context) error {
|
||||||
|
time.Sleep(4 * time.Second)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
once1.Do(func() {
|
||||||
|
fmt.Println("tk1 done")
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
return errors.New("timeout")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, TimeoutOption(3*time.Second),
|
||||||
|
)
|
||||||
|
|
||||||
|
tk2 := NewTask("tk2", "0/11 * * ? * *",
|
||||||
|
func(ctx context.Context) error {
|
||||||
|
time.Sleep(4 * time.Second)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return errors.New("timeout")
|
||||||
|
default:
|
||||||
|
once2.Do(func() {
|
||||||
|
fmt.Println("tk2 done")
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
m.AddTask("tk1", tk1)
|
||||||
|
m.AddTask("tk2", tk2)
|
||||||
|
m.StartTask()
|
||||||
|
defer m.StopTask()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(19 * time.Second):
|
||||||
|
t.Error("TestTimeout failed")
|
||||||
|
case <-wait(wg):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTask_Run(t *testing.T) {
|
func TestTask_Run(t *testing.T) {
|
||||||
cnt := -1
|
cnt := -1
|
||||||
task := func(ctx context.Context) error {
|
task := func(ctx context.Context) error {
|
||||||
@ -109,6 +160,23 @@ func TestTask_Run(t *testing.T) {
|
|||||||
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
|
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCrudTask(t *testing.T) {
|
||||||
|
m := newTaskManager()
|
||||||
|
m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
m.DeleteTask("my-task1")
|
||||||
|
assert.Equal(t, 1, len(m.adminTaskList))
|
||||||
|
|
||||||
|
m.ClearTask()
|
||||||
|
assert.Equal(t, 0, len(m.adminTaskList))
|
||||||
|
}
|
||||||
|
|
||||||
func wait(wg *sync.WaitGroup) chan bool {
|
func wait(wg *sync.WaitGroup) chan bool {
|
||||||
ch := make(chan bool)
|
ch := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user