Merge branch 'develop' of https://gitclone.com/github.com/beego/beego into frt/delete-txorm

# Conflicts:
#	CHANGELOG.md
This commit is contained in:
Anker Jam 2021-01-28 00:01:55 +08:00
commit e8448a520f
98 changed files with 5278 additions and 647 deletions

View File

@ -8,7 +8,7 @@ on:
pull_request:
types: [opened, synchronize, reopened, labeled, unlabeled]
branches:
- master
- develop
jobs:
changelog:

32
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,32 @@
name: golangci-lint
on:
push:
tags:
- v*
branches:
- master
- main
pull_request:
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
version: v1.29
# Optional: working directory, useful for monorepos
# working-directory: ./
# Optional: golangci-lint command line arguments.
args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true
# Optional: show only new issues if it's a pull request. The default value is `false`.
only-new-issues: true
# Optional: if set to true then the action will use pre-installed Go
# skip-go-installation: true

View File

@ -1,7 +1,19 @@
# developing
- ORM mock. [4407](https://github.com/beego/beego/pull/4407)
- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433)
- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427)
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294)
- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404)
- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416)
- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424)
- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441)
- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453)
- Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446)
- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452)
- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456)
- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461)

1
ERROR_SPECIFICATION.md Normal file
View File

@ -0,0 +1 @@
# Error Module

View File

@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
(*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...)
(*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...))
}
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller

View File

@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) {
func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time {
return o.delegate.GetPrev()
}
func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration {
return 0
}

View File

@ -0,0 +1,131 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"github.com/beego/beego/v2/core/berror"
)
var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", `
You pass a invalid url to httplib module. Please check your url, be careful about special character.
`)
var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", `
You pass a invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2
But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2.
but you must confirm that server support those abnormal protocol version.
`)
var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", `
You use a invalid data as request body.
For now, we only support type string and byte[].
`)
var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", `
You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well.
Sometimes you got XML document and you want to make it as request body. So you call XMLBody.
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
`)
var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", `
You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well.
Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody.
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
`)
var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", `
You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well.
Sometimes you got JSON document and you want to make it as request body. So you call JSONBody.
If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data.
`)
// start with 5 --------------------------------------------------------------------------
var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", `
In normal case than handling files with BeegoRequest, you should not see this error code.
Unexpected EOF, invalid characters, bad file descriptor may cause this error.
`)
var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", `
There are several cases that cause this error:
1. file not found. Please check the file name;
2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code.
make sure that this file is in correct directory which Beego looks for;
3. Beego don't have the privilege to read this file, please change file mode;
`)
var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", `
When we try to read file content and then copy it to another writer, and failed.
1. Unexpected EOF;
2. Bad file descriptor;
3. Write conflict;
Please check your file content, and confirm that file is not processed by other process (or by user manually).
`)
var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", `
After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor.
`)
var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", `
Beego send HTTP request, but it failed.
If you config retry times, it means that Beego had retried and failed.
When you got this error, there are vary kind of reason:
1. Network unstable and timeout. In this case, sometimes server has received the request.
2. Server error. Make sure that server works well.
3. The request is invalid, which means that you pass some invalid parameter.
`)
var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", `
Beego parse gzip-encode body failed. Usually Beego got invalid response.
Please confirm that server returns gzip data.
`)
var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", `
Beego want to create file if not exist and failed.
In most cases, it means that Beego doesn't have the privilege to create this file.
Please change file mode to ensure that Beego is able to create files on specific directory.
Or you can run Beego with higher authority.
In some cases, you pass invalid filename. Make sure that the file name is valid on your system.
`)
var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName,
"UnmarshalResponseToObjectFailed", `
Beego trying to unmarshal response's body to structure but failed.
Make sure that:
1. You pass valid structure pointer to the function;
2. The body is valid json document
`)
var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName,
"UnmarshalResponseToObjectFailed", `
Beego trying to unmarshal response's body to structure but failed.
Make sure that:
1. You pass valid structure pointer to the function;
2. The body is valid XML document
`)
var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName,
"UnmarshalResponseToObjectFailed", `
Beego trying to unmarshal response's body to structure but failed.
Make sure that:
1. You pass valid structure pointer to the function;
2. The body is valid YAML document
`)

View File

@ -18,6 +18,7 @@ import (
"context"
"net/http"
"strconv"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -26,24 +27,30 @@ import (
)
type FilterChainBuilder struct {
summaryVec prometheus.ObserverVec
AppName string
ServerName string
RunMode string
}
var summaryVec prometheus.ObserverVec
var initSummaryVec sync.Once
func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter {
builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "remote_http_request",
ConstLabels: map[string]string{
"server": builder.ServerName,
"env": builder.RunMode,
"appname": builder.AppName,
},
Help: "The statics info for remote http requests",
}, []string{"proto", "scheme", "method", "host", "path", "status", "isError"})
initSummaryVec.Do(func() {
summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "remote_http_request",
ConstLabels: map[string]string{
"server": builder.ServerName,
"env": builder.RunMode,
"appname": builder.AppName,
},
Help: "The statics info for remote http requests",
}, []string{"proto", "scheme", "method", "host", "path", "status", "isError"})
prometheus.MustRegister(summaryVec)
})
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
startTime := time.Now()
@ -72,6 +79,6 @@ func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time
dur := int(endTime.Sub(startTime) / time.Millisecond)
builder.summaryVec.WithLabelValues(proto, scheme, method, host, path,
summaryVec.WithLabelValues(proto, scheme, method, host, path,
strconv.Itoa(status), strconv.FormatBool(err != nil)).Observe(float64(dur))
}

View File

@ -0,0 +1,36 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewHttpResponseWithJsonBody(t *testing.T) {
// string
resp := NewHttpResponseWithJsonBody("{}")
assert.Equal(t, int64(2), resp.ContentLength)
resp = NewHttpResponseWithJsonBody([]byte("{}"))
assert.Equal(t, int64(2), resp.ContentLength)
resp = NewHttpResponseWithJsonBody(&user{
Name: "Tom",
})
assert.True(t, resp.ContentLength > 0)
}

View File

@ -40,59 +40,36 @@ import (
"encoding/xml"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httputil"
"net/url"
"os"
"path"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
"github.com/beego/beego/v2/core/berror"
"github.com/beego/beego/v2/core/logs"
)
var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second,
Gzip: true,
DumpBody: true,
FilterChains: []FilterChain{mockFilter.FilterChain},
}
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
// it will be the last filter and execute request.Do
var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return req.doRequest(ctx)
}
// createDefaultCookie creates a global cookiejar to store cookies.
func createDefaultCookie() {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultCookieJar, _ = cookiejar.New(nil)
}
// SetDefaultSetting overwrites default settings
func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
}
// NewBeegoRequest returns *BeegoHttpRequest with specific method
// TODO add error as return value
// I think if we don't return error
// users are hard to check whether we create Beego request successfully
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response
u, err := url.Parse(rawurl)
if err != nil {
log.Println("Httplib:", err)
logs.Error("%+v", berror.Wrapf(err, InvalidUrl, "invalid raw url: %s", rawurl))
}
req := http.Request{
URL: u,
@ -137,24 +114,6 @@ func Head(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "HEAD")
}
// BeegoHTTPSettings is the http.Client setting
type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
CheckRedirect func(req *http.Request, via []*http.Request) error
EnableCookie bool
Gzip bool
DumpBody bool
Retries int // if set to -1 means will retry forever
RetryDelay time.Duration
FilterChains []FilterChain
}
// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url.
type BeegoHTTPRequest struct {
url string
@ -254,7 +213,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
}
// SetProtocolVersion sets the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
// Client requests always use HTTP/1.1
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
@ -265,8 +224,9 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
b.req.Proto = vers
b.req.ProtoMajor = major
b.req.ProtoMinor = minor
return b
}
logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers))
return b
}
@ -334,6 +294,7 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest
// Body adds request raw body.
// Supports string and []byte.
// TODO return error if data is invalid
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) {
case string:
@ -350,6 +311,8 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
return ioutil.NopCloser(bf), nil
}
b.req.ContentLength = int64(len(t))
default:
logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t))
}
return b
}
@ -359,9 +322,12 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := xml.Marshal(obj)
if err != nil {
return b, err
return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data")
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.GetBody = func() (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewReader(byts)), nil
}
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/xml")
}
@ -373,7 +339,7 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error)
if b.req.Body == nil && obj != nil {
byts, err := yaml.Marshal(obj)
if err != nil {
return b, err
return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data")
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
@ -387,7 +353,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error)
if b.req.Body == nil && obj != nil {
byts, err := json.Marshal(obj)
if err != nil {
return b, err
return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body")
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
@ -415,28 +381,15 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
bodyWriter := multipart.NewWriter(pw)
go func() {
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
log.Println("Httplib:", err)
}
fh, err := os.Open(filename)
if err != nil {
log.Println("Httplib:", err)
}
// iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
log.Println("Httplib:", err)
}
b.handleFileToBody(bodyWriter, formname, filename)
}
for k, v := range b.params {
for _, vv := range v {
bodyWriter.WriteField(k, vv)
_ = bodyWriter.WriteField(k, vv)
}
}
bodyWriter.Close()
pw.Close()
_ = bodyWriter.Close()
_ = pw.Close()
}()
b.Header("Content-Type", bodyWriter.FormDataContentType())
b.req.Body = ioutil.NopCloser(pr)
@ -452,6 +405,29 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
}
}
func (b *BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
const errFmt = "Httplib: %+v"
if err != nil {
logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed,
"could not create form file, formname: %s, filename: %s", formname, filename))
}
fh, err := os.Open(filename)
if err != nil {
logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename))
}
// iocopy
_, err = io.Copy(fileWriter, fh)
if err != nil {
logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename))
}
err = fh.Close()
if err != nil {
logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename))
}
}
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 {
return b.resp, nil
@ -480,62 +456,20 @@ func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Res
return root(ctx, b)
}
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
for _, vv := range v {
buf.WriteString(url.QueryEscape(k))
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(vv))
buf.WriteByte('&')
}
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (*http.Response, error) {
paramBody := b.buildParamBody()
b.buildURL(paramBody)
urlParsed, err := url.Parse(b.url)
if err != nil {
return nil, err
return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url)
}
b.req.URL = urlParsed
trans := b.setting.Transport
trans := b.buildTrans()
if trans == nil {
// create default transport
trans = &http.Transport{
TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
MaxIdleConnsPerHost: 100,
}
} else {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.Dial == nil {
t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
}
}
var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
}
jar = defaultCookieJar
}
jar := b.buildCookieJar()
client := &http.Client{
Transport: trans,
@ -551,12 +485,16 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response,
}
if b.setting.ShowDebug {
dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
if err != nil {
log.Println(err.Error())
dump, e := httputil.DumpRequest(b.req, b.setting.DumpBody)
if e != nil {
logs.Error("%+v", e)
}
b.dump = dump
}
return b.sendRequest(client)
}
func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) {
// retries default value is 0, it will run once.
// retries equal to -1, it will run forever until success
// retries is setted, it will retries fixed times.
@ -564,11 +502,68 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response,
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
resp, err = client.Do(b.req)
if err == nil {
break
return
}
time.Sleep(b.setting.RetryDelay)
}
return resp, err
return nil, berror.Wrap(err, SendRequestFailed, "sending request fail")
}
func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar {
var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
}
jar = defaultCookieJar
}
return jar
}
func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper {
trans := b.setting.Transport
if trans == nil {
// create default transport
trans = &http.Transport{
TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
MaxIdleConnsPerHost: 100,
}
} else {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.DialContext == nil {
t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
}
}
return trans
}
func (b *BeegoHTTPRequest) buildParamBody() string {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
for _, vv := range v {
buf.WriteString(url.QueryEscape(k))
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(vv))
buf.WriteByte('&')
}
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
return paramBody
}
// String returns the body string in response.
@ -599,10 +594,10 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed")
}
b.body, err = ioutil.ReadAll(reader)
return b.body, err
return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed")
}
b.body, err = ioutil.ReadAll(resp.Body)
return b.body, err
@ -645,7 +640,7 @@ func pathExistAndMkdir(filename string) (err error) {
return nil
}
}
return err
return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename)
}
// ToJSON returns the map that marshals from the body bytes as json in response.
@ -655,7 +650,8 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
if err != nil {
return err
}
return json.Unmarshal(data, v)
return berror.Wrap(json.Unmarshal(data, v),
UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.")
}
// ToXML returns the map that marshals from the body bytes as xml in response .
@ -665,7 +661,8 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
if err != nil {
return err
}
return xml.Unmarshal(data, v)
return berror.Wrap(xml.Unmarshal(data, v),
UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.")
}
// ToYAML returns the map that marshals from the body bytes as yaml in response .
@ -675,7 +672,8 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
if err != nil {
return err
}
return yaml.Unmarshal(data, v)
return berror.Wrap(yaml.Unmarshal(data, v),
UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.")
}
// Response executes request client gets response manually.
@ -684,8 +682,18 @@ func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
}
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
// Deprecated
// we will move this at the end of 2021
// please use TimeoutDialerCtx
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
return func(netw, addr string) (net.Conn, error) {
return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr)
}
}
func TimeoutDialerCtx(cTimeout time.Duration,
rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) {
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(netw, addr, cTimeout)
if err != nil {
return nil, err

View File

@ -300,3 +300,136 @@ func TestAddFilter(t *testing.T) {
r := Get("http://beego.me")
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
}
func TestFilterChainOrder(t *testing.T) {
req := Get("http://beego.me")
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("first"), nil
}
})
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("second"), nil
}
})
resp, err := req.DoRequestWithCtx(context.Background())
assert.Nil(t, err)
data := make([]byte, 5)
_, _ = resp.Body.Read(data)
assert.Equal(t, "first", string(data))
}
func TestHead(t *testing.T) {
req := Head("http://beego.me")
assert.NotNil(t, req)
assert.Equal(t, "HEAD", req.req.Method)
}
func TestDelete(t *testing.T) {
req := Delete("http://beego.me")
assert.NotNil(t, req)
assert.Equal(t, "DELETE", req.req.Method)
}
func TestPost(t *testing.T) {
req := Post("http://beego.me")
assert.NotNil(t, req)
assert.Equal(t, "POST", req.req.Method)
}
func TestNewBeegoRequest(t *testing.T) {
req := NewBeegoRequest("http://beego.me", "GET")
assert.NotNil(t, req)
assert.Equal(t, "GET", req.req.Method)
// invalid case but still go request
req = NewBeegoRequest("httpa\ta://beego.me", "GET")
assert.NotNil(t, req)
}
func TestBeegoHTTPRequest_SetProtocolVersion(t *testing.T) {
req := NewBeegoRequest("http://beego.me", "GET")
req.SetProtocolVersion("HTTP/3.10")
assert.Equal(t, "HTTP/3.10", req.req.Proto)
assert.Equal(t, 3, req.req.ProtoMajor)
assert.Equal(t, 10, req.req.ProtoMinor)
req.SetProtocolVersion("")
assert.Equal(t, "HTTP/1.1", req.req.Proto)
assert.Equal(t, 1, req.req.ProtoMajor)
assert.Equal(t, 1, req.req.ProtoMinor)
// invalid case
req.SetProtocolVersion("HTTP/aaa1.1")
assert.Equal(t, "HTTP/1.1", req.req.Proto)
assert.Equal(t, 1, req.req.ProtoMajor)
assert.Equal(t, 1, req.req.ProtoMinor)
}
func TestPut(t *testing.T) {
req := Put("http://beego.me")
assert.NotNil(t, req)
assert.Equal(t, "PUT", req.req.Method)
}
func TestBeegoHTTPRequest_Header(t *testing.T) {
req := Post("http://beego.me")
key, value := "test-header", "test-header-value"
req.Header(key, value)
assert.Equal(t, value, req.req.Header.Get(key))
}
func TestBeegoHTTPRequest_SetHost(t *testing.T) {
req := Post("http://beego.me")
host := "test-hose"
req.SetHost(host)
assert.Equal(t, host, req.req.Host)
}
func TestBeegoHTTPRequest_Param(t *testing.T) {
req := Post("http://beego.me")
key, value := "test-param", "test-param-value"
req.Param(key, value)
assert.Equal(t, value, req.params[key][0])
value1 := "test-param-value-1"
req.Param(key, value1)
assert.Equal(t, value1, req.params[key][1])
}
func TestBeegoHTTPRequest_Body(t *testing.T) {
req := Post("http://beego.me")
body := `hello, world`
req.Body([]byte(body))
assert.Equal(t, int64(len(body)), req.req.ContentLength)
assert.NotNil(t, req.req.GetBody)
assert.NotNil(t, req.req.Body)
body = "hhhh, i am test"
req.Body(body)
assert.Equal(t, int64(len(body)), req.req.ContentLength)
assert.NotNil(t, req.req.GetBody)
assert.NotNil(t, req.req.Body)
// invalid case
req.Body(13)
}
type user struct {
Name string `xml:"name"`
}
func TestBeegoHTTPRequest_XMLBody(t *testing.T) {
req := Post("http://beego.me")
body := &user{
Name: "Tom",
}
_, err := req.XMLBody(body)
assert.True(t, req.req.ContentLength > 0)
assert.Nil(t, err)
assert.NotNil(t, req.req.GetBody)
}

View File

@ -12,17 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"context"
"net/http"
"github.com/beego/beego/v2/client/httplib"
"github.com/beego/beego/v2/core/logs"
)
const mockCtxKey = "beego-httplib-mock"
func init() {
InitMockSetting()
}
type Stub interface {
Mock(cond RequestCondition, resp *http.Response, err error)
Clear()
@ -31,6 +36,10 @@ type Stub interface {
var mockFilter = &MockResponseFilter{}
func InitMockSetting() {
httplib.AddDefaultFilter(mockFilter.FilterChain)
}
func StartMock() Stub {
return mockFilter
}

View File

@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"context"
"encoding/json"
"net/textproto"
"regexp"
"github.com/beego/beego/v2/client/httplib"
)
type RequestCondition interface {
Match(ctx context.Context, req *BeegoHTTPRequest) bool
Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool
}
// reqCondition create condition
@ -54,7 +56,7 @@ func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondi
return sc
}
func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
res := true
if len(sc.path) > 0 {
res = sc.matchPath(ctx, req)
@ -70,12 +72,12 @@ func (sc *SimpleCondition) Match(ctx context.Context, req *BeegoHTTPRequest) boo
sc.matchBodyFields(ctx, req)
}
func (sc *SimpleCondition) matchPath(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
path := req.GetRequest().URL.Path
return path == sc.path
}
func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
path := req.GetRequest().URL.Path
if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil {
return b
@ -83,7 +85,7 @@ func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *BeegoHTTPReque
return false
}
func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
qs := req.GetRequest().URL.Query()
for k, v := range sc.query {
if uv, ok := qs[k]; !ok || uv[0] != v {
@ -93,7 +95,7 @@ func (sc *SimpleCondition) matchQuery(ctx context.Context, req *BeegoHTTPRequest
return true
}
func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
headers := req.GetRequest().Header
for k, v := range sc.header {
if uv, ok := headers[k]; !ok || uv[0] != v {
@ -103,7 +105,7 @@ func (sc *SimpleCondition) matchHeader(ctx context.Context, req *BeegoHTTPReques
return true
}
func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
if len(sc.body) == 0 {
return true
}
@ -135,7 +137,7 @@ func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *BeegoHTTPRe
return true
}
func (sc *SimpleCondition) matchMethod(ctx context.Context, req *BeegoHTTPRequest) bool {
func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool {
if len(sc.method) > 0 {
return sc.method == req.GetRequest().Method
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"context"
@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/httplib"
)
func init() {
@ -28,37 +29,37 @@ func init() {
func TestSimpleCondition_MatchPath(t *testing.T) {
sc := NewSimpleCondition("/abc/s")
res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s"))
res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s"))
assert.True(t, res)
}
func TestSimpleCondition_MatchQuery(t *testing.T) {
k, v := "my-key", "my-value"
sc := NewSimpleCondition("/abc/s")
res := sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value"))
res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value"))
assert.True(t, res)
sc = NewSimpleCondition("/abc/s", WithQuery(k, v))
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value"))
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value"))
assert.True(t, res)
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-valuesss"))
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss"))
assert.False(t, res)
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key-a=my-value"))
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value"))
assert.False(t, res)
res = sc.Match(context.Background(), Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello"))
res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello"))
assert.True(t, res)
}
func TestSimpleCondition_MatchHeader(t *testing.T) {
k, v := "my-header", "my-header-value"
sc := NewSimpleCondition("/abc/s")
req := Get("http://localhost:8080/abc/s")
req := httplib.Get("http://localhost:8080/abc/s")
assert.True(t, sc.Match(context.Background(), req))
req = Get("http://localhost:8080/abc/s")
req = httplib.Get("http://localhost:8080/abc/s")
req.Header(k, v)
assert.True(t, sc.Match(context.Background(), req))
@ -73,7 +74,7 @@ func TestSimpleCondition_MatchHeader(t *testing.T) {
func TestSimpleCondition_MatchBodyField(t *testing.T) {
sc := NewSimpleCondition("/abc/s")
req := Post("http://localhost:8080/abc/s")
req := httplib.Post("http://localhost:8080/abc/s")
assert.True(t, sc.Match(context.Background(), req))
@ -102,7 +103,7 @@ func TestSimpleCondition_MatchBodyField(t *testing.T) {
func TestSimpleCondition_Match(t *testing.T) {
sc := NewSimpleCondition("/abc/s")
req := Post("http://localhost:8080/abc/s")
req := httplib.Post("http://localhost:8080/abc/s")
assert.True(t, sc.Match(context.Background(), req))
@ -115,9 +116,9 @@ func TestSimpleCondition_Match(t *testing.T) {
func TestSimpleCondition_MatchPathReg(t *testing.T) {
sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`))
req := Post("http://localhost:8080/abc/s")
req := httplib.Post("http://localhost:8080/abc/s")
assert.True(t, sc.Match(context.Background(), req))
req = Post("http://localhost:8080/abcd/s")
req = httplib.Post("http://localhost:8080/abcd/s")
assert.False(t, sc.Match(context.Background(), req))
}

