Merge branch 'develop' of https://github.com/beego/beego into orm-mock
This commit is contained in:
commit
89eeedbf9d
2
.github/workflows/changelog.yml
vendored
2
.github/workflows/changelog.yml
vendored
@ -8,7 +8,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, labeled, unlabeled]
|
types: [opened, synchronize, reopened, labeled, unlabeled]
|
||||||
branches:
|
branches:
|
||||||
- master
|
- develop
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
changelog:
|
changelog:
|
||||||
|
|||||||
32
.github/workflows/golangci-lint.yml
vendored
Normal file
32
.github/workflows/golangci-lint.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
name: golangci-lint
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
jobs:
|
||||||
|
golangci:
|
||||||
|
name: lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v2
|
||||||
|
with:
|
||||||
|
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
|
||||||
|
version: v1.29
|
||||||
|
|
||||||
|
# Optional: working directory, useful for monorepos
|
||||||
|
# working-directory: ./
|
||||||
|
|
||||||
|
# Optional: golangci-lint command line arguments.
|
||||||
|
args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true
|
||||||
|
|
||||||
|
# Optional: show only new issues if it's a pull request. The default value is `false`.
|
||||||
|
only-new-issues: true
|
||||||
|
|
||||||
|
# Optional: if set to true then the action will use pre-installed Go
|
||||||
|
# skip-go-installation: true
|
||||||
@ -1,6 +1,11 @@
|
|||||||
# developing
|
# developing
|
||||||
|
- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433)
|
||||||
|
- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427)
|
||||||
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
|
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
|
||||||
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
|
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
|
||||||
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
|
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
|
||||||
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
|
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
|
||||||
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
|
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
|
||||||
|
- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294)
|
||||||
|
- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404)
|
||||||
|
- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424)
|
||||||
|
|||||||
@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare {
|
|||||||
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
||||||
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
||||||
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
|
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
|
||||||
return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...)))
|
return (*App)(web.Router(rootpath, c, mappingMethods...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
|
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
|
||||||
|
|||||||
@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister {
|
|||||||
// Add("/api",&RestController{},"get,post:ApiFunc"
|
// Add("/api",&RestController{},"get,post:ApiFunc"
|
||||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||||
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
||||||
(*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...))
|
(*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
|
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
|
||||||
|
|||||||
@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) {
|
|||||||
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
|
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterChainOrder(t *testing.T) {
|
||||||
|
req := Get("http://beego.me")
|
||||||
|
req.AddFilters(func(next Filter) Filter {
|
||||||
|
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||||
|
return NewHttpResponseWithJsonBody("first"), nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
req.AddFilters(func(next Filter) Filter {
|
||||||
|
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||||
|
return NewHttpResponseWithJsonBody("second"), nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := req.DoRequestWithCtx(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
data := make([]byte, 5)
|
||||||
|
_, _ = resp.Body.Read(data)
|
||||||
|
assert.Equal(t, "first", string(data))
|
||||||
|
}
|
||||||
|
|
||||||
func TestHead(t *testing.T) {
|
func TestHead(t *testing.T) {
|
||||||
req := Head("http://beego.me")
|
req := Head("http://beego.me")
|
||||||
assert.NotNil(t, req)
|
assert.NotNil(t, req)
|
||||||
|
|||||||
6
client/orm/clauses/const.go
Normal file
6
client/orm/clauses/const.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package clauses
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExprSep = "__"
|
||||||
|
ExprDot = "."
|
||||||
|
)
|
||||||
103
client/orm/clauses/order_clause/order.go
Normal file
103
client/orm/clauses/order_clause/order.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package order_clause
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sort int8
|
||||||
|
|
||||||
|
const (
|
||||||
|
None Sort = 0
|
||||||
|
Ascending Sort = 1
|
||||||
|
Descending Sort = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Option func(order *Order)
|
||||||
|
|
||||||
|
type Order struct {
|
||||||
|
column string
|
||||||
|
sort Sort
|
||||||
|
isRaw bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func Clause(options ...Option) *Order {
|
||||||
|
o := &Order{}
|
||||||
|
for _, option := range options {
|
||||||
|
option(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) GetColumn() string {
|
||||||
|
return o.column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) GetSort() Sort {
|
||||||
|
return o.sort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) SortString() string {
|
||||||
|
switch o.GetSort() {
|
||||||
|
case Ascending:
|
||||||
|
return "ASC"
|
||||||
|
case Descending:
|
||||||
|
return "DESC"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ``
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Order) IsRaw() bool {
|
||||||
|
return o.isRaw
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseOrder(expressions ...string) []*Order {
|
||||||
|
var orders []*Order
|
||||||
|
for _, expression := range expressions {
|
||||||
|
sort := Ascending
|
||||||
|
column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot)
|
||||||
|
if column[0] == '-' {
|
||||||
|
sort = Descending
|
||||||
|
column = column[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
orders = append(orders, &Order{
|
||||||
|
column: column,
|
||||||
|
sort: sort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return orders
|
||||||
|
}
|
||||||
|
|
||||||
|
func Column(column string) Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sort(sort Sort) Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.sort = sort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortAscending() Option {
|
||||||
|
return sort(Ascending)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortDescending() Option {
|
||||||
|
return sort(Descending)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortNone() Option {
|
||||||
|
return sort(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Raw() Option {
|
||||||
|
return func(order *Order) {
|
||||||
|
order.isRaw = true
|
||||||
|
}
|
||||||
|
}
|
||||||
144
client/orm/clauses/order_clause/order_test.go
Normal file
144
client/orm/clauses/order_clause/order_test.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package order_clause
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClause(t *testing.T) {
|
||||||
|
var (
|
||||||
|
column = `a`
|
||||||
|
)
|
||||||
|
|
||||||
|
o := Clause(
|
||||||
|
Column(column),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetColumn() != column {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortAscending(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortAscending(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetSort() != Ascending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortDescending(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortDescending(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o.GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortNone(t *testing.T) {
|
||||||
|
o1 := Clause(
|
||||||
|
SortNone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o1.GetSort() != None {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause()
|
||||||
|
|
||||||
|
if o2.GetSort() != None {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaw(t *testing.T) {
|
||||||
|
o1 := Clause()
|
||||||
|
|
||||||
|
if o1.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause(
|
||||||
|
Raw(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !o2.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumn(t *testing.T) {
|
||||||
|
o1 := Clause(
|
||||||
|
Column(`aaa`),
|
||||||
|
)
|
||||||
|
|
||||||
|
if o1.GetColumn() != `aaa` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOrder(t *testing.T) {
|
||||||
|
orders := ParseOrder(
|
||||||
|
`-user__status`,
|
||||||
|
`status`,
|
||||||
|
`user__status`,
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Log(orders)
|
||||||
|
|
||||||
|
if orders[0].GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[0].GetColumn() != `user.status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[1].GetColumn() != `status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[1].GetSort() != Ascending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if orders[2].GetColumn() != `user.status` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_GetColumn(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
Column(`user__id`),
|
||||||
|
)
|
||||||
|
if o.GetColumn() != `user.id` {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_GetSort(t *testing.T) {
|
||||||
|
o := Clause(
|
||||||
|
SortDescending(),
|
||||||
|
)
|
||||||
|
if o.GetSort() != Descending {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_IsRaw(t *testing.T) {
|
||||||
|
o1 := Clause()
|
||||||
|
if o1.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
o2 := Clause(
|
||||||
|
Raw(),
|
||||||
|
)
|
||||||
|
if !o2.IsRaw() {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
fmt.Printf(" %s\n", err.Error())
|
fmt.Printf(" %s\n", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
for i, mi := range modelCache.allOrdered() {
|
for i, mi := range modelCache.allOrdered() {
|
||||||
|
|
||||||
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
|
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
|
||||||
@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var fields []*fieldInfo
|
var fields []*fieldInfo
|
||||||
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
|
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if d.rtOnError {
|
if d.rtOnError {
|
||||||
return err
|
return err
|
||||||
@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, idx := range indexes[mi.table] {
|
for _, idx := range indexes[mi.table] {
|
||||||
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
|
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
|
||||||
if !d.noInfo {
|
if !d.noInfo {
|
||||||
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||||
}
|
}
|
||||||
|
|||||||
146
client/orm/db.go
146
client/orm/db.go
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create insert sql preparation statement object.
|
// create insert sql preparation statement object.
|
||||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
||||||
@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
|
|
||||||
d.ins.HasReturningID(mi, &query)
|
d.ins.HasReturningID(mi, &query)
|
||||||
|
|
||||||
stmt, err := q.Prepare(query)
|
stmt, err := q.PrepareContext(ctx, query)
|
||||||
return stmt, query, err
|
return stmt, query, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert struct with prepared statement and given struct reflect value.
|
// insert struct with prepared statement and given struct reflect value.
|
||||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
res, err := stmt.Exec(values...)
|
res, err := stmt.ExecContext(ctx, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.LastInsertId()
|
return res.LastInsertId()
|
||||||
}
|
}
|
||||||
@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// query sql ,read records and persist in dbBaser.
|
// query sql ,read records and persist in dbBaser.
|
||||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
||||||
var whereCols []string
|
var whereCols []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
row := q.QueryRow(query, args...)
|
row := q.QueryRowContext(ctx, query, args...)
|
||||||
if err := row.Scan(refs...); err != nil {
|
if err := row.Scan(refs...); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return ErrNoRows
|
return ErrNoRows
|
||||||
@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
names := make([]string, 0, len(mi.fields.dbcols))
|
names := make([]string, 0, len(mi.fields.dbcols))
|
||||||
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := d.InsertValue(q, mi, false, names, values)
|
id, err := d.InsertValue(ctx, q, mi, false, names, values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(autoFields) > 0 {
|
if len(autoFields) > 0 {
|
||||||
err = d.ins.setval(q, mi, autoFields)
|
err = d.ins.setval(ctx, q, mi, autoFields)
|
||||||
}
|
}
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-insert sql with given slice struct reflect.Value.
|
// multi-insert sql with given slice struct reflect.Value.
|
||||||
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||||
var (
|
var (
|
||||||
cnt int64
|
cnt int64
|
||||||
nums int
|
nums int
|
||||||
@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
}
|
}
|
||||||
|
|
||||||
if i > 1 && i%bulk == 0 || length == i {
|
if i > 1 && i%bulk == 0 || length == i {
|
||||||
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
if len(autoFields) > 0 {
|
if len(autoFields) > 0 {
|
||||||
err = d.ins.setval(q, mi, autoFields)
|
err = d.ins.setval(ctx, q, mi, autoFields)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cnt, err
|
return cnt, err
|
||||||
@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
|
|||||||
|
|
||||||
// execute insert sql with given struct and given values.
|
// execute insert sql with given struct and given values.
|
||||||
// insert the given values, not the field values in struct.
|
// insert the given values, not the field values in struct.
|
||||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
|
|||||||
// InsertOrUpdate a row
|
// InsertOrUpdate a row
|
||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
args0 := ""
|
args0 := ""
|
||||||
iouStr := ""
|
iouStr := ""
|
||||||
argsMap := map[string]string{}
|
argsMap := map[string]string{}
|
||||||
@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
|
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
|
||||||
@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute update sql dbQuerier with given struct reflect.Value.
|
// execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
res, err := q.Exec(query, setValues...)
|
res, err := q.ExecContext(ctx, query, setValues...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
// execute delete sql dbQuerier with given struct reflect.Value.
|
// execute delete sql dbQuerier with given struct reflect.Value.
|
||||||
// delete index is pk.
|
// delete index is pk.
|
||||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
var whereCols []string
|
var whereCols []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
// if specify cols length > 0, then use it for where condition.
|
// if specify cols length > 0, then use it for where condition.
|
||||||
@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
|
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
res, err := q.Exec(query, args...)
|
res, err := q.ExecContext(ctx, query, args...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
num, err := res.RowsAffected()
|
num, err := res.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
|
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := d.deleteRels(q, mi, args, tz)
|
err := d.deleteRels(ctx, q, mi, args, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
|
|
||||||
// update table-related record by querySet.
|
// update table-related record by querySet.
|
||||||
// need querySet not struct reflect.Value to update related records.
|
// need querySet not struct reflect.Value to update related records.
|
||||||
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
||||||
columns := make([]string, 0, len(params))
|
columns := make([]string, 0, len(params))
|
||||||
values := make([]interface{}, 0, len(params))
|
values := make([]interface{}, 0, len(params))
|
||||||
for col, val := range params {
|
for col, val := range params {
|
||||||
@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
var err error
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
var res sql.Result
|
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
res, err = q.ExecContext(qs.ctx, query, values...)
|
|
||||||
} else {
|
|
||||||
res, err = q.Exec(query, values...)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
|
|
||||||
// delete related records.
|
// delete related records.
|
||||||
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
||||||
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
||||||
for _, fi := range mi.fields.fieldsReverse {
|
for _, fi := range mi.fields.fieldsReverse {
|
||||||
fi = fi.reverseFieldInfo
|
fi = fi.reverseFieldInfo
|
||||||
switch fi.onDelete {
|
switch fi.onDelete {
|
||||||
case odCascade:
|
case odCascade:
|
||||||
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
|
||||||
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
|
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
|||||||
if fi.onDelete == odSetDefault {
|
if fi.onDelete == odSetDefault {
|
||||||
params[fi.column] = fi.initial.String()
|
params[fi.column] = fi.initial.String()
|
||||||
}
|
}
|
||||||
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
|
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// delete table-related records.
|
// delete table-related records.
|
||||||
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.skipEnd = true
|
tables.skipEnd = true
|
||||||
|
|
||||||
@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
var rs *sql.Rows
|
||||||
r, err := q.Query(query, args...)
|
r, err := q.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
|
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
var res sql.Result
|
res, err := q.ExecContext(ctx, query, args...)
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
res, err = q.ExecContext(qs.ctx, query, args...)
|
|
||||||
} else {
|
|
||||||
res, err = q.Exec(query, args...)
|
|
||||||
}
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
num, err := res.RowsAffected()
|
num, err := res.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if num > 0 {
|
if num > 0 {
|
||||||
err := d.deleteRels(q, mi, args, tz)
|
err := d.deleteRels(ctx, q, mi, args, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
// read related records.
|
// read related records.
|
||||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||||
|
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
|
|
||||||
errTyp := true
|
unregister := true
|
||||||
one := true
|
one := true
|
||||||
isPtr := true
|
isPtr := true
|
||||||
|
name := ""
|
||||||
|
|
||||||
if val.Kind() == reflect.Ptr {
|
if val.Kind() == reflect.Ptr {
|
||||||
fn := ""
|
fn := ""
|
||||||
@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
isPtr = false
|
isPtr = false
|
||||||
fn = getFullName(typ)
|
fn = getFullName(typ)
|
||||||
|
name = getTableName(reflect.New(typ))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fn = getFullName(ind.Type())
|
fn = getFullName(ind.Type())
|
||||||
|
name = getTableName(ind)
|
||||||
}
|
}
|
||||||
errTyp = fn != mi.fullName
|
unregister = fn != mi.fullName
|
||||||
}
|
}
|
||||||
|
|
||||||
if errTyp {
|
if unregister {
|
||||||
if one {
|
RegisterModel(container)
|
||||||
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
|
|
||||||
} else {
|
|
||||||
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rlimit := qs.limit
|
rlimit := qs.limit
|
||||||
@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
if qs.distinct {
|
if qs.distinct {
|
||||||
sqlSelect += " DISTINCT"
|
sqlSelect += " DISTINCT"
|
||||||
}
|
}
|
||||||
|
if qs.aggregate != "" {
|
||||||
|
sels = qs.aggregate
|
||||||
|
}
|
||||||
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
|
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
|
||||||
sqlSelect, sels, Q, mi.table, Q,
|
sqlSelect, sels, Q, mi.table, Q,
|
||||||
specifyIndexes, join, where, groupBy, orderBy, limit)
|
specifyIndexes, join, where, groupBy, orderBy, limit)
|
||||||
@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var rs *sql.Rows
|
rs, err := q.QueryContext(ctx, query, args...)
|
||||||
var err error
|
if err != nil {
|
||||||
if qs != nil && qs.forContext {
|
return 0, err
|
||||||
rs, err = q.QueryContext(qs.ctx, query, args...)
|
}
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
defer rs.Close()
|
||||||
}
|
|
||||||
} else {
|
slice := ind
|
||||||
rs, err = q.Query(query, args...)
|
if unregister {
|
||||||
if err != nil {
|
mi, _ = modelCache.get(name)
|
||||||
return 0, err
|
tCols = mi.fields.dbcols
|
||||||
}
|
colsNum = len(tCols)
|
||||||
}
|
}
|
||||||
|
|
||||||
refs := make([]interface{}, colsNum)
|
refs := make([]interface{}, colsNum)
|
||||||
@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
var ref interface{}
|
var ref interface{}
|
||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rs.Close()
|
|
||||||
|
|
||||||
slice := ind
|
|
||||||
|
|
||||||
var cnt int64
|
var cnt int64
|
||||||
for rs.Next() {
|
for rs.Next() {
|
||||||
if one && cnt == 0 || !one {
|
if one && cnt == 0 || !one {
|
||||||
@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// excute count sql and return count result int64.
|
// excute count sql and return count result int64.
|
||||||
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
|
|
||||||
@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
var row *sql.Row
|
row := q.QueryRowContext(ctx, query, args...)
|
||||||
if qs != nil && qs.forContext {
|
|
||||||
row = q.QueryRowContext(qs.ctx, query, args...)
|
|
||||||
} else {
|
|
||||||
row = q.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
err = row.Scan(&cnt)
|
err = row.Scan(&cnt)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1649,7 +1631,7 @@ setValue:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// query sql, read values , save to *[]ParamList.
|
// query sql, read values , save to *[]ParamList.
|
||||||
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
maps []Params
|
maps []Params
|
||||||
@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
rs, err := q.Query(query, args...)
|
rs, err := q.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sync auto key
|
// sync auto key
|
||||||
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get all cloumns in table.
|
// get all cloumns in table.
|
||||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
columns := make(map[string][3]string)
|
columns := make(map[string][3]string)
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return columns, err
|
return columns, err
|
||||||
}
|
}
|
||||||
@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// not implement.
|
// not implement.
|
||||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute sql to check index exist.
|
// execute sql to check index exist.
|
||||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
// Add "`" for mysql sql building
|
// Add "`" for mysql sql building
|
||||||
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
var iouStr string
|
var iouStr string
|
||||||
argsMap := map[string]string{}
|
argsMap := map[string]string{}
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index is exist
|
// check index is exist
|
||||||
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
|
row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
|
||||||
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
|
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
|
||||||
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
|
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
|
||||||
|
|
||||||
@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
|
|||||||
|
|
||||||
// execute insert sql with given struct and given values.
|
// execute insert sql with given struct and given values.
|
||||||
// insert the given values, not the field values in struct.
|
// insert the given values, not the field values in struct.
|
||||||
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
|
|||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.Exec(query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if isMulti {
|
if isMulti {
|
||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
|
|||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRowContext(ctx, query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err := row.Scan(&id)
|
err := row.Scan(&id)
|
||||||
return id, err
|
return id, err
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sync auto key
|
// sync auto key
|
||||||
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||||
if len(autoFields) == 0 {
|
if len(autoFields) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string
|
|||||||
mi.table, name,
|
mi.table, name,
|
||||||
Q, name, Q,
|
Q, name, Q,
|
||||||
Q, mi.table, Q)
|
Q, mi.table, Q)
|
||||||
if _, err := db.Exec(query); err != nil {
|
if _, err := db.ExecContext(ctx, query); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index exist in postgresql.
|
// check index exist in postgresql.
|
||||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||||
row := db.QueryRow(query)
|
row := db.QueryRowContext(ctx, query)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -73,11 +74,11 @@ type dbBaseSqlite struct {
|
|||||||
var _ dbBaser = new(dbBaseSqlite)
|
var _ dbBaser = new(dbBaseSqlite)
|
||||||
|
|
||||||
// override base db read for update behavior as SQlite does not support syntax
|
// override base db read for update behavior as SQlite does not support syntax
|
||||||
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
||||||
if isForUpdate {
|
if isForUpdate {
|
||||||
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
|
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
|
||||||
}
|
}
|
||||||
return d.dbBase.Read(q, mi, ind, tz, cols, false)
|
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get sqlite operator.
|
// get sqlite operator.
|
||||||
@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get columns in sqlite.
|
// get columns in sqlite.
|
||||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check index exist in sqlite.
|
// check index exist in sqlite.
|
||||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,6 +16,8 @@ package orm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generate order sql.
|
// generate order sql.
|
||||||
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
|
||||||
if len(orders) == 0 {
|
if len(orders) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
|||||||
|
|
||||||
orderSqls := make([]string, 0, len(orders))
|
orderSqls := make([]string, 0, len(orders))
|
||||||
for _, order := range orders {
|
for _, order := range orders {
|
||||||
asc := "ASC"
|
column := order.GetColumn()
|
||||||
if order[0] == '-' {
|
clause := strings.Split(column, clauses.ExprDot)
|
||||||
asc = "DESC"
|
|
||||||
order = order[1:]
|
|
||||||
}
|
|
||||||
exprs := strings.Split(order, ExprSep)
|
|
||||||
|
|
||||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
if order.IsRaw() {
|
||||||
if !suc {
|
if len(clause) == 2 {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
|
||||||
}
|
} else if len(clause) == 1 {
|
||||||
|
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
|
||||||
|
} else {
|
||||||
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
index, _, fi, suc := t.parseExprs(t.mi, clause)
|
||||||
|
if !suc {
|
||||||
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||||
|
}
|
||||||
|
|
||||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute sql to check index exist.
|
// execute sql to check index exist.
|
||||||
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
var cnt int
|
var cnt int
|
||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
|
|||||||
@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) {
|
|||||||
|
|
||||||
assert.Nil(t, o.Driver())
|
assert.Nil(t, o.Driver())
|
||||||
|
|
||||||
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
|
|
||||||
assert.Nil(t, o.QueryM2M(nil, ""))
|
assert.Nil(t, o.QueryM2M(nil, ""))
|
||||||
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
||||||
assert.Nil(t, o.Read(nil))
|
assert.Nil(t, o.Read(nil))
|
||||||
@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, int64(0), i)
|
assert.Equal(t, int64(0), i)
|
||||||
|
|
||||||
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
|
|
||||||
assert.Nil(t, o.QueryTable(nil))
|
assert.Nil(t, o.QueryTable(nil))
|
||||||
|
|
||||||
assert.Nil(t, o.Read(nil))
|
assert.Nil(t, o.Read(nil))
|
||||||
|
|||||||
@ -27,7 +27,7 @@ import (
|
|||||||
// this Filter's behavior looks a little bit strange
|
// this Filter's behavior looks a little bit strange
|
||||||
// for example:
|
// for example:
|
||||||
// if we want to trace QuerySetter
|
// if we want to trace QuerySetter
|
||||||
// actually we trace invoking "QueryTable" and "QueryTableWithCtx"
|
// actually we trace invoking "QueryTable"
|
||||||
// the method Begin*, Commit and Rollback are ignored.
|
// the method Begin*, Commit and Rollback are ignored.
|
||||||
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
|
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import (
|
|||||||
// this Filter's behavior looks a little bit strange
|
// this Filter's behavior looks a little bit strange
|
||||||
// for example:
|
// for example:
|
||||||
// if we want to records the metrics of QuerySetter
|
// if we want to records the metrics of QuerySetter
|
||||||
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx"
|
// actually we only records metrics of invoking "QueryTable"
|
||||||
type FilterChainBuilder struct {
|
type FilterChainBuilder struct {
|
||||||
summaryVec prometheus.ObserverVec
|
summaryVec prometheus.ObserverVec
|
||||||
AppName string
|
AppName string
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
"github.com/beego/beego/v2/core/utils"
|
"github.com/beego/beego/v2/core/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
|
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||||
return f.QueryM2MWithCtx(context.Background(), md, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
|
||||||
|
|
||||||
mi, _ := modelCache.getByMd(md)
|
mi, _ := modelCache.getByMd(md)
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "QueryM2MWithCtx",
|
Method: "QueryM2M",
|
||||||
Args: []interface{}{md, name},
|
Args: []interface{}{md, name},
|
||||||
Md: md,
|
Md: md,
|
||||||
mi: mi,
|
mi: mi,
|
||||||
InsideTx: f.insideTx,
|
InsideTx: f.insideTx,
|
||||||
TxStartTime: f.txStartTime,
|
TxStartTime: f.txStartTime,
|
||||||
f: func(c context.Context) []interface{} {
|
f: func(c context.Context) []interface{} {
|
||||||
res := f.ormer.QueryM2MWithCtx(c, md, name)
|
res := f.ormer.QueryM2M(md, name)
|
||||||
return []interface{}{res}
|
return []interface{}{res}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
res := f.root(ctx, inv)
|
res := f.root(context.Background(), inv)
|
||||||
if res[0] == nil {
|
if res[0] == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return res[0].(QueryM2Mer)
|
return res[0].(QueryM2Mer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
|
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.")
|
||||||
|
return f.QueryM2M(md, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
||||||
var (
|
var (
|
||||||
name string
|
name string
|
||||||
md interface{}
|
md interface{}
|
||||||
@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
|
|||||||
}
|
}
|
||||||
|
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "QueryTableWithCtx",
|
Method: "QueryTable",
|
||||||
Args: []interface{}{ptrStructOrTableName},
|
Args: []interface{}{ptrStructOrTableName},
|
||||||
InsideTx: f.insideTx,
|
InsideTx: f.insideTx,
|
||||||
TxStartTime: f.txStartTime,
|
TxStartTime: f.txStartTime,
|
||||||
Md: md,
|
Md: md,
|
||||||
mi: mi,
|
mi: mi,
|
||||||
f: func(c context.Context) []interface{} {
|
f: func(c context.Context) []interface{} {
|
||||||
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
|
res := f.ormer.QueryTable(ptrStructOrTableName)
|
||||||
return []interface{}{res}
|
return []interface{}{res}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
res := f.root(ctx, inv)
|
res := f.root(context.Background(), inv)
|
||||||
|
|
||||||
if res[0] == nil {
|
if res[0] == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
|
|||||||
return res[0].(QuerySeter)
|
return res[0].(QuerySeter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||||
|
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.")
|
||||||
|
return f.QueryTable(ptrStructOrTableName)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
|
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
|
||||||
inv := &Invocation{
|
inv := &Invocation{
|
||||||
Method: "DBStats",
|
Method: "DBStats",
|
||||||
|
|||||||
@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
|
|||||||
o := &filterMockOrm{}
|
o := &filterMockOrm{}
|
||||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||||
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
|
assert.Equal(t, "QueryM2M", inv.Method)
|
||||||
assert.Equal(t, 2, len(inv.Args))
|
assert.Equal(t, 2, len(inv.Args))
|
||||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||||
assert.False(t, inv.InsideTx)
|
assert.False(t, inv.InsideTx)
|
||||||
@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) {
|
|||||||
o := &filterMockOrm{}
|
o := &filterMockOrm{}
|
||||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||||
assert.Equal(t, "QueryTableWithCtx", inv.Method)
|
assert.Equal(t, "QueryTable", inv.Method)
|
||||||
assert.Equal(t, 1, len(inv.Args))
|
assert.Equal(t, 1, len(inv.Args))
|
||||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||||
assert.False(t, inv.InsideTx)
|
assert.False(t, inv.InsideTx)
|
||||||
|
|||||||
@ -332,10 +332,6 @@ end:
|
|||||||
|
|
||||||
// register register models to model cache
|
// register register models to model cache
|
||||||
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
|
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
|
||||||
if mc.done {
|
|
||||||
err = fmt.Errorf("register must be run before BootStrap")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
|
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if val.Elem().Kind() == reflect.Slice {
|
||||||
|
val = reflect.New(val.Elem().Type().Elem())
|
||||||
|
}
|
||||||
table := getTableName(val)
|
table := getTableName(val)
|
||||||
|
|
||||||
if prefixOrSuffixStr != "" {
|
if prefixOrSuffixStr != "" {
|
||||||
@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := mc.get(table); ok {
|
if _, ok := mc.get(table); ok {
|
||||||
err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mi := newModelInfo(val)
|
mi := newModelInfo(val)
|
||||||
@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mi.fields.pk == nil {
|
|
||||||
err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mi.table = table
|
mi.table = table
|
||||||
|
|||||||
@ -255,6 +255,22 @@ func NewTM() *TM {
|
|||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DeptInfo struct {
|
||||||
|
ID int `orm:"column(id)"`
|
||||||
|
Created time.Time `orm:"auto_now_add"`
|
||||||
|
DeptName string
|
||||||
|
EmployeeName string
|
||||||
|
Salary int
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnregisterModel struct {
|
||||||
|
ID int `orm:"column(id)"`
|
||||||
|
Created time.Time `orm:"auto_now_add"`
|
||||||
|
DeptName string
|
||||||
|
EmployeeName string
|
||||||
|
Salary int
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int `orm:"column(id)"`
|
ID int `orm:"column(id)"`
|
||||||
UserName string `orm:"size(30);unique"`
|
UserName string `orm:"size(30);unique"`
|
||||||
|
|||||||
@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string {
|
|||||||
|
|
||||||
// get whether the table needs to be created for the database alias
|
// get whether the table needs to be created for the database alias
|
||||||
func isApplicableTableForDB(val reflect.Value, db string) bool {
|
func isApplicableTableForDB(val reflect.Value, db string) bool {
|
||||||
|
if !val.IsValid() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
fun := val.MethodByName("IsApplicableTableForDB")
|
fun := val.MethodByName("IsApplicableTableForDB")
|
||||||
if fun.IsValid() {
|
if fun.IsValid() {
|
||||||
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})
|
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})
|
||||||
|
|||||||
@ -58,6 +58,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
@ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
||||||
@ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
|
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read a row from the database, or insert one if it doesn't exist
|
// Try to read a row from the database, or insert one if it doesn't exist
|
||||||
@ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo
|
|||||||
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||||
cols = append([]string{col1}, cols...)
|
cols = append([]string{col1}, cols...)
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
|
||||||
if err == ErrNoRows {
|
if err == ErrNoRows {
|
||||||
// Create
|
// Create
|
||||||
id, err := o.InsertWithCtx(ctx, md)
|
id, err := o.InsertWithCtx(ctx, md)
|
||||||
@ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
|
|||||||
for i := 0; i < sind.Len(); i++ {
|
for i := 0; i < sind.Len(); i++ {
|
||||||
ind := reflect.Indirect(sind.Index(i))
|
ind := reflect.Indirect(sind.Index(i))
|
||||||
mi, _ := o.getMiInd(ind.Interface(), false)
|
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cnt, err
|
return cnt, err
|
||||||
}
|
}
|
||||||
@ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||||
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ)
|
||||||
}
|
}
|
||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
@ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (
|
|||||||
}
|
}
|
||||||
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
|
id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete model in database
|
// delete model in database
|
||||||
@ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
|
|||||||
}
|
}
|
||||||
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
|
num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
@ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str
|
|||||||
|
|
||||||
// create a models to models queryer
|
// create a models to models queryer
|
||||||
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
|
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||||
return o.QueryM2MWithCtx(context.Background(), md, name)
|
|
||||||
}
|
|
||||||
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
fi := o.getFieldInfo(mi, name)
|
fi := o.getFieldInfo(mi, name)
|
||||||
|
|
||||||
@ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
|
|||||||
return newQueryM2M(md, o, mi, fi, ind)
|
return newQueryM2M(md, o, mi, fi, ind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
|
||||||
|
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.")
|
||||||
|
return o.QueryM2M(md, name)
|
||||||
|
}
|
||||||
|
|
||||||
// load related models to md model.
|
// load related models to md model.
|
||||||
// args are limit, offset int and order string.
|
// args are limit, offset int and order string.
|
||||||
//
|
//
|
||||||
@ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
|
|||||||
qs.relDepth = relDepth
|
qs.relDepth = relDepth
|
||||||
|
|
||||||
if len(order) > 0 {
|
if len(order) > 0 {
|
||||||
qs.orders = []string{order}
|
qs.orders = order_clause.ParseOrder(order)
|
||||||
}
|
}
|
||||||
|
|
||||||
find := ind.FieldByIndex(fi.fieldIndex)
|
find := ind.FieldByIndex(fi.fieldIndex)
|
||||||
@ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
|
|||||||
// table name can be string or struct.
|
// table name can be string or struct.
|
||||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
|
||||||
}
|
|
||||||
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
|
|
||||||
var name string
|
var name string
|
||||||
if table, ok := ptrStructOrTableName.(string); ok {
|
if table, ok := ptrStructOrTableName.(string); ok {
|
||||||
name = nameStrategyMap[defaultNameStrategy](table)
|
name = nameStrategyMap[defaultNameStrategy](table)
|
||||||
@ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in
|
|||||||
if qs == nil {
|
if qs == nil {
|
||||||
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
|
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
|
||||||
}
|
}
|
||||||
return
|
return qs
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
|
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.")
|
||||||
|
return o.QueryTable(ptrStructOrTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return a raw query seter for raw sql string.
|
// return a raw query seter for raw sql string.
|
||||||
@ -595,9 +602,8 @@ func NewOrm() Ormer {
|
|||||||
func NewOrmUsingDB(aliasName string) Ormer {
|
func NewOrmUsingDB(aliasName string) Ormer {
|
||||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||||
return newDBWithAlias(al)
|
return newDBWithAlias(al)
|
||||||
} else {
|
|
||||||
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
|
|
||||||
}
|
}
|
||||||
|
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
|
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
|
||||||
|
|||||||
@ -16,12 +16,13 @@ package orm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExprSep define the expression separation
|
// ExprSep define the expression separation
|
||||||
const (
|
const (
|
||||||
ExprSep = "__"
|
ExprSep = clauses.ExprSep
|
||||||
)
|
)
|
||||||
|
|
||||||
type condValue struct {
|
type condValue struct {
|
||||||
|
|||||||
@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
|
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
|
||||||
|
return d.ExecContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res, err := d.stmt.Exec(args...)
|
res, err := d.stmt.ExecContext(ctx, args...)
|
||||||
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
|
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
|
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
|
||||||
|
return d.QueryContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res, err := d.stmt.Query(args...)
|
res, err := d.stmt.QueryContext(ctx, args...)
|
||||||
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
|
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
|
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
|
||||||
|
return d.QueryRowContext(context.Background(), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
|
||||||
a := time.Now()
|
a := time.Now()
|
||||||
res := d.stmt.QueryRow(args...)
|
res := d.stmt.QueryRow(args...)
|
||||||
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
|
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
@ -31,6 +32,10 @@ var _ Inserter = new(insertSet)
|
|||||||
|
|
||||||
// insert model ignore it's registered or not.
|
// insert model ignore it's registered or not.
|
||||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||||
|
return o.InsertWithCtx(context.Background(), md)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||||
if o.closed {
|
if o.closed {
|
||||||
return 0, ErrStmtClosed
|
return 0, ErrStmtClosed
|
||||||
}
|
}
|
||||||
@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
|
|||||||
if name != o.mi.fullName {
|
if name != o.mi.fullName {
|
||||||
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
|
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
|
||||||
}
|
}
|
||||||
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
|
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
@ -70,11 +75,11 @@ func (o *insertSet) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create new insert queryer.
|
// create new insert queryer.
|
||||||
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) {
|
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
|
||||||
bi := new(insertSet)
|
bi := new(insertSet)
|
||||||
bi.orm = orm
|
bi.orm = orm
|
||||||
bi.mi = mi
|
bi.mi = mi
|
||||||
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
|
st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,10 @@
|
|||||||
|
|
||||||
package orm
|
package orm
|
||||||
|
|
||||||
import "reflect"
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
// model to model struct
|
// model to model struct
|
||||||
type queryM2M struct {
|
type queryM2M struct {
|
||||||
@ -33,6 +36,10 @@ type queryM2M struct {
|
|||||||
//
|
//
|
||||||
// make sure the relation is defined in post model struct tag.
|
// make sure the relation is defined in post model struct tag.
|
||||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||||
|
return o.AddWithCtx(context.Background(), mds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
mi := fi.relThroughModelInfo
|
mi := fi.relThroughModelInfo
|
||||||
mfi := fi.reverseFieldInfo
|
mfi := fi.reverseFieldInfo
|
||||||
@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
names = append(names, otherNames...)
|
names = append(names, otherNames...)
|
||||||
values = append(values, otherValues...)
|
values = append(values, otherValues...)
|
||||||
return dbase.InsertValue(orm.db, mi, true, names, values)
|
return dbase.InsertValue(ctx, orm.db, mi, true, names, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove models following the origin model relationship
|
// remove models following the origin model relationship
|
||||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||||
|
return o.RemoveWithCtx(context.Background(), mds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||||
|
|
||||||
@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
|||||||
|
|
||||||
// check model is existed in relationship of origin model
|
// check model is existed in relationship of origin model
|
||||||
func (o *queryM2M) Exist(md interface{}) bool {
|
func (o *queryM2M) Exist(md interface{}) bool {
|
||||||
|
return o.ExistWithCtx(context.Background(), md)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean all models in related of origin model
|
// clean all models in related of origin model
|
||||||
func (o *queryM2M) Clear() (int64, error) {
|
func (o *queryM2M) Clear() (int64, error) {
|
||||||
|
return o.ClearWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// count all related models of origin model
|
// count all related models of origin model
|
||||||
func (o *queryM2M) Count() (int64, error) {
|
func (o *queryM2M) Count() (int64, error) {
|
||||||
|
return o.CountWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ QueryM2Mer = new(queryM2M)
|
var _ QueryM2Mer = new(queryM2M)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"github.com/beego/beego/v2/client/orm/hints"
|
"github.com/beego/beego/v2/client/orm/hints"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} {
|
|||||||
|
|
||||||
// real query struct
|
// real query struct
|
||||||
type querySet struct {
|
type querySet struct {
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
cond *Condition
|
cond *Condition
|
||||||
related []string
|
related []string
|
||||||
relDepth int
|
relDepth int
|
||||||
limit int64
|
limit int64
|
||||||
offset int64
|
offset int64
|
||||||
groups []string
|
groups []string
|
||||||
orders []string
|
orders []*order_clause.Order
|
||||||
distinct bool
|
distinct bool
|
||||||
forUpdate bool
|
forUpdate bool
|
||||||
useIndex int
|
useIndex int
|
||||||
indexes []string
|
indexes []string
|
||||||
orm *ormBase
|
orm *ormBase
|
||||||
ctx context.Context
|
aggregate string
|
||||||
forContext bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ QuerySeter = new(querySet)
|
var _ QuerySeter = new(querySet)
|
||||||
@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter {
|
|||||||
|
|
||||||
// add ORDER expression.
|
// add ORDER expression.
|
||||||
// "column" means ASC, "-column" means DESC.
|
// "column" means ASC, "-column" means DESC.
|
||||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
func (o querySet) OrderBy(expressions ...string) QuerySeter {
|
||||||
o.orders = exprs
|
if len(expressions) <= 0 {
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
o.orders = order_clause.ParseOrder(expressions...)
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
|
||||||
|
// add ORDER expression.
|
||||||
|
func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter {
|
||||||
|
if len(orders) <= 0 {
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
o.orders = orders
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition {
|
|||||||
|
|
||||||
// return QuerySeter execution result number
|
// return QuerySeter execution result number
|
||||||
func (o *querySet) Count() (int64, error) {
|
func (o *querySet) Count() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.CountWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check result empty or not after QuerySeter executed
|
// check result empty or not after QuerySeter executed
|
||||||
func (o *querySet) Exist() bool {
|
func (o *querySet) Exist() bool {
|
||||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.ExistWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
|
||||||
|
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute update with parameters
|
// execute update with parameters
|
||||||
func (o *querySet) Update(values Params) (int64, error) {
|
func (o *querySet) Update(values Params) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
return o.UpdateWithCtx(context.Background(), values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute delete
|
// execute delete
|
||||||
func (o *querySet) Delete() (int64, error) {
|
func (o *querySet) Delete() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.DeleteWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// return a insert queryer.
|
// return a insert queryer.
|
||||||
@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) {
|
|||||||
// i,err := sq.PrepareInsert()
|
// i,err := sq.PrepareInsert()
|
||||||
// i.Add(&user1{},&user2{})
|
// i.Add(&user1{},&user2{})
|
||||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||||
return newInsertSet(o.orm, o.mi)
|
return o.PrepareInsertWithCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
|
||||||
|
return newInsertSet(ctx, o.orm, o.mi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to containers.
|
// query all data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
return o.AllWithCtx(context.Background(), container, cols...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query one row data and map to containers.
|
// query one row data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||||
|
return o.OneWithCtx(context.Background(), container, cols...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
|
||||||
o.limit = 1
|
o.limit = 1
|
||||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error {
|
|||||||
// expres means condition expression.
|
// expres means condition expression.
|
||||||
// it converts data to []map[column]value.
|
// it converts data to []map[column]value.
|
||||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.ValuesWithCtx(context.Background(), results, exprs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to [][]interface
|
// query all data and map to [][]interface
|
||||||
// it converts data to [][column_index]value
|
// it converts data to [][column_index]value
|
||||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.ValuesListWithCtx(context.Background(), results, exprs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all data and map to []interface.
|
// query all data and map to []interface.
|
||||||
// it's designed for one row record set, auto change to []value, not [][column]value.
|
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
return o.ValuesFlatWithCtx(context.Background(), result, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
|
||||||
|
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query all rows into map[string]interface with specify key and value column name.
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
|
|||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set context to QuerySeter.
|
|
||||||
func (o querySet) WithContext(ctx context.Context) QuerySeter {
|
|
||||||
o.ctx = ctx
|
|
||||||
o.forContext = true
|
|
||||||
return &o
|
|
||||||
}
|
|
||||||
|
|
||||||
// create new QuerySeter.
|
// create new QuerySeter.
|
||||||
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
||||||
o := new(querySet)
|
o := new(querySet)
|
||||||
@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
|||||||
o.orm = orm
|
o.orm = orm
|
||||||
return o
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// aggregate func
|
||||||
|
func (o querySet) Aggregate(s string) QuerySeter {
|
||||||
|
o.aggregate = s
|
||||||
|
return &o
|
||||||
|
}
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
@ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) {
|
|||||||
RegisterModel(new(Index))
|
RegisterModel(new(Index))
|
||||||
RegisterModel(new(StrPk))
|
RegisterModel(new(StrPk))
|
||||||
RegisterModel(new(TM))
|
RegisterModel(new(TM))
|
||||||
|
RegisterModel(new(DeptInfo))
|
||||||
|
|
||||||
err := RunSyncdb("default", true, Debug)
|
err := RunSyncdb("default", true, Debug)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) {
|
|||||||
RegisterModel(new(Index))
|
RegisterModel(new(Index))
|
||||||
RegisterModel(new(StrPk))
|
RegisterModel(new(StrPk))
|
||||||
RegisterModel(new(TM))
|
RegisterModel(new(TM))
|
||||||
|
RegisterModel(new(DeptInfo))
|
||||||
|
|
||||||
BootStrap()
|
BootStrap()
|
||||||
|
|
||||||
@ -333,6 +336,73 @@ func TestTM(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
|
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnregisterModel(t *testing.T) {
|
||||||
|
data := []*DeptInfo{
|
||||||
|
{
|
||||||
|
DeptName: "A",
|
||||||
|
EmployeeName: "A1",
|
||||||
|
Salary: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "A",
|
||||||
|
EmployeeName: "A2",
|
||||||
|
Salary: 2000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B1",
|
||||||
|
Salary: 2000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B2",
|
||||||
|
Salary: 4000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeptName: "B",
|
||||||
|
EmployeeName: "B3",
|
||||||
|
Salary: 3000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
qs := dORM.QueryTable("dept_info")
|
||||||
|
i, _ := qs.PrepareInsert()
|
||||||
|
for _, d := range data {
|
||||||
|
_, err := i.Insert(d)
|
||||||
|
if err != nil {
|
||||||
|
throwFail(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f := func() {
|
||||||
|
var res []UnregisterModel
|
||||||
|
n, err := dORM.QueryTable("dept_info").All(&res)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(n, 5))
|
||||||
|
throwFail(t, AssertIs(res[0].EmployeeName, "A1"))
|
||||||
|
|
||||||
|
type Sum struct {
|
||||||
|
DeptName string
|
||||||
|
Total int
|
||||||
|
}
|
||||||
|
var sun []Sum
|
||||||
|
qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun)
|
||||||
|
throwFail(t, AssertIs(sun[0].DeptName, "A"))
|
||||||
|
throwFail(t, AssertIs(sun[0].Total, 3000))
|
||||||
|
|
||||||
|
type Max struct {
|
||||||
|
DeptName string
|
||||||
|
Max float64
|
||||||
|
}
|
||||||
|
var max []Max
|
||||||
|
qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max)
|
||||||
|
throwFail(t, AssertIs(max[1].DeptName, "B"))
|
||||||
|
throwFail(t, AssertIs(max[1].Max, 4000))
|
||||||
|
}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNullDataTypes(t *testing.T) {
|
func TestNullDataTypes(t *testing.T) {
|
||||||
d := DataNull{}
|
d := DataNull{}
|
||||||
|
|
||||||
@ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) {
|
|||||||
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
|
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`profile__age`),
|
||||||
|
order_clause.SortDescending(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
if IsMysql {
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`rand()`),
|
||||||
|
order_clause.Raw(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAll(t *testing.T) {
|
func TestAll(t *testing.T) {
|
||||||
@ -1163,6 +1253,19 @@ func TestValues(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column("Id"),
|
||||||
|
order_clause.SortAscending(),
|
||||||
|
),
|
||||||
|
).Values(&maps)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
if num == 3 {
|
||||||
|
throwFail(t, AssertIs(maps[0]["UserName"], "slene"))
|
||||||
|
throwFail(t, AssertIs(maps[2]["Profile"], nil))
|
||||||
|
}
|
||||||
|
|
||||||
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
|
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 3))
|
throwFail(t, AssertIs(num, 3))
|
||||||
@ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(!cycleFlag, true))
|
throwFail(t, AssertIs(!cycleFlag, true))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
user := User{UserName: "slene"}
|
||||||
|
|
||||||
|
err := dORM.ReadWithCtx(ctx, &user, "UserName")
|
||||||
|
throwFail(t, err)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
err = dORM.ReadWithCtx(ctx, &user, "UserName")
|
||||||
|
throwFail(t, AssertIs(err, context.Canceled))
|
||||||
|
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
qs := dORM.QueryTable(user)
|
||||||
|
_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx)
|
||||||
|
throwFail(t, AssertIs(err, context.Canceled))
|
||||||
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
|
||||||
"github.com/beego/beego/v2/core/utils"
|
"github.com/beego/beego/v2/core/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,12 +197,16 @@ type DQL interface {
|
|||||||
// post := Post{Id: 4}
|
// post := Post{Id: 4}
|
||||||
// m2m := Ormer.QueryM2M(&post, "Tags")
|
// m2m := Ormer.QueryM2M(&post, "Tags")
|
||||||
QueryM2M(md interface{}, name string) QueryM2Mer
|
QueryM2M(md interface{}, name string) QueryM2Mer
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
|
||||||
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
|
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
|
||||||
|
|
||||||
// return a QuerySeter for table operations.
|
// return a QuerySeter for table operations.
|
||||||
// table name can be string or struct.
|
// table name can be string or struct.
|
||||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
QueryTable(ptrStructOrTableName interface{}) QuerySeter
|
QueryTable(ptrStructOrTableName interface{}) QuerySeter
|
||||||
|
// NOTE: this method is deprecated, context parameter will not take effect.
|
||||||
|
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
|
||||||
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
|
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
|
||||||
|
|
||||||
DBStats() *sql.DBStats
|
DBStats() *sql.DBStats
|
||||||
@ -230,6 +235,7 @@ type TxOrmer interface {
|
|||||||
// Inserter insert prepared statement
|
// Inserter insert prepared statement
|
||||||
type Inserter interface {
|
type Inserter interface {
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
|
InsertWithCtx(context.Context, interface{}) (int64, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,6 +295,28 @@ type QuerySeter interface {
|
|||||||
// for example:
|
// for example:
|
||||||
// qs.OrderBy("-status")
|
// qs.OrderBy("-status")
|
||||||
OrderBy(exprs ...string) QuerySeter
|
OrderBy(exprs ...string) QuerySeter
|
||||||
|
// add ORDER expression by order clauses
|
||||||
|
// for example:
|
||||||
|
// OrderClauses(
|
||||||
|
// order_clause.Clause(
|
||||||
|
// order.Column("Id"),
|
||||||
|
// order.SortAscending(),
|
||||||
|
// ),
|
||||||
|
// order_clause.Clause(
|
||||||
|
// order.Column("status"),
|
||||||
|
// order.SortDescending(),
|
||||||
|
// ),
|
||||||
|
// )
|
||||||
|
// OrderClauses(order_clause.Clause(
|
||||||
|
// order_clause.Column(`user__status`),
|
||||||
|
// order_clause.SortDescending(),//default None
|
||||||
|
// ))
|
||||||
|
// OrderClauses(order_clause.Clause(
|
||||||
|
// order_clause.Column(`random()`),
|
||||||
|
// order_clause.SortNone(),//default None
|
||||||
|
// order_clause.Raw(),//default false.if true, do not check field is valid or not
|
||||||
|
// ))
|
||||||
|
OrderClauses(orders ...*order_clause.Order) QuerySeter
|
||||||
// add FORCE INDEX expression.
|
// add FORCE INDEX expression.
|
||||||
// for example:
|
// for example:
|
||||||
// qs.ForceIndex(`idx_name1`,`idx_name2`)
|
// qs.ForceIndex(`idx_name1`,`idx_name2`)
|
||||||
@ -327,9 +355,11 @@ type QuerySeter interface {
|
|||||||
// for example:
|
// for example:
|
||||||
// num, err = qs.Filter("profile__age__gt", 28).Count()
|
// num, err = qs.Filter("profile__age__gt", 28).Count()
|
||||||
Count() (int64, error)
|
Count() (int64, error)
|
||||||
|
CountWithCtx(context.Context) (int64, error)
|
||||||
// check result empty or not after QuerySeter executed
|
// check result empty or not after QuerySeter executed
|
||||||
// the same as QuerySeter.Count > 0
|
// the same as QuerySeter.Count > 0
|
||||||
Exist() bool
|
Exist() bool
|
||||||
|
ExistWithCtx(context.Context) bool
|
||||||
// execute update with parameters
|
// execute update with parameters
|
||||||
// for example:
|
// for example:
|
||||||
// num, err = qs.Filter("user_name", "slene").Update(Params{
|
// num, err = qs.Filter("user_name", "slene").Update(Params{
|
||||||
@ -339,11 +369,13 @@ type QuerySeter interface {
|
|||||||
// "user_name": "slene2"
|
// "user_name": "slene2"
|
||||||
// }) // user slene's name will change to slene2
|
// }) // user slene's name will change to slene2
|
||||||
Update(values Params) (int64, error)
|
Update(values Params) (int64, error)
|
||||||
|
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
|
||||||
// delete from table
|
// delete from table
|
||||||
// for example:
|
// for example:
|
||||||
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
|
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
|
||||||
// //delete two user who's name is testing1 or testing2
|
// //delete two user who's name is testing1 or testing2
|
||||||
Delete() (int64, error)
|
Delete() (int64, error)
|
||||||
|
DeleteWithCtx(context.Context) (int64, error)
|
||||||
// return a insert queryer.
|
// return a insert queryer.
|
||||||
// it can be used in times.
|
// it can be used in times.
|
||||||
// example:
|
// example:
|
||||||
@ -352,18 +384,21 @@ type QuerySeter interface {
|
|||||||
// num, err = i.Insert(&user2) // user table will add one record user2 at once
|
// num, err = i.Insert(&user2) // user table will add one record user2 at once
|
||||||
// err = i.Close() //don't forget call Close
|
// err = i.Close() //don't forget call Close
|
||||||
PrepareInsert() (Inserter, error)
|
PrepareInsert() (Inserter, error)
|
||||||
|
PrepareInsertWithCtx(context.Context) (Inserter, error)
|
||||||
// query all data and map to containers.
|
// query all data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
// for example:
|
// for example:
|
||||||
// var users []*User
|
// var users []*User
|
||||||
// qs.All(&users) // users[0],users[1],users[2] ...
|
// qs.All(&users) // users[0],users[1],users[2] ...
|
||||||
All(container interface{}, cols ...string) (int64, error)
|
All(container interface{}, cols ...string) (int64, error)
|
||||||
|
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
|
||||||
// query one row data and map to containers.
|
// query one row data and map to containers.
|
||||||
// cols means the columns when querying.
|
// cols means the columns when querying.
|
||||||
// for example:
|
// for example:
|
||||||
// var user User
|
// var user User
|
||||||
// qs.One(&user) //user.UserName == "slene"
|
// qs.One(&user) //user.UserName == "slene"
|
||||||
One(container interface{}, cols ...string) error
|
One(container interface{}, cols ...string) error
|
||||||
|
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
|
||||||
// query all data and map to []map[string]interface.
|
// query all data and map to []map[string]interface.
|
||||||
// expres means condition expression.
|
// expres means condition expression.
|
||||||
// it converts data to []map[column]value.
|
// it converts data to []map[column]value.
|
||||||
@ -371,18 +406,21 @@ type QuerySeter interface {
|
|||||||
// var maps []Params
|
// var maps []Params
|
||||||
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
|
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
|
||||||
Values(results *[]Params, exprs ...string) (int64, error)
|
Values(results *[]Params, exprs ...string) (int64, error)
|
||||||
|
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
|
||||||
// query all data and map to [][]interface
|
// query all data and map to [][]interface
|
||||||
// it converts data to [][column_index]value
|
// it converts data to [][column_index]value
|
||||||
// for example:
|
// for example:
|
||||||
// var list []ParamsList
|
// var list []ParamsList
|
||||||
// qs.ValuesList(&list) // list[0][1] == "slene"
|
// qs.ValuesList(&list) // list[0][1] == "slene"
|
||||||
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
|
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
|
||||||
|
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
|
||||||
// query all data and map to []interface.
|
// query all data and map to []interface.
|
||||||
// it's designed for one column record set, auto change to []value, not [][column]value.
|
// it's designed for one column record set, auto change to []value, not [][column]value.
|
||||||
// for example:
|
// for example:
|
||||||
// var list ParamsList
|
// var list ParamsList
|
||||||
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
|
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
|
||||||
ValuesFlat(result *ParamsList, expr string) (int64, error)
|
ValuesFlat(result *ParamsList, expr string) (int64, error)
|
||||||
|
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
|
||||||
// query all rows into map[string]interface with specify key and value column name.
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
// keyCol = "name", valueCol = "value"
|
// keyCol = "name", valueCol = "value"
|
||||||
// table data
|
// table data
|
||||||
@ -405,6 +443,15 @@ type QuerySeter interface {
|
|||||||
// Found int
|
// Found int
|
||||||
// }
|
// }
|
||||||
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
||||||
|
// aggregate func.
|
||||||
|
// for example:
|
||||||
|
// type result struct {
|
||||||
|
// DeptName string
|
||||||
|
// Total int
|
||||||
|
// }
|
||||||
|
// var res []result
|
||||||
|
// o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res)
|
||||||
|
Aggregate(s string) QuerySeter
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryM2Mer model to model query struct
|
// QueryM2Mer model to model query struct
|
||||||
@ -422,18 +469,23 @@ type QueryM2Mer interface {
|
|||||||
// insert one or more rows to m2m table
|
// insert one or more rows to m2m table
|
||||||
// make sure the relation is defined in post model struct tag.
|
// make sure the relation is defined in post model struct tag.
|
||||||
Add(...interface{}) (int64, error)
|
Add(...interface{}) (int64, error)
|
||||||
|
AddWithCtx(context.Context, ...interface{}) (int64, error)
|
||||||
// remove models following the origin model relationship
|
// remove models following the origin model relationship
|
||||||
// only delete rows from m2m table
|
// only delete rows from m2m table
|
||||||
// for example:
|
// for example:
|
||||||
// tag3 := &Tag{Id:5,Name: "TestTag3"}
|
// tag3 := &Tag{Id:5,Name: "TestTag3"}
|
||||||
// num, err = m2m.Remove(tag3)
|
// num, err = m2m.Remove(tag3)
|
||||||
Remove(...interface{}) (int64, error)
|
Remove(...interface{}) (int64, error)
|
||||||
|
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
|
||||||
// check model is existed in relationship of origin model
|
// check model is existed in relationship of origin model
|
||||||
Exist(interface{}) bool
|
Exist(interface{}) bool
|
||||||
|
ExistWithCtx(context.Context, interface{}) bool
|
||||||
// clean all models in related of origin model
|
// clean all models in related of origin model
|
||||||
Clear() (int64, error)
|
Clear() (int64, error)
|
||||||
|
ClearWithCtx(context.Context) (int64, error)
|
||||||
// count all related models of origin model
|
// count all related models of origin model
|
||||||
Count() (int64, error)
|
Count() (int64, error)
|
||||||
|
CountWithCtx(context.Context) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawPreparer raw query statement
|
// RawPreparer raw query statement
|
||||||
@ -507,11 +559,11 @@ type RawSeter interface {
|
|||||||
type stmtQuerier interface {
|
type stmtQuerier interface {
|
||||||
Close() error
|
Close() error
|
||||||
Exec(args ...interface{}) (sql.Result, error)
|
Exec(args ...interface{}) (sql.Result, error)
|
||||||
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||||
Query(args ...interface{}) (*sql.Rows, error)
|
Query(args ...interface{}) (*sql.Rows, error)
|
||||||
// QueryContext(args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRow(args ...interface{}) *sql.Row
|
QueryRow(args ...interface{}) *sql.Row
|
||||||
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
// db querier
|
// db querier
|
||||||
@ -548,28 +600,28 @@ type txEnder interface {
|
|||||||
|
|
||||||
// base database struct
|
// base database struct
|
||||||
type dbBaser interface {
|
type dbBaser interface {
|
||||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
|
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
|
||||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
||||||
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||||
|
|
||||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
||||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
|
|
||||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
||||||
|
|
||||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||||
|
|
||||||
SupportUpdateJoin() bool
|
SupportUpdateJoin() bool
|
||||||
OperatorSQL(string) string
|
OperatorSQL(string) string
|
||||||
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
|
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
|
||||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||||
MaxLimit() uint64
|
MaxLimit() uint64
|
||||||
TableQuote() string
|
TableQuote() string
|
||||||
ReplaceMarks(*string)
|
ReplaceMarks(*string)
|
||||||
@ -578,12 +630,12 @@ type dbBaser interface {
|
|||||||
TimeToDB(*time.Time, *time.Location)
|
TimeToDB(*time.Time, *time.Location)
|
||||||
DbTypes() map[string]string
|
DbTypes() map[string]string
|
||||||
GetTables(dbQuerier) (map[string]bool, error)
|
GetTables(dbQuerier) (map[string]bool, error)
|
||||||
GetColumns(dbQuerier, string) (map[string][3]string, error)
|
GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error)
|
||||||
ShowTablesQuery() string
|
ShowTablesQuery() string
|
||||||
ShowColumnsQuery(string) string
|
ShowColumnsQuery(string) string
|
||||||
IndexExists(dbQuerier, string, string) bool
|
IndexExists(context.Context, dbQuerier, string, string) bool
|
||||||
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
||||||
setval(dbQuerier, *modelInfo, []string) error
|
setval(context.Context, dbQuerier, *modelInfo, []string) error
|
||||||
|
|
||||||
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
|
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
|
||||||
}
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -25,7 +25,7 @@ require (
|
|||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible
|
github.com/gomodule/redigo v2.0.0+incompatible
|
||||||
github.com/google/go-cmp v0.5.0 // indirect
|
github.com/google/go-cmp v0.5.0 // indirect
|
||||||
github.com/google/uuid v1.1.1 // indirect
|
github.com/google/uuid v1.1.1
|
||||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
|
||||||
github.com/hashicorp/golang-lru v0.5.4
|
github.com/hashicorp/golang-lru v0.5.4
|
||||||
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6
|
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6
|
||||||
|
|||||||
@ -112,13 +112,13 @@ func registerAdmin() error {
|
|||||||
HttpServer: NewHttpServerWithCfg(BConfig),
|
HttpServer: NewHttpServerWithCfg(BConfig),
|
||||||
}
|
}
|
||||||
// keep in mind that all data should be html escaped to avoid XSS attack
|
// keep in mind that all data should be html escaped to avoid XSS attack
|
||||||
beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex"))
|
beeAdminApp.Router("/", c, "get:AdminIndex")
|
||||||
beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex"))
|
beeAdminApp.Router("/qps", c, "get:QpsIndex")
|
||||||
beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex"))
|
beeAdminApp.Router("/prof", c, "get:ProfIndex")
|
||||||
beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck"))
|
beeAdminApp.Router("/healthcheck", c, "get:Healthcheck")
|
||||||
beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus"))
|
beeAdminApp.Router("/task", c, "get:TaskStatus")
|
||||||
beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf"))
|
beeAdminApp.Router("/listconf", c, "get:ListConf")
|
||||||
beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics"))
|
beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics")
|
||||||
|
|
||||||
go beeAdminApp.Run()
|
go beeAdminApp.Run()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Session return session store of this context of request
|
||||||
|
func (ctx *Context) Session() (store session.Store, err error) {
|
||||||
|
if ctx.Input != nil {
|
||||||
|
if ctx.Input.CruSession != nil {
|
||||||
|
store = ctx.Input.CruSession
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
err = errors.New(`no valid session store(please initialize session)`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = errors.New(`no valid input`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Response is a wrapper for the http.ResponseWriter
|
// Response is a wrapper for the http.ResponseWriter
|
||||||
// Started: if true, response was already written to so the other handler will not be executed
|
// Started: if true, response was already written to so the other handler will not be executed
|
||||||
type Response struct {
|
type Response struct {
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package context
|
package context
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) {
|
|||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContext_Session(t *testing.T) {
|
||||||
|
c := NewContext()
|
||||||
|
if store, err := c.Session(); store != nil || err == nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_Session1(t *testing.T) {
|
||||||
|
c := Context{}
|
||||||
|
if store, err := c.Session(); store != nil || err == nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_Session2(t *testing.T) {
|
||||||
|
c := NewContext()
|
||||||
|
c.Input.CruSession = &session.MemSessionStore{}
|
||||||
|
|
||||||
|
if store, err := c.Session(); store == nil || err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@ -37,4 +38,5 @@ func TestFilterChain(t *testing.T) {
|
|||||||
ctx.Input.SetData("RouterPattern", "my-route")
|
ctx.Input.SetData("RouterPattern", "my-route")
|
||||||
filter(ctx)
|
filter(ctx)
|
||||||
assert.True(t, ctx.Input.GetData("invocation").(bool))
|
assert.True(t, ctx.Input.GetData("invocation").(bool))
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
35
server/web/filter/session/filter.go
Normal file
35
server/web/filter/session/filter.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/beego/beego/v2/core/logs"
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
webContext "github.com/beego/beego/v2/server/web/context"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
//Session maintain session for web service
|
||||||
|
//Session new a session storage and store it into webContext.Context
|
||||||
|
func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain {
|
||||||
|
sessionConfig := session.NewManagerConfig(options...)
|
||||||
|
sessionManager, _ := session.NewManager(string(providerType), sessionConfig)
|
||||||
|
go sessionManager.GC()
|
||||||
|
|
||||||
|
return func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if ctx.Input.CruSession != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil {
|
||||||
|
logs.Error(`init session error:%s`, err.Error())
|
||||||
|
} else {
|
||||||
|
//release session at the end of request
|
||||||
|
defer sess.SessionRelease(context.Background(), ctx.ResponseWriter)
|
||||||
|
ctx.Input.CruSession = sess
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
86
server/web/filter/session/filter_test.go
Normal file
86
server/web/filter/session/filter_test.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/beego/beego/v2/server/web"
|
||||||
|
webContext "github.com/beego/beego/v2/server/web/context"
|
||||||
|
"github.com/beego/beego/v2/server/web/session"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) {
|
||||||
|
r, _ := http.NewRequest(method, path, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != code {
|
||||||
|
t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSession(t *testing.T) {
|
||||||
|
storeKey := uuid.New().String()
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
Session(
|
||||||
|
session.ProviderMemory,
|
||||||
|
session.CfgCookieName(`go_session_id`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if store := ctx.Input.GetData(storeKey); store == nil {
|
||||||
|
t.Error(`store should not be nil`)
|
||||||
|
}
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler.Any("*", func(ctx *webContext.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSession1(t *testing.T) {
|
||||||
|
handler := web.NewControllerRegister()
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
Session(
|
||||||
|
session.ProviderMemory,
|
||||||
|
session.CfgCookieName(`go_session_id`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
handler.InsertFilterChain(
|
||||||
|
"*",
|
||||||
|
func(next web.FilterFunc) web.FilterFunc {
|
||||||
|
return func(ctx *webContext.Context) {
|
||||||
|
if store, err := ctx.Session(); store == nil || err != nil {
|
||||||
|
t.Error(`store should not be nil`)
|
||||||
|
}
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler.Any("*", func(ctx *webContext.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
|
||||||
|
}
|
||||||
@ -15,9 +15,12 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
|
|||||||
ns := NewNamespace("/chain")
|
ns := NewNamespace("/chain")
|
||||||
|
|
||||||
ns.Get("/*", func(ctx *context.Context) {
|
ns.Get("/*", func(ctx *context.Context) {
|
||||||
ctx.Output.Body([]byte("hello"))
|
_ = ctx.Output.Body([]byte("hello"))
|
||||||
})
|
})
|
||||||
|
|
||||||
r, _ := http.NewRequest("GET", "/chain/user", nil)
|
r, _ := http.NewRequest("GET", "/chain/user", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
BeeApp.Handlers.Init()
|
||||||
BeeApp.Handlers.ServeHTTP(w, r)
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
|
||||||
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
|
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
|
||||||
|
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ := http.NewRequest("GET", "/abc", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
BeeApp.Handlers.Init()
|
||||||
|
BeeApp.Handlers.ServeHTTP(w, r)
|
||||||
|
first := w.Header().Get("first")
|
||||||
|
second := w.Header().Get("second")
|
||||||
|
|
||||||
|
ft, _ := strconv.ParseInt(first, 10, 64)
|
||||||
|
st, _ := strconv.ParseInt(second, 10, 64)
|
||||||
|
|
||||||
|
assert.True(t, st > ft)
|
||||||
|
}
|
||||||
|
|||||||
@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) {
|
|||||||
|
|
||||||
// setup the handler
|
// setup the handler
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
|
handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
// get the Set-Cookie value
|
// get the Set-Cookie value
|
||||||
|
|||||||
@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
|
|||||||
// Router same as beego.Rourer
|
// Router same as beego.Rourer
|
||||||
// refer: https://godoc.org/github.com/beego/beego/v2#Router
|
// refer: https://godoc.org/github.com/beego/beego/v2#Router
|
||||||
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
|
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
|
||||||
n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...))
|
n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -121,24 +121,30 @@ type ControllerInfo struct {
|
|||||||
sessionOn bool
|
sessionOn bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ControllerOptions func(*ControllerInfo)
|
type ControllerOption func(*ControllerInfo)
|
||||||
|
|
||||||
func (c *ControllerInfo) GetPattern() string {
|
func (c *ControllerInfo) GetPattern() string {
|
||||||
return c.pattern
|
return c.pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions {
|
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
|
||||||
return func(c *ControllerInfo) {
|
return func(c *ControllerInfo) {
|
||||||
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
|
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRouterSessionOn(sessionOn bool) ControllerOptions {
|
func WithRouterSessionOn(sessionOn bool) ControllerOption {
|
||||||
return func(c *ControllerInfo) {
|
return func(c *ControllerInfo) {
|
||||||
c.sessionOn = sessionOn
|
c.sessionOn = sessionOn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type filterChainConfig struct {
|
||||||
|
pattern string
|
||||||
|
chain FilterChain
|
||||||
|
opts []FilterOpt
|
||||||
|
}
|
||||||
|
|
||||||
// ControllerRegister containers registered router rules, controller handlers and filters.
|
// ControllerRegister containers registered router rules, controller handlers and filters.
|
||||||
type ControllerRegister struct {
|
type ControllerRegister struct {
|
||||||
routers map[string]*Tree
|
routers map[string]*Tree
|
||||||
@ -151,6 +157,9 @@ type ControllerRegister struct {
|
|||||||
// the filter created by FilterChain
|
// the filter created by FilterChain
|
||||||
chainRoot *FilterRouter
|
chainRoot *FilterRouter
|
||||||
|
|
||||||
|
// keep registered chain and build it when serve http
|
||||||
|
filterChains []filterChainConfig
|
||||||
|
|
||||||
cfg *Config
|
cfg *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
filterChains: make([]filterChainConfig, 0, 4),
|
||||||
}
|
}
|
||||||
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
|
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init will be executed when HttpServer start running
|
||||||
|
func (p *ControllerRegister) Init() {
|
||||||
|
for i := len(p.filterChains) - 1; i >= 0 ; i -- {
|
||||||
|
fc := p.filterChains[i]
|
||||||
|
root := p.chainRoot
|
||||||
|
filterFunc := fc.chain(root.filterFunc)
|
||||||
|
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
|
||||||
|
p.chainRoot.next = root
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add controller handler and pattern rules to ControllerRegister.
|
// Add controller handler and pattern rules to ControllerRegister.
|
||||||
// usage:
|
// usage:
|
||||||
// default methods is the same name as method
|
// default methods is the same name as method
|
||||||
@ -186,7 +207,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
|
|||||||
// Add("/api/delete",&RestController{},"delete:DeleteFood")
|
// Add("/api/delete",&RestController{},"delete:DeleteFood")
|
||||||
// Add("/api",&RestController{},"get,post:ApiFunc"
|
// Add("/api",&RestController{},"get,post:ApiFunc"
|
||||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||||
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) {
|
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) {
|
||||||
p.addWithMethodParams(pattern, c, nil, opts...)
|
p.addWithMethodParams(pattern, c, nil, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,7 +260,7 @@ func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) {
|
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) {
|
||||||
reflectVal := reflect.ValueOf(c)
|
reflectVal := reflect.ValueOf(c)
|
||||||
t := reflect.Indirect(reflectVal).Type()
|
t := reflect.Indirect(reflectVal).Type()
|
||||||
|
|
||||||
@ -311,7 +332,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
|
|||||||
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
|
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
|
||||||
}
|
}
|
||||||
|
|
||||||
p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
|
p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
|
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
|
||||||
root := p.chainRoot
|
|
||||||
filterFunc := chain(root.filterFunc)
|
|
||||||
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
|
|
||||||
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
|
|
||||||
p.chainRoot.next = root
|
|
||||||
|
|
||||||
|
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
|
||||||
|
p.filterChains = append(p.filterChains, filterChainConfig{
|
||||||
|
pattern: pattern,
|
||||||
|
chain: chain,
|
||||||
|
opts: opts,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// add Filter into
|
// add Filter into
|
||||||
@ -583,7 +605,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
|
|||||||
for _, l := range t.leaves {
|
for _, l := range t.leaves {
|
||||||
if c, ok := l.runObject.(*ControllerInfo); ok {
|
if c, ok := l.runObject.(*ControllerInfo); ok {
|
||||||
if c.routerType == routerTypeBeego &&
|
if c.routerType == routerTypeBeego &&
|
||||||
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
|
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) {
|
||||||
find := false
|
find := false
|
||||||
if HTTPMETHOD[strings.ToUpper(methodName)] {
|
if HTTPMETHOD[strings.ToUpper(methodName)] {
|
||||||
if len(c.methods) == 0 {
|
if len(c.methods) == 0 {
|
||||||
|
|||||||
@ -26,6 +26,14 @@ import (
|
|||||||
"github.com/beego/beego/v2/server/web/context"
|
"github.com/beego/beego/v2/server/web/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PrefixTestController struct {
|
||||||
|
Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ptc *PrefixTestController) PrefixList() {
|
||||||
|
ptc.Ctx.Output.Body([]byte("i am list in prefix test"))
|
||||||
|
}
|
||||||
|
|
||||||
type TestController struct {
|
type TestController struct {
|
||||||
Controller
|
Controller
|
||||||
}
|
}
|
||||||
@ -87,10 +95,24 @@ func (jc *JSONController) Get() {
|
|||||||
jc.Ctx.Output.Body([]byte("ok"))
|
jc.Ctx.Output.Body([]byte("ok"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrefixUrlFor(t *testing.T){
|
||||||
|
handler := NewControllerRegister()
|
||||||
|
handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList"))
|
||||||
|
|
||||||
|
if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` {
|
||||||
|
logs.Info(a)
|
||||||
|
t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list")
|
||||||
|
}
|
||||||
|
if a := handler.URLFor(`TestController.PrefixList`); a != `` {
|
||||||
|
logs.Info(a)
|
||||||
|
t.Errorf("TestController.PrefixList must equal to empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUrlFor(t *testing.T) {
|
func TestUrlFor(t *testing.T) {
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
|
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param"))
|
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
|
||||||
if a := handler.URLFor("TestController.List"); a != "/api/list" {
|
if a := handler.URLFor("TestController.List"); a != "/api/list" {
|
||||||
logs.Info(a)
|
logs.Info(a)
|
||||||
t.Errorf("TestController.List must equal to /api/list")
|
t.Errorf("TestController.List must equal to /api/list")
|
||||||
@ -113,9 +135,9 @@ func TestUrlFor3(t *testing.T) {
|
|||||||
|
|
||||||
func TestUrlFor2(t *testing.T) {
|
func TestUrlFor2(t *testing.T) {
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
|
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL"))
|
handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL"))
|
||||||
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param"))
|
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
|
||||||
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
|
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
|
||||||
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
|
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
|
||||||
logs.Info(handler.URLFor("TestController.GetURL"))
|
logs.Info(handler.URLFor("TestController.GetURL"))
|
||||||
@ -145,7 +167,7 @@ func TestUserFunc(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
|
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Body.String() != "i am list" {
|
if w.Body.String() != "i am list" {
|
||||||
t.Errorf("user define func can't run")
|
t.Errorf("user define func can't run")
|
||||||
@ -235,7 +257,7 @@ func TestRouteOk(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams"))
|
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
if body != "anderson+thomas+kungfu" {
|
if body != "anderson+thomas+kungfu" {
|
||||||
@ -249,7 +271,7 @@ func TestManyRoute(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter"))
|
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
body := w.Body.String()
|
body := w.Body.String()
|
||||||
@ -266,7 +288,7 @@ func TestEmptyResponse(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody"))
|
handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody"))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
if body := w.Body.String(); body != "" {
|
if body := w.Body.String(); body != "" {
|
||||||
@ -761,8 +783,8 @@ func TestRouterSessionSet(t *testing.T) {
|
|||||||
r, _ := http.NewRequest("GET", "/user", nil)
|
r, _ := http.NewRequest("GET", "/user", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
SetRouterSessionOn(false))
|
WithRouterSessionOn(false))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Header().Get("Set-Cookie") != "" {
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
t.Errorf("TestRotuerSessionSet failed")
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
@ -772,8 +794,8 @@ func TestRouterSessionSet(t *testing.T) {
|
|||||||
r, _ = http.NewRequest("GET", "/user", nil)
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
handler = NewControllerRegister()
|
handler = NewControllerRegister()
|
||||||
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
SetRouterSessionOn(true))
|
WithRouterSessionOn(true))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Header().Get("Set-Cookie") != "" {
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
t.Errorf("TestRotuerSessionSet failed")
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
@ -787,8 +809,8 @@ func TestRouterSessionSet(t *testing.T) {
|
|||||||
r, _ = http.NewRequest("GET", "/user", nil)
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
handler = NewControllerRegister()
|
handler = NewControllerRegister()
|
||||||
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
SetRouterSessionOn(false))
|
WithRouterSessionOn(false))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Header().Get("Set-Cookie") != "" {
|
if w.Header().Get("Set-Cookie") != "" {
|
||||||
t.Errorf("TestRotuerSessionSet failed")
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
@ -798,8 +820,8 @@ func TestRouterSessionSet(t *testing.T) {
|
|||||||
r, _ = http.NewRequest("GET", "/user", nil)
|
r, _ = http.NewRequest("GET", "/user", nil)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
handler = NewControllerRegister()
|
handler = NewControllerRegister()
|
||||||
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
|
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
|
||||||
SetRouterSessionOn(true))
|
WithRouterSessionOn(true))
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if w.Header().Get("Set-Cookie") == "" {
|
if w.Header().Get("Set-Cookie") == "" {
|
||||||
t.Errorf("TestRotuerSessionSet failed")
|
t.Errorf("TestRotuerSessionSet failed")
|
||||||
|
|||||||
@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
|||||||
|
|
||||||
initBeforeHTTPRun()
|
initBeforeHTTPRun()
|
||||||
|
|
||||||
|
// init...
|
||||||
app.initAddr(addr)
|
app.initAddr(addr)
|
||||||
|
app.Handlers.Init()
|
||||||
|
|
||||||
addr = app.Cfg.Listen.HTTPAddr
|
addr = app.Cfg.Listen.HTTPAddr
|
||||||
|
|
||||||
@ -266,8 +268,12 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Router see HttpServer.Router
|
// Router see HttpServer.Router
|
||||||
func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer {
|
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
||||||
return BeeApp.Router(rootpath, c, opts...)
|
return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
|
||||||
|
return BeeApp.RouterWithOpts(rootpath, c, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Router adds a patterned controller handler to BeeApp.
|
// Router adds a patterned controller handler to BeeApp.
|
||||||
@ -286,7 +292,11 @@ func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *
|
|||||||
// beego.Router("/api/create",&RestController{},"post:CreateFood")
|
// beego.Router("/api/create",&RestController{},"post:CreateFood")
|
||||||
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
|
||||||
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
|
||||||
func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer {
|
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
|
||||||
|
return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
|
||||||
app.Handlers.Add(rootPath, c, opts...)
|
app.Handlers.Add(rootPath, c, opts...)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,21 +15,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRedis(t *testing.T) {
|
func TestRedis(t *testing.T) {
|
||||||
sessionConfig := &session.ManagerConfig{
|
|
||||||
CookieName: "gosessionid",
|
|
||||||
EnableSetCookie: true,
|
|
||||||
Gclifetime: 3600,
|
|
||||||
Maxlifetime: 3600,
|
|
||||||
Secure: false,
|
|
||||||
CookieLifeTime: 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
redisAddr := os.Getenv("REDIS_ADDR")
|
redisAddr := os.Getenv("REDIS_ADDR")
|
||||||
if redisAddr == "" {
|
if redisAddr == "" {
|
||||||
redisAddr = "127.0.0.1:6379"
|
redisAddr = "127.0.0.1:6379"
|
||||||
}
|
}
|
||||||
|
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
|
||||||
|
|
||||||
|
sessionConfig := session.NewManagerConfig(
|
||||||
|
session.CfgCookieName(`gosessionid`),
|
||||||
|
session.CfgSetCookie(true),
|
||||||
|
session.CfgGcLifeTime(3600),
|
||||||
|
session.CfgMaxLifeTime(3600),
|
||||||
|
session.CfgSecure(false),
|
||||||
|
session.CfgCookieLifeTime(3600),
|
||||||
|
session.CfgProviderConfig(redisConfig),
|
||||||
|
)
|
||||||
|
|
||||||
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
|
|
||||||
globalSession, err := session.NewManager("redis", sessionConfig)
|
globalSession, err := session.NewManager("redis", sessionConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("could not create manager:", err)
|
t.Fatal("could not create manager:", err)
|
||||||
|
|||||||
@ -13,15 +13,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRedisSentinel(t *testing.T) {
|
func TestRedisSentinel(t *testing.T) {
|
||||||
sessionConfig := &session.ManagerConfig{
|
sessionConfig := session.NewManagerConfig(
|
||||||
CookieName: "gosessionid",
|
session.CfgCookieName(`gosessionid`),
|
||||||
EnableSetCookie: true,
|
session.CfgSetCookie(true),
|
||||||
Gclifetime: 3600,
|
session.CfgGcLifeTime(3600),
|
||||||
Maxlifetime: 3600,
|
session.CfgMaxLifeTime(3600),
|
||||||
Secure: false,
|
session.CfgSecure(false),
|
||||||
CookieLifeTime: 3600,
|
session.CfgCookieLifeTime(3600),
|
||||||
ProviderConfig: "127.0.0.1:6379,100,,0,master",
|
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
|
||||||
}
|
)
|
||||||
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
|
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
t.Log(e)
|
t.Log(e)
|
||||||
|
|||||||
@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) {
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerConfig define the session config
|
|
||||||
type ManagerConfig struct {
|
|
||||||
CookieName string `json:"cookieName"`
|
|
||||||
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
|
||||||
Gclifetime int64 `json:"gclifetime"`
|
|
||||||
Maxlifetime int64 `json:"maxLifetime"`
|
|
||||||
DisableHTTPOnly bool `json:"disableHTTPOnly"`
|
|
||||||
Secure bool `json:"secure"`
|
|
||||||
CookieLifeTime int `json:"cookieLifeTime"`
|
|
||||||
ProviderConfig string `json:"providerConfig"`
|
|
||||||
Domain string `json:"domain"`
|
|
||||||
SessionIDLength int64 `json:"sessionIDLength"`
|
|
||||||
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
|
|
||||||
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
|
|
||||||
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
|
|
||||||
SessionIDPrefix string `json:"sessionIDPrefix"`
|
|
||||||
CookieSameSite http.SameSite `json:"cookieSameSite"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager contains Provider and its configuration.
|
// Manager contains Provider and its configuration.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
provider Provider
|
provider Provider
|
||||||
|
|||||||
143
server/web/session/session_config.go
Normal file
143
server/web/session/session_config.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// ManagerConfig define the session config
|
||||||
|
type ManagerConfig struct {
|
||||||
|
CookieName string `json:"cookieName"`
|
||||||
|
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
||||||
|
Gclifetime int64 `json:"gclifetime"`
|
||||||
|
Maxlifetime int64 `json:"maxLifetime"`
|
||||||
|
DisableHTTPOnly bool `json:"disableHTTPOnly"`
|
||||||
|
Secure bool `json:"secure"`
|
||||||
|
CookieLifeTime int `json:"cookieLifeTime"`
|
||||||
|
ProviderConfig string `json:"providerConfig"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
SessionIDLength int64 `json:"sessionIDLength"`
|
||||||
|
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
|
||||||
|
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
|
||||||
|
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
|
||||||
|
SessionIDPrefix string `json:"sessionIDPrefix"`
|
||||||
|
CookieSameSite http.SameSite `json:"cookieSameSite"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) {
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagerConfigOpt func(config *ManagerConfig)
|
||||||
|
|
||||||
|
func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig {
|
||||||
|
config := &ManagerConfig{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(config)
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgCookieName set key of session id
|
||||||
|
func CfgCookieName(cookieName string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieName = cookieName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgCookieName set len of session id
|
||||||
|
func CfgSessionIdLength(len int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionIDLength = len
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CfgSessionIdPrefix set prefix of session id
|
||||||
|
func CfgSessionIdPrefix(prefix string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionIDPrefix = prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSetCookie whether set `Set-Cookie` header in HTTP response
|
||||||
|
func CfgSetCookie(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSetCookie = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgGcLifeTime set session gc lift time
|
||||||
|
func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Gclifetime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgMaxLifeTime set session lift time
|
||||||
|
func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Maxlifetime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgGcLifeTime set session lift time
|
||||||
|
func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieLifeTime = lifeTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgProviderConfig configure session provider
|
||||||
|
func CfgProviderConfig(providerConfig string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.ProviderConfig = providerConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgDomain set cookie domain
|
||||||
|
func CfgDomain(domain string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Domain = domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSessionIdInHTTPHeader enable session id in http header
|
||||||
|
func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSidInHTTPHeader = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSetSessionNameInHTTPHeader set key of session id in http header
|
||||||
|
func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.SessionNameInHTTPHeader = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//EnableSidInURLQuery enable session id in query string
|
||||||
|
func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.EnableSidInURLQuery = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//DisableHTTPOnly set HTTPOnly for http.Cookie
|
||||||
|
func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.DisableHTTPOnly = !HTTPOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSecure set Secure for http.Cookie
|
||||||
|
func CfgSecure(Enable bool) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.Secure = Enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//CfgSameSite set http.SameSite
|
||||||
|
func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt {
|
||||||
|
return func(config *ManagerConfig) {
|
||||||
|
config.CookieSameSite = sameSite
|
||||||
|
}
|
||||||
|
}
|
||||||
222
server/web/session/session_config_test.go
Normal file
222
server/web/session/session_config_test.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCfgCookieLifeTime(t *testing.T) {
|
||||||
|
value := 8754
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgCookieLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieLifeTime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgDomain(t *testing.T) {
|
||||||
|
value := `http://domain.com`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgDomain(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Domain != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSameSite(t *testing.T) {
|
||||||
|
value := http.SameSiteLaxMode
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSameSite(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieSameSite != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSecure(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSecure(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Secure != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSecure1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSecure(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Secure != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdPrefix(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdPrefix(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionIDPrefix != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetSessionNameInHTTPHeader(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetSessionNameInHTTPHeader(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionNameInHTTPHeader != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgCookieName(t *testing.T) {
|
||||||
|
value := `sodiausodkljalsd`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgCookieName(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.CookieName != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgEnableSidInURLQuery(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgEnableSidInURLQuery(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInURLQuery != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgGcLifeTime(t *testing.T) {
|
||||||
|
value := int64(5454)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgGcLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Gclifetime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgHTTPOnly(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgHTTPOnly(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.DisableHTTPOnly != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgHTTPOnly2(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgHTTPOnly(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.DisableHTTPOnly != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgMaxLifeTime(t *testing.T) {
|
||||||
|
value := int64(5454)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgMaxLifeTime(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.Maxlifetime != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgProviderConfig(t *testing.T) {
|
||||||
|
value := `asodiuasldkj12i39809as`
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgProviderConfig(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.ProviderConfig != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdInHTTPHeader(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdInHTTPHeader(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInHTTPHeader != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdInHTTPHeader1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdInHTTPHeader(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSidInHTTPHeader != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSessionIdLength(t *testing.T) {
|
||||||
|
value := int64(100)
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSessionIdLength(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.SessionIDLength != value {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetCookie(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetCookie(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSetCookie != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCfgSetCookie1(t *testing.T) {
|
||||||
|
c := NewManagerConfig(
|
||||||
|
CfgSetCookie(false),
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.EnableSetCookie != false {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewManagerConfig(t *testing.T) {
|
||||||
|
c := NewManagerConfig()
|
||||||
|
if c == nil {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerConfig_Opts(t *testing.T) {
|
||||||
|
c := NewManagerConfig()
|
||||||
|
c.Opts(CfgSetCookie(true))
|
||||||
|
|
||||||
|
if c.EnableSetCookie != true {
|
||||||
|
t.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
18
server/web/session/session_provider_type.go
Normal file
18
server/web/session/session_provider_type.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
type ProviderType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProviderCookie ProviderType = `cookie`
|
||||||
|
ProviderFile ProviderType = `file`
|
||||||
|
ProviderMemory ProviderType = `memory`
|
||||||
|
ProviderCouchbase ProviderType = `couchbase`
|
||||||
|
ProviderLedis ProviderType = `ledis`
|
||||||
|
ProviderMemcache ProviderType = `memcache`
|
||||||
|
ProviderMysql ProviderType = `mysql`
|
||||||
|
ProviderPostgresql ProviderType = `postgresql`
|
||||||
|
ProviderRedis ProviderType = `redis`
|
||||||
|
ProviderRedisCluster ProviderType = `redis_cluster`
|
||||||
|
ProviderRedisSentinel ProviderType = `redis_sentinel`
|
||||||
|
ProviderSsdb ProviderType = `ssdb`
|
||||||
|
)
|
||||||
@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler, "Test original root",
|
testHelperFnContentCheck(t, handler, "Test original root",
|
||||||
@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the root path TestPreUnregController action with the action from
|
// Replace the root path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
|
handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
|
||||||
|
|
||||||
// Test replacement root (expect change)
|
// Test replacement root (expect change)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
|
||||||
@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler,
|
testHelperFnContentCheck(t, handler,
|
||||||
@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the "level1" path TestPreUnregController action with the action from
|
// Replace the "level1" path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
|
handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
|
||||||
|
|
||||||
// Test replacement root (expect no change from the original)
|
// Test replacement root (expect no change from the original)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
||||||
@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
|
|||||||
var method = "GET"
|
var method = "GET"
|
||||||
|
|
||||||
handler := NewControllerRegister()
|
handler := NewControllerRegister()
|
||||||
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
|
||||||
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
|
||||||
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test original root
|
// Test original root
|
||||||
testHelperFnContentCheck(t, handler,
|
testHelperFnContentCheck(t, handler,
|
||||||
@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
|
|||||||
|
|
||||||
// Replace the "/level1/level2" path TestPreUnregController action with the action from
|
// Replace the "/level1/level2" path TestPreUnregController action with the action from
|
||||||
// TestPostUnregController
|
// TestPostUnregController
|
||||||
handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
|
handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
|
||||||
|
|
||||||
// Test replacement root (expect no change from the original)
|
// Test replacement root (expect no change from the original)
|
||||||
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
|
||||||
|
|||||||
7
sonar-project.properties
Normal file
7
sonar-project.properties
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
sonar.organization=beego
|
||||||
|
sonar.projectKey=beego_beego
|
||||||
|
|
||||||
|
# relative paths to source directories. More details and properties are described
|
||||||
|
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
|
||||||
|
sonar.sources=.
|
||||||
|
sonar.exclusions=**/*_test.go
|
||||||
@ -109,6 +109,23 @@ func TestTask_Run(t *testing.T) {
|
|||||||
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
|
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCrudTask(t *testing.T) {
|
||||||
|
m := newTaskManager()
|
||||||
|
m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
m.DeleteTask("my-task1")
|
||||||
|
assert.Equal(t, 1, len(m.adminTaskList))
|
||||||
|
|
||||||
|
m.ClearTask()
|
||||||
|
assert.Equal(t, 0, len(m.adminTaskList))
|
||||||
|
}
|
||||||
|
|
||||||
func wait(wg *sync.WaitGroup) chan bool {
|
func wait(wg *sync.WaitGroup) chan bool {
|
||||||
ch := make(chan bool)
|
ch := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user