View File

@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"context"
"fmt"
"net/http"
"github.com/beego/beego/v2/client/httplib"
)
// MockResponse will return mock response if find any suitable mock data
@ -32,13 +33,10 @@ func NewMockResponseFilter() *MockResponseFilter {
}
}
func (m *MockResponseFilter) FilterChain(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter {
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
ms := mockFromCtx(ctx)
ms = append(ms, m.ms...)
fmt.Printf("url: %s, mock: %d \n", req.url, len(ms))
for _, mock := range ms {
if mock.cond.Match(ctx, req) {
return mock.resp, mock.err

View File

@ -12,20 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/httplib"
)
func TestMockResponseFilter_FilterChain(t *testing.T) {
req := Get("http://localhost:8080/abc/s")
req := httplib.Get("http://localhost:8080/abc/s")
ft := NewMockResponseFilter()
expectedResp := NewHttpResponseWithJsonBody(`{}`)
expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`)
expectedErr := errors.New("expected error")
ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr)
@ -35,16 +37,16 @@ func TestMockResponseFilter_FilterChain(t *testing.T) {
assert.Equal(t, expectedErr, err)
assert.Equal(t, expectedResp, resp)
req = Get("http://localhost:8080/abcd/s")
req = httplib.Get("http://localhost:8080/abcd/s")
req.AddFilters(ft.FilterChain)
resp, err = req.DoRequest()
assert.NotEqual(t, expectedErr, err)
assert.NotEqual(t, expectedResp, resp)
req = Get("http://localhost:8080/abc/s")
req = httplib.Get("http://localhost:8080/abc/s")
req.AddFilters(ft.FilterChain)
expectedResp1 := NewHttpResponseWithJsonBody(map[string]string{})
expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{})
expectedErr1 := errors.New("expected error")
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
@ -52,7 +54,7 @@ func TestMockResponseFilter_FilterChain(t *testing.T) {
assert.Equal(t, expectedErr, err)
assert.Equal(t, expectedResp, resp)
req = Get("http://localhost:8080/abc/abs/bbc")
req = httplib.Get("http://localhost:8080/abc/abs/bbc")
req.AddFilters(ft.FilterChain)
ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1)
resp, err = req.DoRequest()

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
package mock
import (
"context"
@ -21,16 +21,18 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/httplib"
)
func TestStartMock(t *testing.T) {
defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain}
// httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain}
stub := StartMock()
// defer stub.Clear()
expectedResp := NewHttpResponseWithJsonBody([]byte(`{}`))
expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`))
expectedErr := errors.New("expected err")
stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr)
@ -45,14 +47,14 @@ func TestStartMock(t *testing.T) {
// TestStartMock_Isolation Test StartMock that
// mock only work for this request
func TestStartMock_Isolation(t *testing.T) {
defaultSetting.FilterChains = []FilterChain{mockFilter.FilterChain}
// httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain}
// setup global stub
stub := StartMock()
globalMockResp := NewHttpResponseWithJsonBody([]byte(`{}`))
globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`))
globalMockErr := errors.New("expected err")
stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr)
expectedResp := NewHttpResponseWithJsonBody(struct {
expectedResp := httplib.NewHttpResponseWithJsonBody(struct {
A string `json:"a"`
}{
A: "aaa",
@ -67,9 +69,9 @@ func TestStartMock_Isolation(t *testing.T) {
}
func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) {
return Get("http://localhost:7777/abc").DoRequestWithCtx(ctx)
return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx)
}
func OriginalCodeUsingHttplib() (*http.Response, error){
return Get("http://localhost:7777/abc").DoRequest()
return httplib.Get("http://localhost:7777/abc").DoRequest()
}

17
client/httplib/module.go Normal file
View File

@ -0,0 +1,17 @@
// Copyright 2021 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
const moduleName = "httplib"

81
client/httplib/setting.go Normal file
View File

@ -0,0 +1,81 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"crypto/tls"
"net/http"
"net/http/cookiejar"
"net/url"
"sync"
"time"
)
// BeegoHTTPSettings is the http.Client setting
type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
CheckRedirect func(req *http.Request, via []*http.Request) error
EnableCookie bool
Gzip bool
DumpBody bool
Retries int // if set to -1 means will retry forever
RetryDelay time.Duration
FilterChains []FilterChain
}
// createDefaultCookie creates a global cookiejar to store cookies.
func createDefaultCookie() {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultCookieJar, _ = cookiejar.New(nil)
}
// SetDefaultSetting overwrites default settings
// Keep in mind that when you invoke the SetDefaultSetting
// some methods invoked before SetDefaultSetting
func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
}
var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second,
Gzip: true,
DumpBody: true,
FilterChains: make([]FilterChain, 0, 4),
}
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
// AddDefaultFilter add a new filter into defaultSetting
// Be careful about using this method if you invoke SetDefaultSetting somewhere
func AddDefaultFilter(fc FilterChain) {
settingMutex.Lock()
defer settingMutex.Unlock()
if defaultSetting.FilterChains == nil {
defaultSetting.FilterChains = make([]FilterChain, 0, 4)
}
defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc)
}

View File

@ -0,0 +1,6 @@
package clauses
const (
ExprSep = "__"
ExprDot = "."
)

View File

@ -0,0 +1,103 @@
package order_clause
import (
"github.com/beego/beego/v2/client/orm/clauses"
"strings"
)
type Sort int8
const (
None Sort = 0
Ascending Sort = 1
Descending Sort = 2
)
type Option func(order *Order)
type Order struct {
column string
sort Sort
isRaw bool
}
func Clause(options ...Option) *Order {
o := &Order{}
for _, option := range options {
option(o)
}
return o
}
func (o *Order) GetColumn() string {
return o.column
}
func (o *Order) GetSort() Sort {
return o.sort
}
func (o *Order) SortString() string {
switch o.GetSort() {
case Ascending:
return "ASC"
case Descending:
return "DESC"
}
return ``
}
func (o *Order) IsRaw() bool {
return o.isRaw
}
func ParseOrder(expressions ...string) []*Order {
var orders []*Order
for _, expression := range expressions {
sort := Ascending
column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot)
if column[0] == '-' {
sort = Descending
column = column[1:]
}
orders = append(orders, &Order{
column: column,
sort: sort,
})
}
return orders
}
func Column(column string) Option {
return func(order *Order) {
order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot)
}
}
func sort(sort Sort) Option {
return func(order *Order) {
order.sort = sort
}
}
func SortAscending() Option {
return sort(Ascending)
}
func SortDescending() Option {
return sort(Descending)
}
func SortNone() Option {
return sort(None)
}
func Raw() Option {
return func(order *Order) {
order.isRaw = true
}
}

View File

@ -0,0 +1,144 @@
package order_clause
import (
"testing"
)
func TestClause(t *testing.T) {
var (
column = `a`
)
o := Clause(
Column(column),
)
if o.GetColumn() != column {
t.Error()
}
}
func TestSortAscending(t *testing.T) {
o := Clause(
SortAscending(),
)
if o.GetSort() != Ascending {
t.Error()
}
}
func TestSortDescending(t *testing.T) {
o := Clause(
SortDescending(),
)
if o.GetSort() != Descending {
t.Error()
}
}
func TestSortNone(t *testing.T) {
o1 := Clause(
SortNone(),
)
if o1.GetSort() != None {
t.Error()
}
o2 := Clause()
if o2.GetSort() != None {
t.Error()
}
}
func TestRaw(t *testing.T) {
o1 := Clause()
if o1.IsRaw() {
t.Error()
}
o2 := Clause(
Raw(),
)
if !o2.IsRaw() {
t.Error()
}
}
func TestColumn(t *testing.T) {
o1 := Clause(
Column(`aaa`),
)
if o1.GetColumn() != `aaa` {
t.Error()
}
}
func TestParseOrder(t *testing.T) {
orders := ParseOrder(
`-user__status`,
`status`,
`user__status`,
)
t.Log(orders)
if orders[0].GetSort() != Descending {
t.Error()
}
if orders[0].GetColumn() != `user.status` {
t.Error()
}
if orders[1].GetColumn() != `status` {
t.Error()
}
if orders[1].GetSort() != Ascending {
t.Error()
}
if orders[2].GetColumn() != `user.status` {
t.Error()
}
}
func TestOrder_GetColumn(t *testing.T) {
o := Clause(
Column(`user__id`),
)
if o.GetColumn() != `user.id` {
t.Error()
}
}
func TestOrder_GetSort(t *testing.T) {
o := Clause(
SortDescending(),
)
if o.GetSort() != Descending {
t.Error()
}
}
func TestOrder_IsRaw(t *testing.T) {
o1 := Clause()
if o1.IsRaw() {
t.Error()
}
o2 := Clause(
Raw(),
)
if !o2.IsRaw() {
t.Error()
}
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"flag"
"fmt"
"os"
@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf(" %s\n", err.Error())
}
ctx := context.Background()
for i, mi := range modelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error {
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
if err != nil {
if d.rtOnError {
return err
@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error {
}
for _, idx := range indexes[mi.table] {
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"errors"
"fmt"
@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote()
dbcols := make([]string, 0, len(mi.fields.dbcols))
@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query)
stmt, err := q.PrepareContext(ctx, query)
return stmt, query, err
}
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return 0, err
@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
err := row.Scan(&id)
return id, err
}
res, err := stmt.Exec(values...)
res, err := stmt.ExecContext(ctx, values...)
if err == nil {
return res.LastInsertId()
}
@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string
var args []interface{}
@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...)
row := q.QueryRowContext(ctx, query, args...)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
}
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols))
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil {
return 0, err
}
id, err := d.InsertValue(q, mi, false, names, values)
id, err := d.InsertValue(ctx, q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return id, err
}
// multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return cnt, err
@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
}
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err
@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
// InsertOrUpdate a row
// If your primary key or unique column conflict will update
// If no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := ""
iouStr := ""
argsMap := map[string]string{}
@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if !ok {
return 0, ErrMissPK
@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, setValues...)
res, err := q.ExecContext(ctx, query, setValues...)
if err == nil {
return res.RowsAffected()
}
@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
var whereCols []string
var args []interface{}
// if specify cols length > 0, then use it for where condition.
@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
return res.RowsAffected()
}
@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
}
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins)
tables.skipEnd = true
@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
r, err := q.Query(query, args...)
r, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query)
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
errTyp := true
unregister := true
one := true
isPtr := true
name := ""
if val.Kind() == reflect.Ptr {
fn := ""
@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
case reflect.Struct:
isPtr = false
fn = getFullName(typ)
name = getTableName(reflect.New(typ))
}
} else {
fn = getFullName(ind.Type())
name = getTableName(ind)
}
errTyp = fn != mi.fullName
unregister = fn != mi.fullName
}
if errTyp {
if one {
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
} else {
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
}
if unregister {
RegisterModel(container)
}
rlimit := qs.limit
@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if qs.distinct {
sqlSelect += " DISTINCT"
}
if qs.aggregate != "" {
sels = qs.aggregate
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
var err error
if qs != nil && qs.forContext {
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rs.Close()
slice := ind
if unregister {
mi, _ = modelCache.get(name)
tCols = mi.fields.dbcols
colsNum = len(tCols)
}
refs := make([]interface{}, colsNum)
@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
var ref interface{}
refs[i] = &ref
}
defer rs.Close()
slice := ind
var cnt int64
for rs.Next() {
if one && cnt == 0 || !one {
@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query)
var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
row := q.QueryRowContext(ctx, query, args...)
err = row.Scan(&cnt)
return
}
@ -1649,7 +1631,7 @@ setValue:
}
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
maps []Params
@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
d.ins.ReplaceMarks(&query)
rs, err := q.Query(query, args...)
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
}
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
}
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return columns, err
}
@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
}
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
panic(ErrNotImplement)
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"reflect"
"strings"
@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
// If your primary key or unique column conflict will update
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
argsMap := map[string]string{}
@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
return id, err

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"strings"
@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
}
// check index is exist
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
}
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"strconv"
)
@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
}
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
if _, err := db.ExecContext(ctx, query); err != nil {
return err
}
}
@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string {
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
row := db.QueryRowContext(ctx, query)
var cnt int
row.Scan(&cnt)
return cnt > 0

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"fmt"
"reflect"
@ -73,11 +74,11 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(q, mi, ind, tz, cols, false)
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
}
// get sqlite operator.
@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
panic(err)
}

View File

@ -16,6 +16,8 @@ package orm
import (
"fmt"
"github.com/beego/beego/v2/client/orm/clauses"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"strings"
"time"
)
@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
}
// generate order sql.
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
if len(orders) == 0 {
return
}
@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
column := order.GetColumn()
clause := strings.Split(column, clauses.ExprDot)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
if order.IsRaw() {
if len(clause) == 2 {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
} else if len(clause) == 1 {
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
} else {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
} else {
index, _, fi, suc := t.parseExprs(t.mi, clause)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
}
}
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
)
@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)

View File

@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
return nil
}
@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
return nil
}

View File

@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, o.Driver())
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
assert.Nil(t, o.QueryM2M(nil, ""))
assert.Nil(t, o.ReadWithCtx(nil, nil))
assert.Nil(t, o.Read(nil))
@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
assert.Nil(t, o.QueryTable(nil))
assert.Nil(t, o.Read(nil))

View File

@ -27,7 +27,7 @@ import (
// this Filter's behavior looks a little bit strange
// for example:
// if we want to trace QuerySetter
// actually we trace invoking "QueryTable" and "QueryTableWithCtx"
// actually we trace invoking "QueryTable"
// the method Begin*, Commit and Rollback are ignored.
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
type FilterChainBuilder struct {

View File

@ -18,6 +18,7 @@ import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -31,26 +32,31 @@ import (
// this Filter's behavior looks a little bit strange
// for example:
// if we want to records the metrics of QuerySetter
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx"
// actually we only records metrics of invoking "QueryTable"
type FilterChainBuilder struct {
summaryVec prometheus.ObserverVec
AppName string
ServerName string
RunMode string
}
var summaryVec prometheus.ObserverVec
var initSummaryVec sync.Once
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "orm_operation",
ConstLabels: map[string]string{
"server": builder.ServerName,
"env": builder.RunMode,
"appname": builder.AppName,
},
Help: "The statics info for orm operation",
}, []string{"method", "name", "insideTx", "txName"})
initSummaryVec.Do(func() {
summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "orm_operation",
ConstLabels: map[string]string{
"server": builder.ServerName,
"env": builder.RunMode,
"appname": builder.AppName,
},
Help: "The statics info for orm operation",
}, []string{"method", "name", "insideTx", "txName"})
prometheus.MustRegister(summaryVec)
})
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
startTime := time.Now()
@ -74,12 +80,12 @@ func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocati
builder.reportTxn(ctx, inv)
return
}
builder.summaryVec.WithLabelValues(inv.Method, inv.GetTableName(),
summaryVec.WithLabelValues(inv.Method, inv.GetTableName(),
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
}
func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) {
dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond
builder.summaryVec.WithLabelValues(inv.Method, inv.TxName,
summaryVec.WithLabelValues(inv.Method, inv.TxName,
strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur))
}

View File

@ -32,7 +32,7 @@ func TestFilterChainBuilder_FilterChain1(t *testing.T) {
builder := &FilterChainBuilder{}
filter := builder.FilterChain(next)
assert.NotNil(t, builder.summaryVec)
assert.NotNil(t, summaryVec)
assert.NotNil(t, filter)
inv := &orm.Invocation{}

View File

@ -20,6 +20,7 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/core/utils"
)
@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
}
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
return f.QueryM2MWithCtx(context.Background(), md, name)
}
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "QueryM2MWithCtx",
Method: "QueryM2M",
Args: []interface{}{md, name},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func(c context.Context) []interface{} {
res := f.ormer.QueryM2MWithCtx(c, md, name)
res := f.ormer.QueryM2M(md, name)
return []interface{}{res}
},
}
res := f.root(ctx, inv)
res := f.root(context.Background(), inv)
if res[0] == nil {
return nil
}
return res[0].(QueryM2Mer)
}
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
// NOTE: this method is deprecated, context parameter will not take effect.
func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.")
return f.QueryM2M(md, name)
}
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
var (
name string
md interface{}
@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
}
inv := &Invocation{
Method: "QueryTableWithCtx",
Method: "QueryTable",
Args: []interface{}{ptrStructOrTableName},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
Md: md,
mi: mi,
f: func(c context.Context) []interface{} {
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
res := f.ormer.QueryTable(ptrStructOrTableName)
return []interface{}{res}
},
}
res := f.root(ctx, inv)
res := f.root(context.Background(), inv)
if res[0] == nil {
return nil
@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
return res[0].(QuerySeter)
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.")
return f.QueryTable(ptrStructOrTableName)
}
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
inv := &Invocation{
Method: "DBStats",

View File

@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
assert.Equal(t, "QueryM2M", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryTableWithCtx", inv.Method)
assert.Equal(t, "QueryTable", inv.Method)
assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)

View File

@ -0,0 +1,63 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"github.com/beego/beego/v2/client/orm"
)
type Mock struct {
cond Condition
resp []interface{}
cb func(inv *orm.Invocation)
}
func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock {
return &Mock{
cond: cond,
resp: resp,
cb: cb,
}
}
type Condition interface {
Match(ctx context.Context, inv *orm.Invocation) bool
}
type SimpleCondition struct {
tableName string
method string
}
func NewSimpleCondition(tableName string, methodName string) Condition {
return &SimpleCondition{
tableName: tableName,
method: methodName,
}
}
func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool {
res := true
if len(s.tableName) != 0 {
res = res && (s.tableName == inv.GetTableName())
}
if len(s.method) != 0 {
res = res && (s.method == inv.Method)
}
return res
}

View File

@ -0,0 +1,41 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm"
)
func TestSimpleCondition_Match(t *testing.T) {
cond := NewSimpleCondition("", "")
res := cond.Match(context.Background(), &orm.Invocation{})
assert.True(t, res)
cond = NewSimpleCondition("hello", "")
assert.False(t, cond.Match(context.Background(), &orm.Invocation{}))
cond = NewSimpleCondition("", "A")
assert.False(t, cond.Match(context.Background(), &orm.Invocation{
Method: "B",
}))
assert.True(t, cond.Match(context.Background(), &orm.Invocation{
Method: "A",
}))
}

View File

@ -0,0 +1,40 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"github.com/beego/beego/v2/core/logs"
)
type mockCtxKeyType string
const mockCtxKey = mockCtxKeyType("beego-orm-mock")
func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context {
return context.WithValue(ctx, mockCtxKey, mock)
}
func mockFromCtx(ctx context.Context) []*Mock {
ms := ctx.Value(mockCtxKey)
if ms != nil {
if res, ok := ms.([]*Mock); ok {
return res
}
logs.Error("mockCtxKey found in context, but value is not type []*Mock")
}
return nil
}

View File

@ -0,0 +1,29 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCtx(t *testing.T) {
ms := make([]*Mock, 0, 4)
ctx := CtxWithMock(context.Background(), ms...)
res := mockFromCtx(ctx)
assert.Equal(t, ms, res)
}

72
client/orm/mock/mock.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"github.com/beego/beego/v2/client/orm"
)
var stub = newOrmStub()
func init() {
orm.AddGlobalFilterChain(stub.FilterChain)
}
type Stub interface {
Mock(m *Mock)
Clear()
}
type OrmStub struct {
ms []*Mock
}
func StartMock() Stub {
return stub
}
func newOrmStub() *OrmStub {
return &OrmStub{
ms: make([]*Mock, 0, 4),
}
}
func (o *OrmStub) Mock(m *Mock) {
o.ms = append(o.ms, m)
}
func (o *OrmStub) Clear() {
o.ms = make([]*Mock, 0, 4)
}
func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter {
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
ms := mockFromCtx(ctx)
ms = append(ms, o.ms...)
for _, mock := range ms {
if mock.cond.Match(ctx, inv) {
if mock.cb != nil {
mock.cb(inv)
}
return mock.resp
}
}
return next(ctx, inv)
}
}

162
client/orm/mock/mock_orm.go Normal file
View File

@ -0,0 +1,162 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"database/sql"
"os"
"path/filepath"
_ "github.com/mattn/go-sqlite3"
"github.com/beego/beego/v2/client/orm"
)
func init() {
RegisterMockDB("default")
}
// RegisterMockDB create an "virtual DB" by using sqllite
// you should not
func RegisterMockDB(name string) {
source := filepath.Join(os.TempDir(), name+".db")
_ = orm.RegisterDataBase(name, "sqlite3", source)
}
// MockTable only check table name
func MockTable(tableName string, resp ...interface{}) *Mock {
return NewMock(NewSimpleCondition(tableName, ""), resp, nil)
}
// MockMethod only check method name
func MockMethod(method string, resp ...interface{}) *Mock {
return NewMock(NewSimpleCondition("", method), resp, nil)
}
// MockOrmRead support orm.Read and orm.ReadWithCtx
// cb is used to mock read data from DB
func MockRead(tableName string, cb func(data interface{}), err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) {
if cb != nil {
cb(inv.Args[0])
}
})
}
// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx
// cb is used to mock read data from DB
func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"),
[]interface{}{err},
func(inv *orm.Invocation) {
cb(inv.Args[0])
})
}
// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx
// cb is used to mock read data from DB
func MockReadOrCreateWithCtx(tableName string,
cb func(data interface{}),
insert bool, id int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"),
[]interface{}{insert, id, err},
func(inv *orm.Invocation) {
cb(inv.Args[0])
})
}
// MockInsertWithCtx support Insert and InsertWithCtx
func MockInsertWithCtx(tableName string, id int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil)
}
// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx
func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil)
}
// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx
func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil)
}
// MockUpdateWithCtx support UpdateWithCtx and Update
func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil)
}
// MockDeleteWithCtx support Delete and DeleteWithCtx
func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock {
return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil)
}
// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M
// Now you may be need to use golang/mock to generate QueryM2M mock instance
// Or use DoNothingQueryM2Mer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
// when you write test code:
// MockQueryM2MWithCtx("post", "Tags", mockM2Mer)
// "post" is the table name of model Post structure
// TODO provide orm.QueryM2Mer
func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock {
return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil)
}
// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated
func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock {
return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil)
}
// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable
func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock {
return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil)
}
// MockRawWithCtx support RawWithCtx and Raw
func MockRawWithCtx(rs orm.RawSeter) *Mock {
return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil)
}
// MockDriver support Driver
// func MockDriver(driver orm.Driver) *Mock {
// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver})
// }
// MockDBStats support DBStats
func MockDBStats(stats *sql.DBStats) *Mock {
return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil)
}
// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts
// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock {
// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err})
// }
// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts
// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock {
// return MockBeginWithCtxAndOpts(txOrm, err)
// }
// MockCommit support Commit
func MockCommit(err error) *Mock {
return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil)
}
// MockRollback support Rollback
func MockRollback(err error) *Mock {
return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil)
}

View File

@ -0,0 +1,297 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"database/sql"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm"
)
const mockErrorMsg = "mock error"
func init() {
orm.RegisterModel(&User{})
}
func TestMockDBStats(t *testing.T) {
s := StartMock()
defer s.Clear()
stats := &sql.DBStats{}
s.Mock(MockDBStats(stats))
o := orm.NewOrm()
res := o.DBStats()
assert.Equal(t, stats, res)
}
func TestMockDeleteWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil))
o := orm.NewOrm()
rows, err := o.Delete(&User{})
assert.Equal(t, int64(12), rows)
assert.Nil(t, err)
}
func TestMockInsertOrUpdateWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil))
o := orm.NewOrm()
id, err := o.InsertOrUpdate(&User{})
assert.Equal(t, int64(12), id)
assert.Nil(t, err)
}
func TestMockRead(t *testing.T) {
s := StartMock()
defer s.Clear()
err := errors.New(mockErrorMsg)
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, err))
o := orm.NewOrm()
u := &User{}
e := o.Read(u)
assert.Equal(t, err, e)
assert.Equal(t, "Tom", u.Name)
}
func TestMockQueryM2MWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := &DoNothingQueryM2Mer{}
s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock))
o := orm.NewOrm()
res := o.QueryM2M(&User{}, "Tags")
assert.Equal(t, mock, res)
}
func TestMockQueryTableWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := &DoNothingQuerySetter{}
s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock))
o := orm.NewOrm()
res := o.QueryTable(&User{})
assert.Equal(t, mock, res)
}
func TestMockTable(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockTable((&User{}).TableName(), mock))
o := orm.NewOrm()
res := o.Read(&User{})
assert.Equal(t, mock, res)
}
func TestMockInsertMultiWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock))
o := orm.NewOrm()
res, err := o.InsertMulti(11, []interface{}{&User{}})
assert.Equal(t, int64(12), res)
assert.Equal(t, mock, err)
}
func TestMockInsertWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock))
o := orm.NewOrm()
res, err := o.Insert(&User{})
assert.Equal(t, int64(13), res)
assert.Equal(t, mock, err)
}
func TestMockUpdateWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock))
o := orm.NewOrm()
res, err := o.Update(&User{})
assert.Equal(t, int64(12), res)
assert.Equal(t, mock, err)
}
func TestMockLoadRelatedWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock))
o := orm.NewOrm()
res, err := o.LoadRelated(&User{}, "T")
assert.Equal(t, int64(12), res)
assert.Equal(t, mock, err)
}
func TestMockMethod(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockMethod("ReadWithCtx", mock))
o := orm.NewOrm()
err := o.Read(&User{})
assert.Equal(t, mock, err)
}
func TestMockReadForUpdateWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, mock))
o := orm.NewOrm()
u := &User{}
err := o.ReadForUpdate(u)
assert.Equal(t, mock, err)
assert.Equal(t, "Tom", u.Name)
}
func TestMockRawWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := &DoNothingRawSetter{}
s.Mock(MockRawWithCtx(mock))
o := orm.NewOrm()
res := o.Raw("")
assert.Equal(t, mock, res)
}
func TestMockReadOrCreateWithCtx(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, false, 12, mock))
o := orm.NewOrm()
u := &User{}
inserted, id, err := o.ReadOrCreate(u, "")
assert.Equal(t, mock, err)
assert.Equal(t, int64(12), id)
assert.False(t, inserted)
assert.Equal(t, "Tom", u.Name)
}
func TestTransactionClosure(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, mock))
u, err := originalTxUsingClosure()
assert.Equal(t, mock, err)
assert.Equal(t, "Tom", u.Name)
}
func TestTransactionManually(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, mock))
u, err := originalTxManually()
assert.Equal(t, mock, err)
assert.Equal(t, "Tom", u.Name)
}
func TestTransactionRollback(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error")))
s.Mock(MockRollback(mock))
_, err := originalTx()
assert.Equal(t, mock, err)
}
func TestTransactionCommit(t *testing.T) {
s := StartMock()
defer s.Clear()
mock := errors.New(mockErrorMsg)
s.Mock(MockRead((&User{}).TableName(), func(data interface{}) {
u := data.(*User)
u.Name = "Tom"
}, nil))
s.Mock(MockCommit(mock))
u, err := originalTx()
assert.Equal(t, mock, err)
assert.Equal(t, "Tom", u.Name)
}
func originalTx() (*User, error) {
u := &User{}
o := orm.NewOrm()
txOrm, _ := o.Begin()
err := txOrm.Read(u)
if err == nil {
err = txOrm.Commit()
return u, err
} else {
err = txOrm.Rollback()
return nil, err
}
}
func originalTxManually() (*User, error) {
u := &User{}
o := orm.NewOrm()
txOrm, _ := o.Begin()
err := txOrm.Read(u)
_ = txOrm.Commit()
return u, err
}
func originalTxUsingClosure() (*User, error) {
u := &User{}
var err error
o := orm.NewOrm()
_ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error {
err = txOrm.Read(u)
return nil
})
return u, err
}
type User struct {
Id int
Name string
}
func (u *User) TableName() string {
return "user"
}

View File

@ -0,0 +1,92 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"github.com/beego/beego/v2/client/orm"
)
// DoNothingQueryM2Mer do nothing
// use it to build mock orm.QueryM2Mer
type DoNothingQueryM2Mer struct {
}
func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool {
return true
}
func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool {
return true
}
func (d *DoNothingQueryM2Mer) Clear() (int64, error) {
return 0, nil
}
func (d *DoNothingQueryM2Mer) Count() (int64, error) {
return 0, nil
}
type QueryM2MerCondition struct {
tableName string
name string
}
func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition {
return &QueryM2MerCondition{
tableName: tableName,
name: name,
}
}
func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool {
res := true
if len(q.tableName) > 0 {
res = res && (q.tableName == inv.GetTableName())
}
if len(q.name) > 0 {
res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string))
}
return res
}

View File

@ -0,0 +1,63 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm"
)
func TestDoNothingQueryM2Mer(t *testing.T) {
m2m := &DoNothingQueryM2Mer{}
i, err := m2m.Clear()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = m2m.Count()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = m2m.Add()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = m2m.Remove()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
assert.True(t, m2m.Exist(nil))
}
func TestNewQueryM2MerCondition(t *testing.T) {
cond := NewQueryM2MerCondition("", "")
res := cond.Match(context.Background(), &orm.Invocation{})
assert.True(t, res)
cond = NewQueryM2MerCondition("hello", "")
assert.False(t, cond.Match(context.Background(), &orm.Invocation{}))
cond = NewQueryM2MerCondition("", "A")
assert.False(t, cond.Match(context.Background(), &orm.Invocation{
Args: []interface{}{0, "B"},
}))
assert.True(t, cond.Match(context.Background(), &orm.Invocation{
Args: []interface{}{0, "A"},
}))
}

View File

@ -0,0 +1,183 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
)
// DoNothingQuerySetter do nothing
// usually you use this to build your mock QuerySetter
type DoNothingQuerySetter struct {
}
func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool {
return true
}
func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) {
return nil, nil
}
func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
return nil
}
func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) GetCond() *orm.Condition {
return orm.NewCondition()
}
func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter {
return d
}
func (d *DoNothingQuerySetter) Count() (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) Exist() bool {
return true
}
func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) Delete() (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) {
return nil, nil
}
func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error {
return nil
}
func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) {
return 0, nil
}
func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return 0, nil
}

View File

@ -0,0 +1,74 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDoNothingQuerySetter(t *testing.T) {
setter := &DoNothingQuerySetter{}
setter.GroupBy().Filter("").Limit(10).
Distinct().Exclude("a").FilterRaw("", "").
ForceIndex().ForUpdate().IgnoreIndex().
Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex()
assert.True(t, setter.Exist())
err := setter.One(nil)
assert.Nil(t, err)
i, err := setter.Count()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.Delete()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.All(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.Update(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.RowsToMap(nil, "", "")
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.RowsToStruct(nil, "", "")
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.Values(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.ValuesFlat(nil, "")
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = setter.ValuesList(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
ins, err := setter.PrepareInsert()
assert.Nil(t, err)
assert.Nil(t, ins)
assert.NotNil(t, setter.GetCond())
}

View File

@ -0,0 +1,64 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"database/sql"
"github.com/beego/beego/v2/client/orm"
)
type DoNothingRawSetter struct {
}
func (d *DoNothingRawSetter) Exec() (sql.Result, error) {
return nil, nil
}
func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error {
return nil
}
func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter {
return d
}
func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return 0, nil
}
func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) {
return nil, nil
}

View File

@ -0,0 +1,63 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDoNothingRawSetter(t *testing.T) {
rs := &DoNothingRawSetter{}
i, err := rs.ValuesList(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = rs.Values(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = rs.ValuesFlat(nil)
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = rs.RowsToStruct(nil, "", "")
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = rs.RowsToMap(nil, "", "")
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
i, err = rs.QueryRows()
assert.Equal(t, int64(0), i)
assert.Nil(t, err)
err = rs.QueryRow()
// assert.Equal(t, int64(0), i)
assert.Nil(t, err)
s, err := rs.Exec()
assert.Nil(t, err)
assert.Nil(t, s)
p, err := rs.Prepare()
assert.Nil(t, err)
assert.Nil(t, p)
rrs := rs.SetArgs()
assert.Equal(t, rrs, rs)
}

View File

@ -0,0 +1,58 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm"
)
func TestOrmStub_FilterChain(t *testing.T) {
os := newOrmStub()
inv := &orm.Invocation{
Args: []interface{}{10},
}
i := 1
os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} {
i++
return nil
})(context.Background(), inv)
assert.Equal(t, 2, i)
m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) {
arg := inv.Args[0]
j := arg.(int)
inv.Args[0] = j + 1
})
os.Mock(m)
os.FilterChain(nil)(context.Background(), inv)
assert.Equal(t, 11, inv.Args[0])
inv.Args[0] = 10
ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) {
arg := inv.Args[0]
j := arg.(int)
inv.Args[0] = j + 3
})
os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv)
assert.Equal(t, 13, inv.Args[0])
}

View File

@ -332,10 +332,6 @@ end:
// register register models to model cache
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
if mc.done {
err = fmt.Errorf("register must be run before BootStrap")
return
}
for _, model := range models {
val := reflect.ValueOf(model)
@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
return
}
if val.Elem().Kind() == reflect.Slice {
val = reflect.New(val.Elem().Type().Elem())
}
table := getTableName(val)
if prefixOrSuffixStr != "" {
@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
}
if _, ok := mc.get(table); ok {
err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
return
return nil
}
mi := newModelInfo(val)
@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
}
}
}
if mi.fields.pk == nil {
err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
return
}
}
mi.table = table

View File

@ -255,6 +255,22 @@ func NewTM() *TM {
return obj
}
type DeptInfo struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type UnregisterModel struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type User struct {
ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"`
@ -476,45 +492,45 @@ var (
helpinfo = `need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/beego/beego/v2/client/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/beego/beego/v2/client/orm
#### Sqlite3
export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/beego/beego/v2/client/orm
#### PostgreSQL
psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/beego/beego/v2/client/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/beego/beego/v2/pgk/orm
`
)

View File

@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string {
// get whether the table needs to be created for the database alias
func isApplicableTableForDB(val reflect.Value, db string) bool {
if !val.IsValid() {
return true
}
fun := val.MethodByName("IsApplicableTableForDB")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})

View File

@ -58,6 +58,7 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"os"
"reflect"
"time"
@ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error {
}
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
}
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
@ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
}
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true)
}
// Try to read a row from the database, or insert one if it doesn't exist
@ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
if err == ErrNoRows {
// Create
id, err := o.InsertWithCtx(ctx, md)
@ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) {
}
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
@ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
for i := 0; i < sind.Len(); i++ {
ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
@ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
@ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (
}
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...)
if err != nil {
return id, err
}
@ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
}
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols)
}
// delete model in database
@ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
}
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}
@ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str
// create a models to models queryer
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
return o.QueryM2MWithCtx(context.Background(), md, name)
}
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
@ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
return newQueryM2M(md, o, mi, fi, ind)
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.")
return o.QueryM2M(md, name)
}
// load related models to md model.
// args are limit, offset int and order string.
//
@ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
qs.relDepth = relDepth
if len(order) > 0 {
qs.orders = []string{order}
qs.orders = order_clause.ParseOrder(order)
}
find := ind.FieldByIndex(fi.fieldIndex)
@ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
}
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table)
@ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in
if qs == nil {
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
}
return
return qs
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.")
return o.QueryTable(ptrStructOrTableName)
}
// return a raw query seter for raw sql string.
@ -595,9 +602,8 @@ func NewOrm() Ormer {
func NewOrmUsingDB(aliasName string) Ormer {
if al, ok := dataBaseCache.get(aliasName); ok {
return newDBWithAlias(al)
} else {
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
}
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
}
// NewOrmWithDB create a new ormer object with specify *sql.DB for query

View File

@ -16,12 +16,13 @@ package orm
import (
"fmt"
"github.com/beego/beego/v2/client/orm/clauses"
"strings"
)
// ExprSep define the expression separation
const (
ExprSep = "__"
ExprSep = clauses.ExprSep
)
type condValue struct {

View File

@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error {
}
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
return d.ExecContext(context.Background(), args...)
}
func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.stmt.Exec(args...)
res, err := d.stmt.ExecContext(ctx, args...)
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
return d.QueryContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.stmt.Query(args...)
res, err := d.stmt.QueryContext(ctx, args...)
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
return d.QueryRowContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
a := time.Now()
res := d.stmt.QueryRow(args...)
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"reflect"
)
@ -31,6 +32,10 @@ var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) {
return o.InsertWithCtx(context.Background(), md)
}
func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if name != o.mi.fullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}
@ -70,11 +75,11 @@ func (o *insertSet) Close() error {
}
// create new insert queryer.
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) {
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
if err != nil {
return nil, err
}

View File

@ -14,7 +14,10 @@
package orm
import "reflect"
import (
"context"
"reflect"
)
// model to model struct
type queryM2M struct {
@ -33,6 +36,10 @@ type queryM2M struct {
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
return o.AddWithCtx(context.Background(), mds...)
}
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo
@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
}
names = append(names, otherNames...)
values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values)
return dbase.InsertValue(ctx, orm.db, mi, true, names, values)
}
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return o.RemoveWithCtx(context.Background(), mds...)
}
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool {
return o.ExistWithCtx(context.Background(), md)
}
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
}
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) {
return o.ClearWithCtx(context.Background())
}
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
}
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) {
return o.CountWithCtx(context.Background())
}
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
}
var _ QueryM2Mer = new(queryM2M)

View File

@ -18,6 +18,7 @@ import (
"context"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints"
)
@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
forUpdate bool
useIndex int
indexes []string
orm *ormBase
ctx context.Context
forContext bool
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []*order_clause.Order
distinct bool
forUpdate bool
useIndex int
indexes []string
orm *ormBase
aggregate string
}
var _ QuerySeter = new(querySet)
@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter {
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
func (o querySet) OrderBy(expressions ...string) QuerySeter {
if len(expressions) <= 0 {
return &o
}
o.orders = order_clause.ParseOrder(expressions...)
return &o
}
// add ORDER expression.
func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter {
if len(orders) <= 0 {
return &o
}
o.orders = orders
return &o
}
@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition {
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.CountWithCtx(context.Background())
}
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.ExistWithCtx(context.Background())
}
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
return o.UpdateWithCtx(context.Background(), values)
}
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.DeleteWithCtx(context.Background())
}
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// return a insert queryer.
@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) {
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
return o.PrepareInsertWithCtx(context.Background())
}
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
return newInsertSet(ctx, o.orm, o.mi)
}
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
return o.AllWithCtx(context.Background(), container, cols...)
}
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...)
}
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
return err
}
@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error {
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
return o.ValuesWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
return o.ValuesListWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
return o.ValuesFlatWithCtx(context.Background(), result, expr)
}
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// query all rows into map[string]interface with specify key and value column name.
@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement)
}
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter.
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o := new(querySet)
@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o.orm = orm
return o
}
// aggregate func
func (o querySet) Aggregate(s string) QuerySeter {
o.aggregate = s
return &o
}

View File

@ -21,6 +21,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"io/ioutil"
"math"
"os"
@ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Index))
RegisterModel(new(StrPk))
RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
@ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Index))
RegisterModel(new(StrPk))
RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
BootStrap()
@ -333,6 +336,73 @@ func TestTM(t *testing.T) {
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
}
func TestUnregisterModel(t *testing.T) {
data := []*DeptInfo{
{
DeptName: "A",
EmployeeName: "A1",
Salary: 1000,
},
{
DeptName: "A",
EmployeeName: "A2",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B1",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B2",
Salary: 4000,
},
{
DeptName: "B",
EmployeeName: "B3",
Salary: 3000,
},
}
qs := dORM.QueryTable("dept_info")
i, _ := qs.PrepareInsert()
for _, d := range data {
_, err := i.Insert(d)
if err != nil {
throwFail(t, err)
}
}
f := func() {
var res []UnregisterModel
n, err := dORM.QueryTable("dept_info").All(&res)
throwFail(t, err)
throwFail(t, AssertIs(n, 5))
throwFail(t, AssertIs(res[0].EmployeeName, "A1"))
type Sum struct {
DeptName string
Total int
}
var sun []Sum
qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun)
throwFail(t, AssertIs(sun[0].DeptName, "A"))
throwFail(t, AssertIs(sun[0].Total, 3000))
type Max struct {
DeptName string
Max float64
}
var max []Max
qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max)
throwFail(t, AssertIs(max[1].DeptName, "B"))
throwFail(t, AssertIs(max[1].Max, 4000))
}
for i := 0; i < 5; i++ {
f()
}
}
func TestNullDataTypes(t *testing.T) {
d := DataNull{}
@ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) {
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`profile__age`),
order_clause.SortDescending(),
),
).Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
if IsMysql {
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`rand()`),
order_clause.Raw(),
),
).Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
}
func TestAll(t *testing.T) {
@ -1163,6 +1253,19 @@ func TestValues(t *testing.T) {
throwFail(t, AssertIs(maps[2]["Profile"], nil))
}
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column("Id"),
order_clause.SortAscending(),
),
).Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
if num == 3 {
throwFail(t, AssertIs(maps[0]["UserName"], "slene"))
throwFail(t, AssertIs(maps[2]["Profile"], nil))
}
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
@ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) {
throwFail(t, AssertIs(!cycleFlag, true))
return
}
func TestContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
user := User{UserName: "slene"}
err := dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, err)
cancel()
err = dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, AssertIs(err, context.Canceled))
ctx, cancel = context.WithCancel(context.Background())
cancel()
qs := dORM.QueryTable(user)
_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx)
throwFail(t, AssertIs(err, context.Canceled))
}

View File

@ -20,6 +20,7 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/core/utils"
)
@ -196,12 +197,16 @@ type DQL interface {
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
DBStats() *sql.DBStats
@ -236,6 +241,7 @@ type TxOrmer interface {
// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
InsertWithCtx(context.Context, interface{}) (int64, error)
Close() error
}
@ -295,6 +301,28 @@ type QuerySeter interface {
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// add ORDER expression by order clauses
// for example:
// OrderClauses(
// order_clause.Clause(
// order.Column("Id"),
// order.SortAscending(),
// ),
// order_clause.Clause(
// order.Column("status"),
// order.SortDescending(),
// ),
// )
// OrderClauses(order_clause.Clause(
// order_clause.Column(`user__status`),
// order_clause.SortDescending(),//default None
// ))
// OrderClauses(order_clause.Clause(
// order_clause.Column(`random()`),
// order_clause.SortNone(),//default None
// order_clause.Raw(),//default false.if true, do not check field is valid or not
// ))
OrderClauses(orders ...*order_clause.Order) QuerySeter
// add FORCE INDEX expression.
// for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`)
@ -333,9 +361,11 @@ type QuerySeter interface {
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
// check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool
ExistWithCtx(context.Context) bool
// execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
@ -345,11 +375,13 @@ type QuerySeter interface {
// "user_name": "slene2"
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
// delete from table
// for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error)
DeleteWithCtx(context.Context) (int64, error)
// return a insert queryer.
// it can be used in times.
// example:
@ -358,18 +390,21 @@ type QuerySeter interface {
// num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
PrepareInsertWithCtx(context.Context) (Inserter, error)
// query all data and map to containers.
// cols means the columns when querying.
// for example:
// var users []*User
// qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error)
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
// query one row data and map to containers.
// cols means the columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
@ -377,18 +412,21 @@ type QuerySeter interface {
// var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
@ -411,6 +449,15 @@ type QuerySeter interface {
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// aggregate func.
// for example:
// type result struct {
// DeptName string
// Total int
// }
// var res []result
// o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res)
Aggregate(s string) QuerySeter
}
// QueryM2Mer model to model query struct
@ -428,18 +475,23 @@ type QueryM2Mer interface {
// insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
AddWithCtx(context.Context, ...interface{}) (int64, error)
// remove models following the origin model relationship
// only delete rows from m2m table
// for example:
// tag3 := &Tag{Id:5,Name: "TestTag3"}
// num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
// check model is existed in relationship of origin model
Exist(interface{}) bool
ExistWithCtx(context.Context, interface{}) bool
// clean all models in related of origin model
Clear() (int64, error)
ClearWithCtx(context.Context) (int64, error)
// count all related models of origin model
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
}
// RawPreparer raw query statement
@ -513,11 +565,11 @@ type RawSeter interface {
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
// QueryContext(args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
}
// db querier
@ -554,28 +606,28 @@ type txEnder interface {
// base database struct
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
@ -584,12 +636,12 @@ type dbBaser interface {
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error)
GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
IndexExists(context.Context, dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
setval(context.Context, dbQuerier, *modelInfo, []string) error
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
}

91
core/berror/codes.go Normal file
View File

@ -0,0 +1,91 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package berror
import (
"fmt"
"sync"
)
// A Code is an unsigned 32-bit error code as defined in the beego spec.
type Code interface {
Code() uint32
Module() string
Desc() string
Name() string
}
var defaultCodeRegistry = &codeRegistry{
codes: make(map[uint32]*codeDefinition, 127),
}
// DefineCode defining a new Code
// Before defining a new code, please read Beego specification.
// desc could be markdown doc
func DefineCode(code uint32, module string, name string, desc string) Code {
res := &codeDefinition{
code: code,
module: module,
desc: desc,
}
defaultCodeRegistry.lock.Lock()
defer defaultCodeRegistry.lock.Unlock()
if _, ok := defaultCodeRegistry.codes[code]; ok {
panic(fmt.Sprintf("duplicate code, code %d has been registered", code))
}
defaultCodeRegistry.codes[code] = res
return res
}
type codeRegistry struct {
lock sync.RWMutex
codes map[uint32]*codeDefinition
}
func (cr *codeRegistry) Get(code uint32) (Code, bool) {
cr.lock.RLock()
defer cr.lock.RUnlock()
c, ok := cr.codes[code]
return c, ok
}
type codeDefinition struct {
code uint32
module string
desc string
name string
}
func (c *codeDefinition) Name() string {
return c.name
}
func (c *codeDefinition) Code() uint32 {
return c.code
}
func (c *codeDefinition) Module() string {
return c.module
}
func (c *codeDefinition) Desc() string {
return c.desc
}

69
core/berror/error.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package berror
import (
"fmt"
"strconv"
"strings"
"github.com/pkg/errors"
)
// code, msg
const errFmt = "ERROR-%d, %s"
// Err returns an error representing c and msg. If c is OK, returns nil.
func Error(c Code, msg string) error {
return fmt.Errorf(errFmt, c.Code(), msg)
}
// Errorf returns error
func Errorf(c Code, format string, a ...interface{}) error {
return Error(c, fmt.Sprintf(format, a...))
}
func Wrap(err error, c Code, msg string) error {
if err == nil {
return nil
}
return errors.Wrap(err, fmt.Sprintf(errFmt, c.Code(), msg))
}
func Wrapf(err error, c Code, format string, a ...interface{}) error {
return Wrap(err, c, fmt.Sprintf(format, a...))
}
// FromError is very simple. It just parse error msg and check whether code has been register
// if code not being register, return unknown
// if err.Error() is not valid beego error code, return unknown
func FromError(err error) (Code, bool) {
msg := err.Error()
codeSeg := strings.SplitN(msg, ",", 2)
if strings.HasPrefix(codeSeg[0], "ERROR-") {
codeStr := strings.SplitN(codeSeg[0], "-", 2)
if len(codeStr) < 2 {
return Unknown, false
}
codeInt, e := strconv.ParseUint(codeStr[1], 10, 32)
if e != nil {
return Unknown, false
}
if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok {
return code, true
}
}
return Unknown, false
}

77
core/berror/error_test.go Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package berror
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1")
var testErr = errors.New("hello, this is error")
func TestErrorf(t *testing.T) {
msg := Errorf(testCode1, "errorf %s", "aaaa")
assert.NotNil(t, msg)
assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error())
}
func TestWrapf(t *testing.T) {
err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa")
assert.NotNil(t, err)
assert.True(t, errors.Is(err, testErr))
}
func TestFromError(t *testing.T) {
err := errors.New("ERROR-1, errorf aaaa")
code, ok := FromError(err)
assert.True(t, ok)
assert.Equal(t, testCode1, code)
assert.Equal(t, "unit_test", code.Module())
assert.Equal(t, "Hello, test code1", code.Desc())
err = errors.New("not beego error")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
err = errors.New("ERROR-2, not register")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
err = errors.New("ERROR-aaa, invalid code")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
err = errors.New("aaaaaaaaaaaaaa")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
err = errors.New("ERROR-2-3, invalid error")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
err = errors.New("ERROR, invalid error")
code, ok = FromError(err)
assert.False(t, ok)
assert.Equal(t, Unknown, code)
}

View File

@ -0,0 +1,52 @@
// Copyright 2021 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package berror
import (
"fmt"
)
// pre define code
// Unknown indicates got some error which is not defined
var Unknown = DefineCode(5000001, "error", "Unknown",fmt.Sprintf(`
Unknown error code. Usually you will see this code in three cases:
1. You forget to define Code or function DefineCode not being executed;
2. This is not Beego's error but you call FromError();
3. Beego got unexpected error and don't know how to handle it, and then return Unknown error
A common practice to DefineCode looks like:
%s
In this way, you may forget to import this package, and got Unknown error.
Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err)
`, goCodeBlock(`
import your_package
func init() {
DefineCode(5100100, "your_module", "detail")
// ...
}
`)))
func goCodeBlock(code string) string {
return codeBlock("go", code)
}
func codeBlock(lan string, code string) string {
return fmt.Sprintf("```%s\n%s\n```", lan, code)
}

View File

@ -69,8 +69,8 @@ func (p *PatternLogFormatter) ToString(lm *LogMsg) string {
'm': lm.Msg,
'n': strconv.Itoa(lm.LineNumber),
'l': strconv.Itoa(lm.Level),
't': levelPrefix[lm.Level-1],
'T': levelNames[lm.Level-1],
't': levelPrefix[lm.Level],
'T': levelNames[lm.Level],
'F': lm.FilePath,
}
_, m['f'] = path.Split(lm.FilePath)

View File

@ -88,7 +88,7 @@ func TestPatternLogFormatter(t *testing.T) {
}
got := tes.ToString(lm)
want := lm.FilePath + ":" + strconv.Itoa(lm.LineNumber) + "|" +
when.Format(tes.WhenFormat) + levelPrefix[lm.Level-1] + ">> " + lm.Msg
when.Format(tes.WhenFormat) + levelPrefix[lm.Level] + ">> " + lm.Msg
if got != want {
t.Errorf("want %s, got %s", want, got)
}

2
go.mod
View File

@ -25,7 +25,7 @@ require (
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-cmp v0.5.0 // indirect
github.com/google/uuid v1.1.1 // indirect
github.com/google/uuid v1.1.1
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/hashicorp/golang-lru v0.5.4
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6

View File

@ -108,8 +108,11 @@ func registerAdmin() error {
c := &adminController{
servers: make([]*HttpServer, 0, 2),
}
// copy config to avoid conflict
adminCfg := *BConfig
beeAdminApp = &adminApp{
HttpServer: NewHttpServerWithCfg(BConfig),
HttpServer: NewHttpServerWithCfg(&adminCfg),
}
// keep in mind that all data should be html escaped to avoid XSS attack
beeAdminApp.Router("/", c, "get:AdminIndex")

View File

@ -29,6 +29,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/beego/beego/v2/server/web/session"
"net"
"net/http"
"strconv"
@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) {
}
}
// Session return session store of this context of request
func (ctx *Context) Session() (store session.Store, err error) {
if ctx.Input != nil {
if ctx.Input.CruSession != nil {
store = ctx.Input.CruSession
return
} else {
err = errors.New(`no valid session store(please initialize session)`)
return
}
} else {
err = errors.New(`no valid input`)
return
}
}
// Response is a wrapper for the http.ResponseWriter
// Started: if true, response was already written to so the other handler will not be executed
type Response struct {

View File

@ -15,6 +15,7 @@
package context
import (
"github.com/beego/beego/v2/server/web/session"
"net/http"
"net/http/httptest"
"testing"
@ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) {
t.FailNow()
}
}
func TestContext_Session(t *testing.T) {
c := NewContext()
if store, err := c.Session(); store != nil || err == nil {
t.FailNow()
}
}
func TestContext_Session1(t *testing.T) {
c := Context{}
if store, err := c.Session(); store != nil || err == nil {
t.FailNow()
}
}
func TestContext_Session2(t *testing.T) {
c := NewContext()
c.Input.CruSession = &session.MemSessionStore{}
if store, err := c.Session(); store == nil || err != nil {
t.FailNow()
}
}

View File

@ -17,23 +17,49 @@ package prometheus
import (
"strconv"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/beego/beego/v2"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/server/web"
"github.com/beego/beego/v2/server/web/context"
)
const unknownRouterPattern = "UnknownRouterPattern"
// FilterChainBuilder is an extension point,
// when we want to support some configuration,
// please use this structure
type FilterChainBuilder struct {
}
var summaryVec prometheus.ObserverVec
var initSummaryVec sync.Once
// FilterChain returns a FilterFunc. The filter will records some metrics
func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc {
initSummaryVec.Do(func() {
summaryVec = builder.buildVec()
err := prometheus.Register(summaryVec)
if _, ok := err.(*prometheus.AlreadyRegisteredError); err != nil && !ok {
logs.Error("web module register prometheus vector failed, %+v", err)
}
registerBuildInfo()
})
return func(ctx *context.Context) {
startTime := time.Now()
next(ctx)
endTime := time.Now()
go report(endTime.Sub(startTime), ctx, summaryVec)
}
}
func (builder *FilterChainBuilder) buildVec() *prometheus.SummaryVec {
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "http_request",
@ -44,17 +70,7 @@ func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFu
},
Help: "The statics info for http request",
}, []string{"pattern", "method", "status"})
prometheus.MustRegister(summaryVec)
registerBuildInfo()
return func(ctx *context.Context) {
startTime := time.Now()
next(ctx)
endTime := time.Now()
go report(endTime.Sub(startTime), ctx, summaryVec)
}
return summaryVec
}
func registerBuildInfo() {
@ -75,13 +91,17 @@ func registerBuildInfo() {
},
}, []string{})
prometheus.MustRegister(buildInfo)
_ = prometheus.Register(buildInfo)
buildInfo.WithLabelValues().Set(1)
}
func report(dur time.Duration, ctx *context.Context, vec *prometheus.SummaryVec) {
func report(dur time.Duration, ctx *context.Context, vec prometheus.ObserverVec) {
status := ctx.Output.Status
ptn := ctx.Input.GetData("RouterPattern").(string)
ptnItf := ctx.Input.GetData("RouterPattern")
ptn := unknownRouterPattern
if ptnItf != nil {
ptn = ptnItf.(string)
}
ms := dur / time.Millisecond
vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status)).Observe(float64(ms))
}

View File

@ -18,6 +18,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -37,4 +38,19 @@ func TestFilterChain(t *testing.T) {
ctx.Input.SetData("RouterPattern", "my-route")
filter(ctx)
assert.True(t, ctx.Input.GetData("invocation").(bool))
time.Sleep(1 * time.Second)
}
func TestFilterChainBuilder_report(t *testing.T) {
ctx := context.NewContext()
r, _ := http.NewRequest("GET", "/prometheus/user", nil)
w := httptest.NewRecorder()
ctx.Reset(w, r)
fb := &FilterChainBuilder{}
// without router info
report(time.Second, ctx, fb.buildVec())
ctx.Input.SetData("RouterPattern", "my-route")
report(time.Second, ctx, fb.buildVec())
}

View File

@ -0,0 +1,35 @@
package session
import (
"context"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/server/web"
webContext "github.com/beego/beego/v2/server/web/context"
"github.com/beego/beego/v2/server/web/session"
)
//Session maintain session for web service
//Session new a session storage and store it into webContext.Context
func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain {
sessionConfig := session.NewManagerConfig(options...)
sessionManager, _ := session.NewManager(string(providerType), sessionConfig)
go sessionManager.GC()
return func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if ctx.Input.CruSession != nil {
return
}
if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil {
logs.Error(`init session error:%s`, err.Error())
} else {
//release session at the end of request
defer sess.SessionRelease(context.Background(), ctx.ResponseWriter)
ctx.Input.CruSession = sess
}
next(ctx)
}
}
}

View File

@ -0,0 +1,86 @@
package session
import (
"github.com/beego/beego/v2/server/web"
webContext "github.com/beego/beego/v2/server/web/context"
"github.com/beego/beego/v2/server/web/session"
"github.com/google/uuid"
"net/http"
"net/http/httptest"
"testing"
)
func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) {
r, _ := http.NewRequest(method, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code)
}
}
func TestSession(t *testing.T) {
storeKey := uuid.New().String()
handler := web.NewControllerRegister()
handler.InsertFilterChain(
"*",
Session(
session.ProviderMemory,
session.CfgCookieName(`go_session_id`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
),
)
handler.InsertFilterChain(
"*",
func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if store := ctx.Input.GetData(storeKey); store == nil {
t.Error(`store should not be nil`)
}
next(ctx)
}
},
)
handler.Any("*", func(ctx *webContext.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
}
func TestSession1(t *testing.T) {
handler := web.NewControllerRegister()
handler.InsertFilterChain(
"*",
Session(
session.ProviderMemory,
session.CfgCookieName(`go_session_id`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
),
)
handler.InsertFilterChain(
"*",
func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if store, err := ctx.Session(); store == nil || err != nil {
t.Error(`store should not be nil`)
}
next(ctx)
}
},
)
handler.Any("*", func(ctx *webContext.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
}

View File

@ -15,9 +15,12 @@
package web
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello"))
_ = ctx.Output.Body([]byte("hello"))
})
r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
}
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
r, _ := http.NewRequest("GET", "/abc", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
first := w.Header().Get("first")
second := w.Header().Get("second")
ft, _ := strconv.ParseInt(first, 10, 64)
st, _ := strconv.ParseInt(second, 10, 64)
assert.True(t, st > ft)
}

View File

@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) {
// setup the handler
handler := NewControllerRegister()
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
handler.ServeHTTP(w, r)
// get the Set-Cookie value

View File

@ -6,6 +6,8 @@ import (
"net/http"
"path/filepath"
"github.com/coreos/etcd/pkg/fileutil"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/server/web/context"
"github.com/beego/beego/v2/server/web/session"
@ -99,7 +101,12 @@ func registerGzip() error {
func registerCommentRouter() error {
if BConfig.RunMode == DEV {
if err := parserPkg(filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)); err != nil {
ctrlDir := filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)
if !fileutil.Exist(ctrlDir) {
logs.Warn("controller package not found, won't generate router: ", ctrlDir)
return nil
}
if err := parserPkg(ctrlDir); err != nil {
return err
}
}

View File

@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
// Router same as beego.Rourer
// refer: https://godoc.org/github.com/beego/beego/v2#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...))
return n
}
@ -187,6 +187,54 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
return n
}
// RouterGet same as beego.RouterGet
func (n *Namespace) RouterGet(rootpath string, f interface{}) *Namespace {
n.handlers.RouterGet(rootpath, f)
return n
}
// RouterPost same as beego.RouterPost
func (n *Namespace) RouterPost(rootpath string, f interface{}) *Namespace {
n.handlers.RouterPost(rootpath, f)
return n
}
// RouterDelete same as beego.RouterDelete
func (n *Namespace) RouterDelete(rootpath string, f interface{}) *Namespace {
n.handlers.RouterDelete(rootpath, f)
return n
}
// RouterPut same as beego.RouterPut
func (n *Namespace) RouterPut(rootpath string, f interface{}) *Namespace {
n.handlers.RouterPut(rootpath, f)
return n
}
// RouterHead same as beego.RouterHead
func (n *Namespace) RouterHead(rootpath string, f interface{}) *Namespace {
n.handlers.RouterHead(rootpath, f)
return n
}
// RouterOptions same as beego.RouterOptions
func (n *Namespace) RouterOptions(rootpath string, f interface{}) *Namespace {
n.handlers.RouterOptions(rootpath, f)
return n
}
// RouterPatch same as beego.RouterPatch
func (n *Namespace) RouterPatch(rootpath string, f interface{}) *Namespace {
n.handlers.RouterPatch(rootpath, f)
return n
}
// Any same as beego.RouterAny
func (n *Namespace) RouterAny(rootpath string, f interface{}) *Namespace {
n.handlers.RouterAny(rootpath, f)
return n
}
// Namespace add nest Namespace
// usage:
// ns := beego.NewNamespace(“/v1”).
@ -366,6 +414,62 @@ func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
}
}
// NSRouterGet call Namespace RouterGet
func NSRouterGet(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterGet(rootpath, f)
}
}
// NSRouterPost call Namespace RouterPost
func NSRouterPost(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterPost(rootpath, f)
}
}
// NSRouterHead call Namespace RouterHead
func NSRouterHead(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterHead(rootpath, f)
}
}
// NSRouterPut call Namespace RouterPut
func NSRouterPut(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterPut(rootpath, f)
}
}
// NSRouterDelete call Namespace RouterDelete
func NSRouterDelete(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterDelete(rootpath, f)
}
}
// NSRouterAny call Namespace RouterAny
func NSRouterAny(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterAny(rootpath, f)
}
}
// NSRouterOptions call Namespace RouterOptions
func NSRouterOptions(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterOptions(rootpath, f)
}
}
// NSRouterPatch call Namespace RouterPatch
func NSRouterPatch(rootpath string, f interface{}) LinkNamespace {
return func(ns *Namespace) {
ns.RouterPatch(rootpath, f)
}
}
// NSAutoRouter call Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {

View File

@ -15,6 +15,7 @@
package web
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
@ -23,6 +24,40 @@ import (
"github.com/beego/beego/v2/server/web/context"
)
const (
exampleBody = "hello world"
examplePointerBody = "hello world pointer"
nsNamespace = "/router"
nsPath = "/user"
nsNamespacePath = "/router/user"
)
type ExampleController struct {
Controller
}
func (m ExampleController) Ping() {
err := m.Ctx.Output.Body([]byte(exampleBody))
if err != nil {
fmt.Println(err)
}
}
func (m *ExampleController) PingPointer() {
err := m.Ctx.Output.Body([]byte(examplePointerBody))
if err != nil {
fmt.Println(err)
}
}
func (m ExampleController) ping() {
err := m.Ctx.Output.Body([]byte("ping method"))
if err != nil {
fmt.Println(err)
}
}
func TestNamespaceGet(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/user", nil)
w := httptest.NewRecorder()
@ -166,3 +201,215 @@ func TestNamespaceInside(t *testing.T) {
t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterGet(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterGet(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterGet can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterPost(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterPost(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterPost can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterDelete(t *testing.T) {
r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterDelete(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterDelete can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterPut(t *testing.T) {
r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterPut(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterPut can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterHead(t *testing.T) {
r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterHead(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterHead can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterOptions(t *testing.T) {
r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterOptions(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterOptions can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterPatch(t *testing.T) {
r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
ns.RouterPatch(nsPath, ExampleController.Ping)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterPatch can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouterAny(t *testing.T) {
ns := NewNamespace(nsNamespace)
ns.RouterAny(nsPath, ExampleController.Ping)
AddNamespace(ns)
for method := range HTTPMETHOD {
w := httptest.NewRecorder()
r, _ := http.NewRequest(method, nsNamespacePath, nil)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceRouterAny can't run, get the response is " + w.Body.String())
}
}
}
func TestNamespaceNSRouterGet(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterGet(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterGet can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterPost(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace("/router")
NSRouterPost(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterPost can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterDelete(t *testing.T) {
r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterDelete(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterDelete can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterPut(t *testing.T) {
r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterPut(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterPut can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterHead(t *testing.T) {
r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterHead(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterHead can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterOptions(t *testing.T) {
r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterOptions(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterOptions can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterPatch(t *testing.T) {
r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil)
w := httptest.NewRecorder()
ns := NewNamespace(nsNamespace)
NSRouterPatch("/user", ExampleController.Ping)(ns)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterPatch can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNSRouterAny(t *testing.T) {
ns := NewNamespace(nsNamespace)
NSRouterAny(nsPath, ExampleController.Ping)(ns)
AddNamespace(ns)
for method := range HTTPMETHOD {
w := httptest.NewRecorder()
r, _ := http.NewRequest(method, nsNamespacePath, nil)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestNamespaceNSRouterAny can't run, get the response is " + w.Body.String())
}
}
}

View File

@ -20,6 +20,7 @@ import (
"net/http"
"path"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
@ -118,12 +119,33 @@ type ControllerInfo struct {
routerType int
initialize func() ControllerInterface
methodParams []*param.MethodParam
sessionOn bool
}
type ControllerOption func(*ControllerInfo)
func (c *ControllerInfo) GetPattern() string {
return c.pattern
}
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
return func(c *ControllerInfo) {
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
}
}
func WithRouterSessionOn(sessionOn bool) ControllerOption {
return func(c *ControllerInfo) {
c.sessionOn = sessionOn
}
}
type filterChainConfig struct {
pattern string
chain FilterChain
opts []FilterOpt
}
// ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct {
routers map[string]*Tree
@ -136,6 +158,9 @@ type ControllerRegister struct {
// the filter created by FilterChain
chainRoot *FilterRouter
// keep registered chain and build it when serve http
filterChains []filterChainConfig
cfg *Config
}
@ -155,12 +180,24 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
return beecontext.NewContext()
},
},
cfg: cfg,
cfg: cfg,
filterChains: make([]filterChainConfig, 0, 4),
}
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
return res
}
// Init will be executed when HttpServer start running
func (p *ControllerRegister) Init() {
for i := len(p.filterChains) - 1; i >= 0; i-- {
fc := p.filterChains[i]
root := p.chainRoot
filterFunc := fc.chain(root.filterFunc)
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
p.chainRoot.next = root
}
}
// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
@ -171,41 +208,64 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
// Add("/api/delete",&RestController{},"delete:DeleteFood")
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) {
p.addWithMethodParams(pattern, c, nil, opts...)
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
if len(mappingMethods) > 0 {
semi := strings.Split(mappingMethods[0], ";")
for _, v := range semi {
colon := strings.Split(v, ":")
if len(colon) != 2 {
panic("method mapping format is invalid")
if len(mappingMethods) == 0 {
return methods
}
semi := strings.Split(mappingMethods[0], ";")
for _, v := range semi {
colon := strings.Split(v, ":")
if len(colon) != 2 {
panic("method mapping format is invalid")
}
comma := strings.Split(colon[0], ",")
for _, m := range comma {
if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] {
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
}
comma := strings.Split(colon[0], ",")
for _, m := range comma {
if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1]
} else {
panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
}
} else {
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
}
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1]
continue
}
panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
}
}
route := &ControllerInfo{}
route.pattern = pattern
route.methods = methods
route.routerType = routerTypeBeego
route.controllerType = t
return methods
}
func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) {
if len(route.methods) == 0 {
for m := range HTTPMETHOD {
p.addToRouter(m, route.pattern, route)
}
return
}
for k := range route.methods {
if k != "*" {
p.addToRouter(k, route.pattern, route)
continue
}
for m := range HTTPMETHOD {
p.addToRouter(m, route.pattern, route)
}
}
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
route := p.createBeegoRouter(t, pattern)
route.initialize = func() ControllerInterface {
vc := reflect.New(route.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
@ -229,23 +289,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
return execController
}
route.methodParams = methodParams
if len(methods) == 0 {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
for k := range methods {
if k == "*" {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
p.addToRouter(k, pattern, route)
}
}
for i := range opts {
opts[i](route)
}
globalSessionOn := p.cfg.WebConfig.Session.SessionOn
if !globalSessionOn && route.sessionOn {
logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern)
route.sessionOn = globalSessionOn
}
p.addRouterForMethod(route)
}
func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
@ -273,7 +328,8 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
for _, f := range a.Filters {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
}
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
}
}
}
@ -294,6 +350,261 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) {
p.pool.Put(ctx)
}
// RouterGet add get method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterGet("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterGet(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodGet, pattern, f)
}
// RouterPost add post method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPost("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterPost(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodPost, pattern, f)
}
// RouterHead add head method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterHead("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterHead(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodHead, pattern, f)
}
// RouterPut add put method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPut("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterPut(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodPut, pattern, f)
}
// RouterPatch add patch method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPatch("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterPatch(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodPatch, pattern, f)
}
// RouterDelete add delete method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterDelete("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterDelete(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodDelete, pattern, f)
}
// RouterOptions add options method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterOptions("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterOptions(pattern string, f interface{}) {
p.AddRouterMethod(http.MethodOptions, pattern, f)
}
// RouterAny add all method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterAny("/api/:id", MyController.Ping)
func (p *ControllerRegister) RouterAny(pattern string, f interface{}) {
p.AddRouterMethod("*", pattern, f)
}
// AddRouterMethod add http method router
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// AddRouterMethod("get","/api/:id", MyController.Ping)
func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) {
httpMethod = p.getUpperMethodString(httpMethod)
ct, methodName := getReflectTypeAndMethod(f)
p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern)
}
// addBeegoTypeRouter add beego type router
func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) {
route := p.createBeegoRouter(ct, pattern)
methods := p.getHttpMethodMapMethod(httpMethod, ctMethod)
route.methods = methods
p.addRouterForMethod(route)
}
// createBeegoRouter create beego router base on reflect type and pattern
func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeBeego
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.controllerType = ct
return route
}
// createRestfulRouter create restful router with filter function and pattern
func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.runFunction = f
return route
}
// createHandlerRouter create handler router with handler and pattern
func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.handler = h
return route
}
// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will
// use http method as the controller method
func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string {
methods := make(map[string]string)
// not match-all sign, only add for the http method
if httpMethod != "*" {
if ctMethod == "" {
ctMethod = httpMethod
}
methods[httpMethod] = ctMethod
return methods
}
// add all http method
for val := range HTTPMETHOD {
if ctMethod == "" {
methods[val] = val
} else {
methods[val] = ctMethod
}
}
return methods
}
// getUpperMethodString get upper string of method, and panic if the method
// is not valid
func (p *ControllerRegister) getUpperMethodString(method string) string {
method = strings.ToUpper(method)
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method)
}
return method
}
// get reflect controller type and method by controller method expression
func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) {
// check f is a function
funcType := reflect.TypeOf(f)
if funcType.Kind() != reflect.Func {
panic("not a method")
}
// get function name
funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
if funcObj == nil {
panic("cannot find the method")
}
funcNameSli := strings.Split(funcObj.Name(), ".")
lFuncSli := len(funcNameSli)
if lFuncSli == 0 {
panic("invalid method full name: " + funcObj.Name())
}
method = funcNameSli[lFuncSli-1]
if len(method) == 0 {
panic("method name is empty")
} else if method[0] > 96 || method[0] < 65 {
panic(fmt.Sprintf("%s is not a public method", method))
}
// check only one param which is the method receiver
if numIn := funcType.NumIn(); numIn != 1 {
panic("invalid number of param in")
}
controllerType = funcType.In(0)
// check controller has the method
_, exists := controllerType.MethodByName(method)
if !exists {
panic(controllerType.String() + " has no method " + method)
}
// check the receiver implement ControllerInterface
if controllerType.Kind() == reflect.Ptr {
controllerType = controllerType.Elem()
}
controller := reflect.New(controllerType)
_, ok := controller.Interface().(ControllerInterface)
if !ok {
panic(controllerType.String() + " is not implemented ControllerInterface")
}
return
}
// Get add get method
// usage:
// Get("/", func(ctx *context.Context){
@ -372,34 +683,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method)
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method)
}
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for val := range HTTPMETHOD {
methods[val] = val
}
} else {
methods[method] = method
}
method = p.getUpperMethodString(method)
route := p.createRestfulRouter(f, pattern)
methods := p.getHttpMethodMapMethod(method, "")
route.methods = methods
for k := range methods {
p.addToRouter(k, pattern, route)
}
p.addRouterForMethod(route)
}
// Handler add user defined Handler
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.handler = h
route := p.createHandlerRouter(h, pattern)
if len(options) > 0 {
if _, ok := options[0].(bool); ok {
pattern = path.Join(pattern, "?:all(.*)")
@ -431,15 +726,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
route := &ControllerInfo{}
route.routerType = routerTypeBeego
route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
route := p.createBeegoRouter(ct, pattern)
route.methods = map[string]string{"*": rt.Method(i).Name}
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route)
@ -472,12 +765,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// }
// }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
p.chainRoot.next = root
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.filterChains = append(p.filterChains, filterChainConfig{
pattern: pattern,
chain: chain,
opts: opts,
})
}
// add Filter into
@ -542,7 +836,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) {
find := false
if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 {
@ -664,12 +958,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
r := ctx.Request
rw := ctx.ResponseWriter.ResponseWriter
var (
runRouter reflect.Type
findRouter bool
runMethod string
methodParams []*param.MethodParam
routerInfo *ControllerInfo
isRunnable bool
runRouter reflect.Type
findRouter bool
runMethod string
methodParams []*param.MethodParam
routerInfo *ControllerInfo
isRunnable bool
currentSessionOn bool
originRouterInfo *ControllerInfo
originFindRouter bool
)
if p.cfg.RecoverFunc != nil {
@ -735,7 +1032,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
}
// session init
if p.cfg.WebConfig.Session.SessionOn {
currentSessionOn = p.cfg.WebConfig.Session.SessionOn
originRouterInfo, originFindRouter = p.FindRouter(ctx)
if originFindRouter {
currentSessionOn = originRouterInfo.sessionOn
}
if currentSessionOn {
ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil {
logs.Error(err)

View File

@ -16,6 +16,7 @@ package web
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"strings"
@ -26,6 +27,25 @@ import (
"github.com/beego/beego/v2/server/web/context"
)
type PrefixTestController struct {
Controller
}
func (ptc *PrefixTestController) PrefixList() {
ptc.Ctx.Output.Body([]byte("i am list in prefix test"))
}
type TestControllerWithInterface struct {
}
func (m TestControllerWithInterface) Ping() {
fmt.Println("pong")
}
func (m *TestControllerWithInterface) PingPointer() {
fmt.Println("pong pointer")
}
type TestController struct {
Controller
}
@ -87,10 +107,24 @@ func (jc *JSONController) Get() {
jc.Ctx.Output.Body([]byte("ok"))
}
func TestPrefixUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList"))
if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` {
logs.Info(a)
t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list")
}
if a := handler.URLFor(`TestController.PrefixList`); a != `` {
logs.Info(a)
t.Errorf("TestController.PrefixList must equal to empty string")
}
}
func TestUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
if a := handler.URLFor("TestController.List"); a != "/api/list" {
logs.Info(a)
t.Errorf("TestController.List must equal to /api/list")
@ -113,9 +147,9 @@ func TestUrlFor3(t *testing.T) {
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL"))
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
logs.Info(handler.URLFor("TestController.GetURL"))
@ -145,7 +179,7 @@ func TestUserFunc(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
@ -235,7 +269,7 @@ func TestRouteOk(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/person/:last/:first", &TestController{}, "get:GetParams")
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams"))
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != "anderson+thomas+kungfu" {
@ -249,7 +283,7 @@ func TestManyRoute(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter")
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter"))
handler.ServeHTTP(w, r)
body := w.Body.String()
@ -266,7 +300,7 @@ func TestEmptyResponse(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody")
handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody"))
handler.ServeHTTP(w, r)
if body := w.Body.String(); body != "" {
@ -750,3 +784,321 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) {
t.Errorf("TestRouterRequestEntityTooLarge can't run")
}
}
func TestRouterSessionSet(t *testing.T) {
oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn
defer func() {
BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn
}()
// global sessionOn = false, router sessionOn = false
r, _ := http.NewRequest("GET", "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(false))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
}
// global sessionOn = false, router sessionOn = true
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(true))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
}
BConfig.WebConfig.Session.SessionOn = true
if err := registerSession(); err != nil {
t.Errorf("register session failed, error: %s", err.Error())
}
// global sessionOn = true, router sessionOn = false
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(false))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
}
// global sessionOn = true, router sessionOn = true
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(true))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") == "" {
t.Errorf("TestRotuerSessionSet failed")
}
}
func TestRouterRouterGet(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterGet("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterGet can't run")
}
}
func TestRouterRouterPost(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPost("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterPost can't run")
}
}
func TestRouterRouterHead(t *testing.T) {
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterHead("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterHead can't run")
}
}
func TestRouterRouterPut(t *testing.T) {
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPut("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterPut can't run")
}
}
func TestRouterRouterPatch(t *testing.T) {
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPatch("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterPatch can't run")
}
}
func TestRouterRouterDelete(t *testing.T) {
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterDelete("/user", ExampleController.Ping)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterDelete can't run")
}
}
func TestRouterRouterAny(t *testing.T) {
handler := NewControllerRegister()
handler.RouterAny("/user", ExampleController.Ping)
for method := range HTTPMETHOD {
w := httptest.NewRecorder()
r, _ := http.NewRequest(method, "/user", nil)
handler.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestRouterRouterAny can't run, get the response is " + w.Body.String())
}
}
}
func TestRouterRouterGetPointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterGet("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterGetPointerMethod can't run")
}
}
func TestRouterRouterPostPointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPost("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterPostPointerMethod can't run")
}
}
func TestRouterRouterHeadPointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterHead("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterHeadPointerMethod can't run")
}
}
func TestRouterRouterPutPointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPut("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterPutPointerMethod can't run")
}
}
func TestRouterRouterPatchPointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterPatch("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterPatchPointerMethod can't run")
}
}
func TestRouterRouterDeletePointerMethod(t *testing.T) {
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.RouterDelete("/user", (*ExampleController).PingPointer)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterDeletePointerMethod can't run")
}
}
func TestRouterRouterAnyPointerMethod(t *testing.T) {
handler := NewControllerRegister()
handler.RouterAny("/user", (*ExampleController).PingPointer)
for method := range HTTPMETHOD {
w := httptest.NewRecorder()
r, _ := http.NewRequest(method, "/user", nil)
handler.ServeHTTP(w, r)
if w.Body.String() != examplePointerBody {
t.Errorf("TestRouterRouterAnyPointerMethod can't run, get the response is " + w.Body.String())
}
}
}
func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) {
method := "some random method"
message := "not support http method: " + strings.ToUpper(method)
defer func() {
err := recover()
if err != nil { //产生了panic异常
errStr, ok := err.(string)
if ok && errStr == message {
return
}
}
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err))
}()
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", ExampleController.Ping)
}
func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) {
method := http.MethodGet
message := "not a method"
defer func() {
err := recover()
if err != nil { //产生了panic异常
errStr, ok := err.(string)
if ok && errStr == message {
return
}
}
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err))
}()
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", ExampleController{})
}
func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) {
method := http.MethodGet
message := "ping is not a public method"
defer func() {
err := recover()
if err != nil { //产生了panic异常
errStr, ok := err.(string)
if ok && errStr == message {
return
}
}
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err))
}()
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", ExampleController.ping)
}
func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) {
method := http.MethodGet
message := "web.TestControllerWithInterface is not implemented ControllerInterface"
defer func() {
err := recover()
if err != nil { //产生了panic异常
errStr, ok := err.(string)
if ok && errStr == message {
return
}
}
t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err))
}()
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping)
}
func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) {
method := http.MethodGet
message := "web.TestControllerWithInterface is not implemented ControllerInterface"
defer func() {
err := recover()
if err != nil { //产生了panic异常
errStr, ok := err.(string)
if ok && errStr == message {
return
}
}
t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err))
}()
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer)
}

View File

@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
// init...
app.initAddr(addr)
app.Handlers.Init()
addr = app.Cfg.Listen.HTTPAddr
@ -267,7 +269,11 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
// Router see HttpServer.Router
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return BeeApp.Router(rootpath, c, mappingMethods...)
return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...))
}
func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
return BeeApp.RouterWithOpts(rootpath, c, opts...)
}
// Router adds a patterned controller handler to BeeApp.
@ -287,7 +293,11 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
app.Handlers.Add(rootPath, c, mappingMethods...)
return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...))
}
func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
app.Handlers.Add(rootPath, c, opts...)
return app
}
@ -453,6 +463,166 @@ func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpSer
return app
}
// RouterGet see HttpServer.RouterGet
func RouterGet(rootpath string, f interface{}) {
BeeApp.RouterGet(rootpath, f)
}
// RouterGet used to register router for RouterGet method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterGet("/api/:id", MyController.Ping)
func (app *HttpServer) RouterGet(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterGet(rootpath, f)
return app
}
// RouterPost see HttpServer.RouterGet
func RouterPost(rootpath string, f interface{}) {
BeeApp.RouterPost(rootpath, f)
}
// RouterPost used to register router for RouterPost method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPost("/api/:id", MyController.Ping)
func (app *HttpServer) RouterPost(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterPost(rootpath, f)
return app
}
// RouterHead see HttpServer.RouterHead
func RouterHead(rootpath string, f interface{}) {
BeeApp.RouterHead(rootpath, f)
}
// RouterHead used to register router for RouterHead method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterHead("/api/:id", MyController.Ping)
func (app *HttpServer) RouterHead(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterHead(rootpath, f)
return app
}
// RouterPut see HttpServer.RouterPut
func RouterPut(rootpath string, f interface{}) {
BeeApp.RouterPut(rootpath, f)
}
// RouterPut used to register router for RouterPut method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPut("/api/:id", MyController.Ping)
func (app *HttpServer) RouterPut(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterPut(rootpath, f)
return app
}
// RouterPatch see HttpServer.RouterPatch
func RouterPatch(rootpath string, f interface{}) {
BeeApp.RouterPatch(rootpath, f)
}
// RouterPatch used to register router for RouterPatch method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterPatch("/api/:id", MyController.Ping)
func (app *HttpServer) RouterPatch(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterPatch(rootpath, f)
return app
}
// RouterDelete see HttpServer.RouterDelete
func RouterDelete(rootpath string, f interface{}) {
BeeApp.RouterDelete(rootpath, f)
}
// RouterDelete used to register router for RouterDelete method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterDelete("/api/:id", MyController.Ping)
func (app *HttpServer) RouterDelete(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterDelete(rootpath, f)
return app
}
// RouterOptions see HttpServer.RouterOptions
func RouterOptions(rootpath string, f interface{}) {
BeeApp.RouterOptions(rootpath, f)
}
// RouterOptions used to register router for RouterOptions method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterOptions("/api/:id", MyController.Ping)
func (app *HttpServer) RouterOptions(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterOptions(rootpath, f)
return app
}
// RouterAny see HttpServer.RouterAny
func RouterAny(rootpath string, f interface{}) {
BeeApp.RouterAny(rootpath, f)
}
// RouterAny used to register router for RouterAny method
// usage:
// type MyController struct {
// web.Controller
// }
// func (m MyController) Ping() {
// m.Ctx.Output.Body([]byte("hello world"))
// }
//
// RouterAny("/api/:id", MyController.Ping)
func (app *HttpServer) RouterAny(rootpath string, f interface{}) *HttpServer {
app.Handlers.RouterAny(rootpath, f)
return app
}
// Get see HttpServer.Get
func Get(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Get(rootpath, f)

View File

@ -15,6 +15,8 @@
package web
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
@ -28,3 +30,82 @@ func TestNewHttpServerWithCfg(t *testing.T) {
assert.Equal(t, "hello", BConfig.AppName)
}
func TestServerRouterGet(t *testing.T) {
r, _ := http.NewRequest(http.MethodGet, "/user", nil)
w := httptest.NewRecorder()
RouterGet("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterGet can't run")
}
}
func TestServerRouterPost(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/user", nil)
w := httptest.NewRecorder()
RouterPost("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterPost can't run")
}
}
func TestServerRouterHead(t *testing.T) {
r, _ := http.NewRequest(http.MethodHead, "/user", nil)
w := httptest.NewRecorder()
RouterHead("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterHead can't run")
}
}
func TestServerRouterPut(t *testing.T) {
r, _ := http.NewRequest(http.MethodPut, "/user", nil)
w := httptest.NewRecorder()
RouterPut("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterPut can't run")
}
}
func TestServerRouterPatch(t *testing.T) {
r, _ := http.NewRequest(http.MethodPatch, "/user", nil)
w := httptest.NewRecorder()
RouterPatch("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterPatch can't run")
}
}
func TestServerRouterDelete(t *testing.T) {
r, _ := http.NewRequest(http.MethodDelete, "/user", nil)
w := httptest.NewRecorder()
RouterDelete("/user", ExampleController.Ping)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterDelete can't run")
}
}
func TestServerRouterAny(t *testing.T) {
RouterAny("/user", ExampleController.Ping)
for method := range HTTPMETHOD {
r, _ := http.NewRequest(method, "/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != exampleBody {
t.Errorf("TestServerRouterAny can't run")
}
}
}

View File

@ -15,21 +15,22 @@ import (
)
func TestRedis(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
}
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig(redisConfig),
)
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Fatal("could not create manager:", err)

View File

@ -13,15 +13,15 @@ import (
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
)
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)

View File

@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) {
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
CookieSameSite http.SameSite `json:"cookieSameSite"`
}
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider

View File

@ -0,0 +1,143 @@
package session
import "net/http"
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
CookieSameSite http.SameSite `json:"cookieSameSite"`
}
func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) {
for _, opt := range opts {
opt(c)
}
}
type ManagerConfigOpt func(config *ManagerConfig)
func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig {
config := &ManagerConfig{}
for _, opt := range opts {
opt(config)
}
return config
}
// CfgCookieName set key of session id
func CfgCookieName(cookieName string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieName = cookieName
}
}
// CfgCookieName set len of session id
func CfgSessionIdLength(len int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionIDLength = len
}
}
// CfgSessionIdPrefix set prefix of session id
func CfgSessionIdPrefix(prefix string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionIDPrefix = prefix
}
}
//CfgSetCookie whether set `Set-Cookie` header in HTTP response
func CfgSetCookie(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSetCookie = enable
}
}
//CfgGcLifeTime set session gc lift time
func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Gclifetime = lifeTime
}
}
//CfgMaxLifeTime set session lift time
func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Maxlifetime = lifeTime
}
}
//CfgGcLifeTime set session lift time
func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieLifeTime = lifeTime
}
}
//CfgProviderConfig configure session provider
func CfgProviderConfig(providerConfig string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.ProviderConfig = providerConfig
}
}
//CfgDomain set cookie domain
func CfgDomain(domain string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Domain = domain
}
}
//CfgSessionIdInHTTPHeader enable session id in http header
func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSidInHTTPHeader = enable
}
}
//CfgSetSessionNameInHTTPHeader set key of session id in http header
func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionNameInHTTPHeader = name
}
}
//EnableSidInURLQuery enable session id in query string
func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSidInURLQuery = enable
}
}
//DisableHTTPOnly set HTTPOnly for http.Cookie
func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.DisableHTTPOnly = !HTTPOnly
}
}
//CfgSecure set Secure for http.Cookie
func CfgSecure(Enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Secure = Enable
}
}
//CfgSameSite set http.SameSite
func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieSameSite = sameSite
}
}

View File

@ -0,0 +1,222 @@
package session
import (
"net/http"
"testing"
)
func TestCfgCookieLifeTime(t *testing.T) {
value := 8754
c := NewManagerConfig(
CfgCookieLifeTime(value),
)
if c.CookieLifeTime != value {
t.Error()
}
}
func TestCfgDomain(t *testing.T) {
value := `http://domain.com`
c := NewManagerConfig(
CfgDomain(value),
)
if c.Domain != value {
t.Error()
}
}
func TestCfgSameSite(t *testing.T) {
value := http.SameSiteLaxMode
c := NewManagerConfig(
CfgSameSite(value),
)
if c.CookieSameSite != value {
t.Error()
}
}
func TestCfgSecure(t *testing.T) {
c := NewManagerConfig(
CfgSecure(true),
)
if c.Secure != true {
t.Error()
}
}
func TestCfgSecure1(t *testing.T) {
c := NewManagerConfig(
CfgSecure(false),
)
if c.Secure != false {
t.Error()
}
}
func TestCfgSessionIdPrefix(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgSessionIdPrefix(value),
)
if c.SessionIDPrefix != value {
t.Error()
}
}
func TestCfgSetSessionNameInHTTPHeader(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgSetSessionNameInHTTPHeader(value),
)
if c.SessionNameInHTTPHeader != value {
t.Error()
}
}
func TestCfgCookieName(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgCookieName(value),
)
if c.CookieName != value {
t.Error()
}
}
func TestCfgEnableSidInURLQuery(t *testing.T) {
c := NewManagerConfig(
CfgEnableSidInURLQuery(true),
)
if c.EnableSidInURLQuery != true {
t.Error()
}
}
func TestCfgGcLifeTime(t *testing.T) {
value := int64(5454)
c := NewManagerConfig(
CfgGcLifeTime(value),
)
if c.Gclifetime != value {
t.Error()
}
}
func TestCfgHTTPOnly(t *testing.T) {
c := NewManagerConfig(
CfgHTTPOnly(true),
)
if c.DisableHTTPOnly != false {
t.Error()
}
}
func TestCfgHTTPOnly2(t *testing.T) {
c := NewManagerConfig(
CfgHTTPOnly(false),
)
if c.DisableHTTPOnly != true {
t.Error()
}
}
func TestCfgMaxLifeTime(t *testing.T) {
value := int64(5454)
c := NewManagerConfig(
CfgMaxLifeTime(value),
)
if c.Maxlifetime != value {
t.Error()
}
}
func TestCfgProviderConfig(t *testing.T) {
value := `asodiuasldkj12i39809as`
c := NewManagerConfig(
CfgProviderConfig(value),
)
if c.ProviderConfig != value {
t.Error()
}
}
func TestCfgSessionIdInHTTPHeader(t *testing.T) {
c := NewManagerConfig(
CfgSessionIdInHTTPHeader(true),
)
if c.EnableSidInHTTPHeader != true {
t.Error()
}
}
func TestCfgSessionIdInHTTPHeader1(t *testing.T) {
c := NewManagerConfig(
CfgSessionIdInHTTPHeader(false),
)
if c.EnableSidInHTTPHeader != false {
t.Error()
}
}
func TestCfgSessionIdLength(t *testing.T) {
value := int64(100)
c := NewManagerConfig(
CfgSessionIdLength(value),
)
if c.SessionIDLength != value {
t.Error()
}
}
func TestCfgSetCookie(t *testing.T) {
c := NewManagerConfig(
CfgSetCookie(true),
)
if c.EnableSetCookie != true {
t.Error()
}
}
func TestCfgSetCookie1(t *testing.T) {
c := NewManagerConfig(
CfgSetCookie(false),
)
if c.EnableSetCookie != false {
t.Error()
}
}
func TestNewManagerConfig(t *testing.T) {
c := NewManagerConfig()
if c == nil {
t.Error()
}
}
func TestManagerConfig_Opts(t *testing.T) {
c := NewManagerConfig()
c.Opts(CfgSetCookie(true))
if c.EnableSetCookie != true {
t.Error()
}
}

View File

@ -0,0 +1,18 @@
package session
type ProviderType string
const (
ProviderCookie ProviderType = `cookie`
ProviderFile ProviderType = `file`
ProviderMemory ProviderType = `memory`
ProviderCouchbase ProviderType = `couchbase`
ProviderLedis ProviderType = `ledis`
ProviderMemcache ProviderType = `memcache`
ProviderMysql ProviderType = `mysql`
ProviderPostgresql ProviderType = `postgresql`
ProviderRedis ProviderType = `redis`
ProviderRedisCluster ProviderType = `redis_cluster`
ProviderRedisSentinel ProviderType = `redis_sentinel`
ProviderSsdb ProviderType = `ssdb`
)

View File

@ -210,9 +210,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) {
func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) {
if len(segments) == 0 {
if reg != "" {
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")})
t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...)
} else {
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards})
t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...)
}
} else {
seg := segments[0]

View File

@ -90,7 +90,17 @@ func init() {
routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}))
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}))
routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}))
routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}))
routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"}))
routers = append(routers, matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"}))
routers = append(routers, matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}))
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"}))
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"}))
routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"}))
routers = append(routers, matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}))
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"}))
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}))
routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"}))
// not match example
// https://github.com/beego/beego/v2/issues/3865

View File

@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler, "Test original root",
@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
// Replace the root path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot")
handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
// Test replacement root (expect change)
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler,
@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
// Replace the "level1" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler,
@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
// Replace the "/level1/level2" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2")
handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)

7
sonar-project.properties Normal file
View File

@ -0,0 +1,7 @@
sonar.organization=beego
sonar.projectKey=beego_beego
# relative paths to source directories. More details and properties are described
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
sonar.sources=.
sonar.exclusions=**/*_test.go

View File

@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time {
return time.Now()
}
func (c *countTask) GetTimeout(ctx context.Context) time.Duration {
return 0
}
func TestRunTaskCommand_Execute(t *testing.T) {
task := &countTask{}
AddTask("count", task)

View File

@ -109,6 +109,7 @@ type Tasker interface {
GetNext(ctx context.Context) time.Time
SetPrev(context.Context, time.Time)
GetPrev(ctx context.Context) time.Time
GetTimeout(ctx context.Context) time.Duration
}
// task error
@ -127,13 +128,14 @@ type Task struct {
DoFunc TaskFunc
Prev time.Time
Next time.Time
Errlist []*taskerr // like errtime:errinfo
ErrLimit int // max length for the errlist, 0 stand for no limit
errCnt int // records the error count during the execution
Timeout time.Duration // timeout duration
Errlist []*taskerr // like errtime:errinfo
ErrLimit int // max length for the errlist, 0 stand for no limit
errCnt int // records the error count during the execution
}
// NewTask add new task with name, time and func
func NewTask(tname string, spec string, f TaskFunc) *Task {
func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task {
task := &Task{
Taskname: tname,
@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task {
// we only store the pointer, so it won't use too many space
Errlist: make([]*taskerr, 100, 100),
}
for _, opt := range opts {
opt.apply(task)
}
task.SetCron(spec)
return task
}
@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time {
return t.Prev
}
// GetTimeout get timeout duration of this task
func (t *Task) GetTimeout(context.Context) time.Duration {
return t.Timeout
}
// Option interface
type Option interface {
apply(*Task)
}
// optionFunc return a function to set task element
type optionFunc func(*Task)
// apply option to task
func (f optionFunc) apply(t *Task) {
f(t)
}
// TimeoutOption return a option to set timeout duration for task
func TimeoutOption(timeout time.Duration) Option {
return optionFunc(func(t *Task) {
t.Timeout = timeout
})
}
// six columns mean
// second0-59
// minute0-59
@ -455,14 +487,12 @@ func (m *taskManager) StartTask() {
func (m *taskManager) run() {
now := time.Now().Local()
m.taskLock.Lock()
for _, t := range m.adminTaskList {
t.SetNext(nil, now)
}
m.taskLock.Unlock()
// first run the tasks, so set all tasks next run time.
m.setTasksStartTime(now)
for {
// we only use RLock here because NewMapSorter copy the reference, do not change any thing
// here, we sort all task and get first task running time (effective).
m.taskLock.RLock()
sortList := NewMapSorter(m.adminTaskList)
m.taskLock.RUnlock()
@ -475,37 +505,75 @@ func (m *taskManager) run() {
} else {
effective = sortList.Vals[0].GetNext(context.Background())
}
select {
case now = <-time.After(effective.Sub(now)):
// Run every entry whose next time was this effective time.
for _, e := range sortList.Vals {
if e.GetNext(context.Background()) != effective {
break
}
go e.Run(nil)
e.SetPrev(context.Background(), e.GetNext(context.Background()))
e.SetNext(nil, effective)
}
case now = <-time.After(effective.Sub(now)): // wait for effective time
runNextTasks(sortList, effective)
continue
case <-m.changed:
case <-m.changed: // tasks have been changed, set all tasks run again now
now = time.Now().Local()
m.taskLock.Lock()
for _, t := range m.adminTaskList {
t.SetNext(nil, now)
}
m.taskLock.Unlock()
m.setTasksStartTime(now)
continue
case <-m.stop:
m.taskLock.Lock()
if m.started {
m.started = false
}
m.taskLock.Unlock()
case <-m.stop: // manager is stopped, and mark manager is stopped
m.markManagerStop()
return
}
}
}
// setTasksStartTime is set all tasks next running time
func (m *taskManager) setTasksStartTime(now time.Time) {
m.taskLock.Lock()
for _, task := range m.adminTaskList {
task.SetNext(context.Background(), now)
}
m.taskLock.Unlock()
}
// markManagerStop it sets manager to be stopped
func (m *taskManager) markManagerStop() {
m.taskLock.Lock()
if m.started {
m.started = false
}
m.taskLock.Unlock()
}
// runNextTasks it runs next task which next run time is equal to effective
func runNextTasks(sortList *MapSorter, effective time.Time) {
// Run every entry whose next time was this effective time.
var i = 0
for _, e := range sortList.Vals {
i++
if e.GetNext(context.Background()) != effective {
break
}
// check if timeout is on, if yes passing the timeout context
ctx := context.Background()
if duration := e.GetTimeout(ctx); duration != 0 {
go func(e Tasker) {
ctx, cancelFunc := context.WithTimeout(ctx, duration)
defer cancelFunc()
err := e.Run(ctx)
if err != nil {
log.Printf("tasker.run err: %s\n", err.Error())
}
}(e)
} else {
go func(e Tasker) {
err := e.Run(ctx)
if err != nil {
log.Printf("tasker.run err: %s\n", err.Error())
}
}(e)
}
e.SetPrev(context.Background(), e.GetNext(context.Background()))
e.SetNext(context.Background(), effective)
}
}
// StopTask stop all tasks
func (m *taskManager) StopTask() {
go func() {

View File

@ -90,6 +90,57 @@ func TestSpec(t *testing.T) {
}
}
func TestTimeout(t *testing.T) {
m := newTaskManager()
defer m.ClearTask()
wg := &sync.WaitGroup{}
wg.Add(2)
once1, once2 := sync.Once{}, sync.Once{}
tk1 := NewTask("tk1", "0/10 * * ? * *",
func(ctx context.Context) error {
time.Sleep(4 * time.Second)
select {
case <-ctx.Done():
once1.Do(func() {
fmt.Println("tk1 done")
wg.Done()
})
return errors.New("timeout")
default:
}
return nil
}, TimeoutOption(3*time.Second),
)
tk2 := NewTask("tk2", "0/11 * * ? * *",
func(ctx context.Context) error {
time.Sleep(4 * time.Second)
select {
case <-ctx.Done():
return errors.New("timeout")
default:
once2.Do(func() {
fmt.Println("tk2 done")
wg.Done()
})
}
return nil
},
)
m.AddTask("tk1", tk1)
m.AddTask("tk2", tk2)
m.StartTask()
defer m.StopTask()
select {
case <-time.After(19 * time.Second):
t.Error("TestTimeout failed")
case <-wait(wg):
}
}
func TestTask_Run(t *testing.T) {
cnt := -1
task := func(ctx context.Context) error {
@ -109,6 +160,23 @@ func TestTask_Run(t *testing.T) {
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
}
func TestCrudTask(t *testing.T) {
m := newTaskManager()
m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error {
return nil
}))
m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error {
return nil
}))
m.DeleteTask("my-task1")
assert.Equal(t, 1, len(m.adminTaskList))
m.ClearTask()
assert.Equal(t, 0, len(m.adminTaskList))
}
func wait(wg *sync.WaitGroup) chan bool {
ch := make(chan bool)
go func() {