Merge branch 'develop' of https://github.com/beego/beego into orm-mock

This commit is contained in:
Ming Deng 2021-01-11 23:42:15 +08:00
commit 89eeedbf9d
56 changed files with 1581 additions and 342 deletions

View File

@ -8,7 +8,7 @@ on:
pull_request:
types: [opened, synchronize, reopened, labeled, unlabeled]
branches:
- master
- develop
jobs:
changelog:

32
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,32 @@
name: golangci-lint
on:
push:
tags:
- v*
branches:
- master
- main
pull_request:
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
version: v1.29
# Optional: working directory, useful for monorepos
# working-directory: ./
# Optional: golangci-lint command line arguments.
args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true
# Optional: show only new issues if it's a pull request. The default value is `false`.
only-new-issues: true
# Optional: if set to true then the action will use pre-installed Go
# skip-go-installation: true

View File

@ -1,6 +1,11 @@
# developing
- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433)
- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427)
- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398)
- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391)
- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385)
- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385)
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294)
- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404)
- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424)

View File

@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare {
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...)))
return (*App)(web.Router(rootpath, c, mappingMethods...))
}
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful

View File

@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
(*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...))
(*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...))
}
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller

View File

@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) {
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
}
func TestFilterChainOrder(t *testing.T) {
req := Get("http://beego.me")
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("first"), nil
}
})
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("second"), nil
}
})
resp, err := req.DoRequestWithCtx(context.Background())
assert.Nil(t, err)
data := make([]byte, 5)
_, _ = resp.Body.Read(data)
assert.Equal(t, "first", string(data))
}
func TestHead(t *testing.T) {
req := Head("http://beego.me")
assert.NotNil(t, req)

View File

@ -0,0 +1,6 @@
package clauses
const (
ExprSep = "__"
ExprDot = "."
)

View File

@ -0,0 +1,103 @@
package order_clause
import (
"github.com/beego/beego/v2/client/orm/clauses"
"strings"
)
type Sort int8
const (
None Sort = 0
Ascending Sort = 1
Descending Sort = 2
)
type Option func(order *Order)
type Order struct {
column string
sort Sort
isRaw bool
}
func Clause(options ...Option) *Order {
o := &Order{}
for _, option := range options {
option(o)
}
return o
}
func (o *Order) GetColumn() string {
return o.column
}
func (o *Order) GetSort() Sort {
return o.sort
}
func (o *Order) SortString() string {
switch o.GetSort() {
case Ascending:
return "ASC"
case Descending:
return "DESC"
}
return ``
}
func (o *Order) IsRaw() bool {
return o.isRaw
}
func ParseOrder(expressions ...string) []*Order {
var orders []*Order
for _, expression := range expressions {
sort := Ascending
column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot)
if column[0] == '-' {
sort = Descending
column = column[1:]
}
orders = append(orders, &Order{
column: column,
sort: sort,
})
}
return orders
}
func Column(column string) Option {
return func(order *Order) {
order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot)
}
}
func sort(sort Sort) Option {
return func(order *Order) {
order.sort = sort
}
}
func SortAscending() Option {
return sort(Ascending)
}
func SortDescending() Option {
return sort(Descending)
}
func SortNone() Option {
return sort(None)
}
func Raw() Option {
return func(order *Order) {
order.isRaw = true
}
}

View File

@ -0,0 +1,144 @@
package order_clause
import (
"testing"
)
func TestClause(t *testing.T) {
var (
column = `a`
)
o := Clause(
Column(column),
)
if o.GetColumn() != column {
t.Error()
}
}
func TestSortAscending(t *testing.T) {
o := Clause(
SortAscending(),
)
if o.GetSort() != Ascending {
t.Error()
}
}
func TestSortDescending(t *testing.T) {
o := Clause(
SortDescending(),
)
if o.GetSort() != Descending {
t.Error()
}
}
func TestSortNone(t *testing.T) {
o1 := Clause(
SortNone(),
)
if o1.GetSort() != None {
t.Error()
}
o2 := Clause()
if o2.GetSort() != None {
t.Error()
}
}
func TestRaw(t *testing.T) {
o1 := Clause()
if o1.IsRaw() {
t.Error()
}
o2 := Clause(
Raw(),
)
if !o2.IsRaw() {
t.Error()
}
}
func TestColumn(t *testing.T) {
o1 := Clause(
Column(`aaa`),
)
if o1.GetColumn() != `aaa` {
t.Error()
}
}
func TestParseOrder(t *testing.T) {
orders := ParseOrder(
`-user__status`,
`status`,
`user__status`,
)
t.Log(orders)
if orders[0].GetSort() != Descending {
t.Error()
}
if orders[0].GetColumn() != `user.status` {
t.Error()
}
if orders[1].GetColumn() != `status` {
t.Error()
}
if orders[1].GetSort() != Ascending {
t.Error()
}
if orders[2].GetColumn() != `user.status` {
t.Error()
}
}
func TestOrder_GetColumn(t *testing.T) {
o := Clause(
Column(`user__id`),
)
if o.GetColumn() != `user.id` {
t.Error()
}
}
func TestOrder_GetSort(t *testing.T) {
o := Clause(
SortDescending(),
)
if o.GetSort() != Descending {
t.Error()
}
}
func TestOrder_IsRaw(t *testing.T) {
o1 := Clause()
if o1.IsRaw() {
t.Error()
}
o2 := Clause(
Raw(),
)
if !o2.IsRaw() {
t.Error()
}
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"flag"
"fmt"
"os"
@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf(" %s\n", err.Error())
}
ctx := context.Background()
for i, mi := range modelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error {
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
if err != nil {
if d.rtOnError {
return err
@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error {
}
for _, idx := range indexes[mi.table] {
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"errors"
"fmt"
@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote()
dbcols := make([]string, 0, len(mi.fields.dbcols))
@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query)
stmt, err := q.PrepareContext(ctx, query)
return stmt, query, err
}
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return 0, err
@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
err := row.Scan(&id)
return id, err
}
res, err := stmt.Exec(values...)
res, err := stmt.ExecContext(ctx, values...)
if err == nil {
return res.LastInsertId()
}
@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string
var args []interface{}
@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...)
row := q.QueryRowContext(ctx, query, args...)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
}
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols))
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil {
return 0, err
}
id, err := d.InsertValue(q, mi, false, names, values)
id, err := d.InsertValue(ctx, q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return id, err
}
// multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return cnt, err
@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
}
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err
@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
// InsertOrUpdate a row
// If your primary key or unique column conflict will update
// If no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := ""
iouStr := ""
argsMap := map[string]string{}
@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if !ok {
return 0, ErrMissPK
@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, setValues...)
res, err := q.ExecContext(ctx, query, setValues...)
if err == nil {
return res.RowsAffected()
}
@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
var whereCols []string
var args []interface{}
// if specify cols length > 0, then use it for where condition.
@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
return res.RowsAffected()
}
@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
}
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins)
tables.skipEnd = true
@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
r, err := q.Query(query, args...)
r, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query)
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
errTyp := true
unregister := true
one := true
isPtr := true
name := ""
if val.Kind() == reflect.Ptr {
fn := ""
@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
case reflect.Struct:
isPtr = false
fn = getFullName(typ)
name = getTableName(reflect.New(typ))
}
} else {
fn = getFullName(ind.Type())
name = getTableName(ind)
}
errTyp = fn != mi.fullName
unregister = fn != mi.fullName
}
if errTyp {
if one {
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
} else {
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
}
if unregister {
RegisterModel(container)
}
rlimit := qs.limit
@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if qs.distinct {
sqlSelect += " DISTINCT"
}
if qs.aggregate != "" {
sels = qs.aggregate
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
var err error
if qs != nil && qs.forContext {
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rs.Close()
slice := ind
if unregister {
mi, _ = modelCache.get(name)
tCols = mi.fields.dbcols
colsNum = len(tCols)
}
refs := make([]interface{}, colsNum)
@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
var ref interface{}
refs[i] = &ref
}
defer rs.Close()
slice := ind
var cnt int64
for rs.Next() {
if one && cnt == 0 || !one {
@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query)
var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
row := q.QueryRowContext(ctx, query, args...)
err = row.Scan(&cnt)
return
}
@ -1649,7 +1631,7 @@ setValue:
}
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
maps []Params
@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
d.ins.ReplaceMarks(&query)
rs, err := q.Query(query, args...)
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
}
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
}
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return columns, err
}
@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
}
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
panic(ErrNotImplement)
}

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"reflect"
"strings"
@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
// If your primary key or unique column conflict will update
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
argsMap := map[string]string{}
@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
return id, err

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"strings"
@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
}
// check index is exist
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
}
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"strconv"
)
@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
}
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
if _, err := db.ExecContext(ctx, query); err != nil {
return err
}
}
@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string {
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
row := db.QueryRowContext(ctx, query)
var cnt int
row.Scan(&cnt)
return cnt > 0

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"fmt"
"reflect"
@ -73,11 +74,11 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(q, mi, ind, tz, cols, false)
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
}
// get sqlite operator.
@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
panic(err)
}

View File

@ -16,6 +16,8 @@ package orm
import (
"fmt"
"github.com/beego/beego/v2/client/orm/clauses"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"strings"
"time"
)
@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
}
// generate order sql.
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
if len(orders) == 0 {
return
}
@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
column := order.GetColumn()
clause := strings.Split(column, clauses.ExprDot)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
if order.IsRaw() {
if len(clause) == 2 {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
} else if len(clause) == 1 {
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
} else {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
} else {
index, _, fi, suc := t.parseExprs(t.mi, clause)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
}
}
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
)
@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)

View File

@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
return nil
}
@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
return nil
}

View File

@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, o.Driver())
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
assert.Nil(t, o.QueryM2M(nil, ""))
assert.Nil(t, o.ReadWithCtx(nil, nil))
assert.Nil(t, o.Read(nil))
@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, int64(0), i)
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
assert.Nil(t, o.QueryTable(nil))
assert.Nil(t, o.Read(nil))

View File

@ -27,7 +27,7 @@ import (
// this Filter's behavior looks a little bit strange
// for example:
// if we want to trace QuerySetter
// actually we trace invoking "QueryTable" and "QueryTableWithCtx"
// actually we trace invoking "QueryTable"
// the method Begin*, Commit and Rollback are ignored.
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
type FilterChainBuilder struct {

View File

@ -31,7 +31,7 @@ import (
// this Filter's behavior looks a little bit strange
// for example:
// if we want to records the metrics of QuerySetter
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx"
// actually we only records metrics of invoking "QueryTable"
type FilterChainBuilder struct {
summaryVec prometheus.ObserverVec
AppName string

View File

@ -20,6 +20,7 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/core/utils"
)
@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
}
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
return f.QueryM2MWithCtx(context.Background(), md, name)
}
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, _ := modelCache.getByMd(md)
inv := &Invocation{
Method: "QueryM2MWithCtx",
Method: "QueryM2M",
Args: []interface{}{md, name},
Md: md,
mi: mi,
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
f: func(c context.Context) []interface{} {
res := f.ormer.QueryM2MWithCtx(c, md, name)
res := f.ormer.QueryM2M(md, name)
return []interface{}{res}
},
}
res := f.root(ctx, inv)
res := f.root(context.Background(), inv)
if res[0] == nil {
return nil
}
return res[0].(QueryM2Mer)
}
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
// NOTE: this method is deprecated, context parameter will not take effect.
func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.")
return f.QueryM2M(md, name)
}
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
var (
name string
md interface{}
@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
}
inv := &Invocation{
Method: "QueryTableWithCtx",
Method: "QueryTable",
Args: []interface{}{ptrStructOrTableName},
InsideTx: f.insideTx,
TxStartTime: f.txStartTime,
Md: md,
mi: mi,
f: func(c context.Context) []interface{} {
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
res := f.ormer.QueryTable(ptrStructOrTableName)
return []interface{}{res}
},
}
res := f.root(ctx, inv)
res := f.root(context.Background(), inv)
if res[0] == nil {
return nil
@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
return res[0].(QuerySeter)
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.")
return f.QueryTable(ptrStructOrTableName)
}
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
inv := &Invocation{
Method: "DBStats",

View File

@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
assert.Equal(t, "QueryM2M", inv.Method)
assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)
@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) {
o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryTableWithCtx", inv.Method)
assert.Equal(t, "QueryTable", inv.Method)
assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx)

View File

@ -332,10 +332,6 @@ end:
// register register models to model cache
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
if mc.done {
err = fmt.Errorf("register must be run before BootStrap")
return
}
for _, model := range models {
val := reflect.ValueOf(model)
@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
return
}
if val.Elem().Kind() == reflect.Slice {
val = reflect.New(val.Elem().Type().Elem())
}
table := getTableName(val)
if prefixOrSuffixStr != "" {
@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
}
if _, ok := mc.get(table); ok {
err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
return
return nil
}
mi := newModelInfo(val)
@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
}
}
}
if mi.fields.pk == nil {
err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
return
}
}
mi.table = table

View File

@ -255,6 +255,22 @@ func NewTM() *TM {
return obj
}
type DeptInfo struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type UnregisterModel struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type User struct {
ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"`
@ -476,45 +492,45 @@ var (
helpinfo = `need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/beego/beego/v2/client/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/beego/beego/v2/client/orm
#### Sqlite3
export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/beego/beego/v2/client/orm
#### PostgreSQL
psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/beego/beego/v2/client/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/beego/beego/v2/pgk/orm
`
)

View File

@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string {
// get whether the table needs to be created for the database alias
func isApplicableTableForDB(val reflect.Value, db string) bool {
if !val.IsValid() {
return true
}
fun := val.MethodByName("IsApplicableTableForDB")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})

View File

@ -58,6 +58,7 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"os"
"reflect"
"time"
@ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error {
}
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
}
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
@ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
}
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true)
}
// Try to read a row from the database, or insert one if it doesn't exist
@ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
if err == ErrNoRows {
// Create
id, err := o.InsertWithCtx(ctx, md)
@ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) {
}
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
@ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
for i := 0; i < sind.Len(); i++ {
ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
@ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
@ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (
}
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...)
if err != nil {
return id, err
}
@ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
}
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols)
}
// delete model in database
@ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
}
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}
@ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str
// create a models to models queryer
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
return o.QueryM2MWithCtx(context.Background(), md, name)
}
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
@ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
return newQueryM2M(md, o, mi, fi, ind)
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.")
return o.QueryM2M(md, name)
}
// load related models to md model.
// args are limit, offset int and order string.
//
@ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s
qs.relDepth = relDepth
if len(order) > 0 {
qs.orders = []string{order}
qs.orders = order_clause.ParseOrder(order)
}
find := ind.FieldByIndex(fi.fieldIndex)
@ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
}
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table)
@ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in
if qs == nil {
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
}
return
return qs
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.")
return o.QueryTable(ptrStructOrTableName)
}
// return a raw query seter for raw sql string.
@ -595,9 +602,8 @@ func NewOrm() Ormer {
func NewOrmUsingDB(aliasName string) Ormer {
if al, ok := dataBaseCache.get(aliasName); ok {
return newDBWithAlias(al)
} else {
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
}
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
}
// NewOrmWithDB create a new ormer object with specify *sql.DB for query

View File

@ -16,12 +16,13 @@ package orm
import (
"fmt"
"github.com/beego/beego/v2/client/orm/clauses"
"strings"
)
// ExprSep define the expression separation
const (
ExprSep = "__"
ExprSep = clauses.ExprSep
)
type condValue struct {

View File

@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error {
}
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
return d.ExecContext(context.Background(), args...)
}
func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.stmt.Exec(args...)
res, err := d.stmt.ExecContext(ctx, args...)
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
return d.QueryContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.stmt.Query(args...)
res, err := d.stmt.QueryContext(ctx, args...)
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
return d.QueryRowContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
a := time.Now()
res := d.stmt.QueryRow(args...)
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)

View File

@ -15,6 +15,7 @@
package orm
import (
"context"
"fmt"
"reflect"
)
@ -31,6 +32,10 @@ var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) {
return o.InsertWithCtx(context.Background(), md)
}
func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if name != o.mi.fullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}
@ -70,11 +75,11 @@ func (o *insertSet) Close() error {
}
// create new insert queryer.
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) {
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
if err != nil {
return nil, err
}

View File

@ -14,7 +14,10 @@
package orm
import "reflect"
import (
"context"
"reflect"
)
// model to model struct
type queryM2M struct {
@ -33,6 +36,10 @@ type queryM2M struct {
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
return o.AddWithCtx(context.Background(), mds...)
}
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo
@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
}
names = append(names, otherNames...)
values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values)
return dbase.InsertValue(ctx, orm.db, mi, true, names, values)
}
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return o.RemoveWithCtx(context.Background(), mds...)
}
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool {
return o.ExistWithCtx(context.Background(), md)
}
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
}
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) {
return o.ClearWithCtx(context.Background())
}
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
}
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) {
return o.CountWithCtx(context.Background())
}
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
}
var _ QueryM2Mer = new(queryM2M)

View File

@ -18,6 +18,7 @@ import (
"context"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints"
)
@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
forUpdate bool
useIndex int
indexes []string
orm *ormBase
ctx context.Context
forContext bool
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []*order_clause.Order
distinct bool
forUpdate bool
useIndex int
indexes []string
orm *ormBase
aggregate string
}
var _ QuerySeter = new(querySet)
@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter {
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
func (o querySet) OrderBy(expressions ...string) QuerySeter {
if len(expressions) <= 0 {
return &o
}
o.orders = order_clause.ParseOrder(expressions...)
return &o
}
// add ORDER expression.
func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter {
if len(orders) <= 0 {
return &o
}
o.orders = orders
return &o
}
@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition {
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.CountWithCtx(context.Background())
}
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.ExistWithCtx(context.Background())
}
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
return o.UpdateWithCtx(context.Background(), values)
}
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return o.DeleteWithCtx(context.Background())
}
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// return a insert queryer.
@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) {
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
return o.PrepareInsertWithCtx(context.Background())
}
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
return newInsertSet(ctx, o.orm, o.mi)
}
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
return o.AllWithCtx(context.Background(), container, cols...)
}
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...)
}
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
return err
}
@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error {
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
return o.ValuesWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
return o.ValuesListWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
return o.ValuesFlatWithCtx(context.Background(), result, expr)
}
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// query all rows into map[string]interface with specify key and value column name.
@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement)
}
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter.
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o := new(querySet)
@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o.orm = orm
return o
}
// aggregate func
func (o querySet) Aggregate(s string) QuerySeter {
o.aggregate = s
return &o
}

View File

@ -21,6 +21,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"io/ioutil"
"math"
"os"
@ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Index))
RegisterModel(new(StrPk))
RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
@ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Index))
RegisterModel(new(StrPk))
RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
BootStrap()
@ -333,6 +336,73 @@ func TestTM(t *testing.T) {
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
}
func TestUnregisterModel(t *testing.T) {
data := []*DeptInfo{
{
DeptName: "A",
EmployeeName: "A1",
Salary: 1000,
},
{
DeptName: "A",
EmployeeName: "A2",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B1",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B2",
Salary: 4000,
},
{
DeptName: "B",
EmployeeName: "B3",
Salary: 3000,
},
}
qs := dORM.QueryTable("dept_info")
i, _ := qs.PrepareInsert()
for _, d := range data {
_, err := i.Insert(d)
if err != nil {
throwFail(t, err)
}
}
f := func() {
var res []UnregisterModel
n, err := dORM.QueryTable("dept_info").All(&res)
throwFail(t, err)
throwFail(t, AssertIs(n, 5))
throwFail(t, AssertIs(res[0].EmployeeName, "A1"))
type Sum struct {
DeptName string
Total int
}
var sun []Sum
qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun)
throwFail(t, AssertIs(sun[0].DeptName, "A"))
throwFail(t, AssertIs(sun[0].Total, 3000))
type Max struct {
DeptName string
Max float64
}
var max []Max
qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max)
throwFail(t, AssertIs(max[1].DeptName, "B"))
throwFail(t, AssertIs(max[1].Max, 4000))
}
for i := 0; i < 5; i++ {
f()
}
}
func TestNullDataTypes(t *testing.T) {
d := DataNull{}
@ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) {
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`profile__age`),
order_clause.SortDescending(),
),
).Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
if IsMysql {
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`rand()`),
order_clause.Raw(),
),
).Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
}
func TestAll(t *testing.T) {
@ -1163,6 +1253,19 @@ func TestValues(t *testing.T) {
throwFail(t, AssertIs(maps[2]["Profile"], nil))
}
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column("Id"),
order_clause.SortAscending(),
),
).Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
if num == 3 {
throwFail(t, AssertIs(maps[0]["UserName"], "slene"))
throwFail(t, AssertIs(maps[2]["Profile"], nil))
}
num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age")
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
@ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) {
throwFail(t, AssertIs(!cycleFlag, true))
return
}
func TestContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
user := User{UserName: "slene"}
err := dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, err)
cancel()
err = dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, AssertIs(err, context.Canceled))
ctx, cancel = context.WithCancel(context.Background())
cancel()
qs := dORM.QueryTable(user)
_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx)
throwFail(t, AssertIs(err, context.Canceled))
}

View File

@ -20,6 +20,7 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/core/utils"
)
@ -196,12 +197,16 @@ type DQL interface {
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
DBStats() *sql.DBStats
@ -230,6 +235,7 @@ type TxOrmer interface {
// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
InsertWithCtx(context.Context, interface{}) (int64, error)
Close() error
}
@ -289,6 +295,28 @@ type QuerySeter interface {
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// add ORDER expression by order clauses
// for example:
// OrderClauses(
// order_clause.Clause(
// order.Column("Id"),
// order.SortAscending(),
// ),
// order_clause.Clause(
// order.Column("status"),
// order.SortDescending(),
// ),
// )
// OrderClauses(order_clause.Clause(
// order_clause.Column(`user__status`),
// order_clause.SortDescending(),//default None
// ))
// OrderClauses(order_clause.Clause(
// order_clause.Column(`random()`),
// order_clause.SortNone(),//default None
// order_clause.Raw(),//default false.if true, do not check field is valid or not
// ))
OrderClauses(orders ...*order_clause.Order) QuerySeter
// add FORCE INDEX expression.
// for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`)
@ -327,9 +355,11 @@ type QuerySeter interface {
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
// check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool
ExistWithCtx(context.Context) bool
// execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
@ -339,11 +369,13 @@ type QuerySeter interface {
// "user_name": "slene2"
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
// delete from table
// for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error)
DeleteWithCtx(context.Context) (int64, error)
// return a insert queryer.
// it can be used in times.
// example:
@ -352,18 +384,21 @@ type QuerySeter interface {
// num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
PrepareInsertWithCtx(context.Context) (Inserter, error)
// query all data and map to containers.
// cols means the columns when querying.
// for example:
// var users []*User
// qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error)
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
// query one row data and map to containers.
// cols means the columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
@ -371,18 +406,21 @@ type QuerySeter interface {
// var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
@ -405,6 +443,15 @@ type QuerySeter interface {
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// aggregate func.
// for example:
// type result struct {
// DeptName string
// Total int
// }
// var res []result
// o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res)
Aggregate(s string) QuerySeter
}
// QueryM2Mer model to model query struct
@ -422,18 +469,23 @@ type QueryM2Mer interface {
// insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
AddWithCtx(context.Context, ...interface{}) (int64, error)
// remove models following the origin model relationship
// only delete rows from m2m table
// for example:
// tag3 := &Tag{Id:5,Name: "TestTag3"}
// num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
// check model is existed in relationship of origin model
Exist(interface{}) bool
ExistWithCtx(context.Context, interface{}) bool
// clean all models in related of origin model
Clear() (int64, error)
ClearWithCtx(context.Context) (int64, error)
// count all related models of origin model
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
}
// RawPreparer raw query statement
@ -507,11 +559,11 @@ type RawSeter interface {
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
// QueryContext(args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
}
// db querier
@ -548,28 +600,28 @@ type txEnder interface {
// base database struct
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
@ -578,12 +630,12 @@ type dbBaser interface {
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error)
GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
IndexExists(context.Context, dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
setval(context.Context, dbQuerier, *modelInfo, []string) error
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
}

2
go.mod
View File

@ -25,7 +25,7 @@ require (
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-cmp v0.5.0 // indirect
github.com/google/uuid v1.1.1 // indirect
github.com/google/uuid v1.1.1
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/hashicorp/golang-lru v0.5.4
github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6

View File

@ -112,13 +112,13 @@ func registerAdmin() error {
HttpServer: NewHttpServerWithCfg(BConfig),
}
// keep in mind that all data should be html escaped to avoid XSS attack
beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex"))
beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex"))
beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex"))
beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck"))
beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus"))
beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf"))
beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics"))
beeAdminApp.Router("/", c, "get:AdminIndex")
beeAdminApp.Router("/qps", c, "get:QpsIndex")
beeAdminApp.Router("/prof", c, "get:ProfIndex")
beeAdminApp.Router("/healthcheck", c, "get:Healthcheck")
beeAdminApp.Router("/task", c, "get:TaskStatus")
beeAdminApp.Router("/listconf", c, "get:ListConf")
beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics")
go beeAdminApp.Run()
}

View File

@ -29,6 +29,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/beego/beego/v2/server/web/session"
"net"
"net/http"
"strconv"
@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) {
}
}
// Session return session store of this context of request
func (ctx *Context) Session() (store session.Store, err error) {
if ctx.Input != nil {
if ctx.Input.CruSession != nil {
store = ctx.Input.CruSession
return
} else {
err = errors.New(`no valid session store(please initialize session)`)
return
}
} else {
err = errors.New(`no valid input`)
return
}
}
// Response is a wrapper for the http.ResponseWriter
// Started: if true, response was already written to so the other handler will not be executed
type Response struct {

View File

@ -15,6 +15,7 @@
package context
import (
"github.com/beego/beego/v2/server/web/session"
"net/http"
"net/http/httptest"
"testing"
@ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) {
t.FailNow()
}
}
func TestContext_Session(t *testing.T) {
c := NewContext()
if store, err := c.Session(); store != nil || err == nil {
t.FailNow()
}
}
func TestContext_Session1(t *testing.T) {
c := Context{}
if store, err := c.Session(); store != nil || err == nil {
t.FailNow()
}
}
func TestContext_Session2(t *testing.T) {
c := NewContext()
c.Input.CruSession = &session.MemSessionStore{}
if store, err := c.Session(); store == nil || err != nil {
t.FailNow()
}
}

View File

@ -18,6 +18,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -37,4 +38,5 @@ func TestFilterChain(t *testing.T) {
ctx.Input.SetData("RouterPattern", "my-route")
filter(ctx)
assert.True(t, ctx.Input.GetData("invocation").(bool))
time.Sleep(1 * time.Second)
}

View File

@ -0,0 +1,35 @@
package session
import (
"context"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/server/web"
webContext "github.com/beego/beego/v2/server/web/context"
"github.com/beego/beego/v2/server/web/session"
)
//Session maintain session for web service
//Session new a session storage and store it into webContext.Context
func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain {
sessionConfig := session.NewManagerConfig(options...)
sessionManager, _ := session.NewManager(string(providerType), sessionConfig)
go sessionManager.GC()
return func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if ctx.Input.CruSession != nil {
return
}
if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil {
logs.Error(`init session error:%s`, err.Error())
} else {
//release session at the end of request
defer sess.SessionRelease(context.Background(), ctx.ResponseWriter)
ctx.Input.CruSession = sess
}
next(ctx)
}
}
}

View File

@ -0,0 +1,86 @@
package session
import (
"github.com/beego/beego/v2/server/web"
webContext "github.com/beego/beego/v2/server/web/context"
"github.com/beego/beego/v2/server/web/session"
"github.com/google/uuid"
"net/http"
"net/http/httptest"
"testing"
)
func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) {
r, _ := http.NewRequest(method, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code)
}
}
func TestSession(t *testing.T) {
storeKey := uuid.New().String()
handler := web.NewControllerRegister()
handler.InsertFilterChain(
"*",
Session(
session.ProviderMemory,
session.CfgCookieName(`go_session_id`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
),
)
handler.InsertFilterChain(
"*",
func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if store := ctx.Input.GetData(storeKey); store == nil {
t.Error(`store should not be nil`)
}
next(ctx)
}
},
)
handler.Any("*", func(ctx *webContext.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
}
func TestSession1(t *testing.T) {
handler := web.NewControllerRegister()
handler.InsertFilterChain(
"*",
Session(
session.ProviderMemory,
session.CfgCookieName(`go_session_id`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
),
)
handler.InsertFilterChain(
"*",
func(next web.FilterFunc) web.FilterFunc {
return func(ctx *webContext.Context) {
if store, err := ctx.Session(); store == nil || err != nil {
t.Error(`store should not be nil`)
}
next(ctx)
}
},
)
handler.Any("*", func(ctx *webContext.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "/dataset1/resource1", "GET", 200)
}

View File

@ -15,9 +15,12 @@
package web
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello"))
_ = ctx.Output.Body([]byte("hello"))
})
r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
}
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
r, _ := http.NewRequest("GET", "/abc", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
first := w.Header().Get("first")
second := w.Header().Get("second")
ft, _ := strconv.ParseInt(first, 10, 64)
st, _ := strconv.ParseInt(second, 10, 64)
assert.True(t, st > ft)
}

View File

@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) {
// setup the handler
handler := NewControllerRegister()
handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
handler.ServeHTTP(w, r)
// get the Set-Cookie value

View File

@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
// Router same as beego.Rourer
// refer: https://godoc.org/github.com/beego/beego/v2#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...))
n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...))
return n
}

View File

@ -121,24 +121,30 @@ type ControllerInfo struct {
sessionOn bool
}
type ControllerOptions func(*ControllerInfo)
type ControllerOption func(*ControllerInfo)
func (c *ControllerInfo) GetPattern() string {
return c.pattern
}
func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions {
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
return func(c *ControllerInfo) {
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
}
}
func SetRouterSessionOn(sessionOn bool) ControllerOptions {
func WithRouterSessionOn(sessionOn bool) ControllerOption {
return func(c *ControllerInfo) {
c.sessionOn = sessionOn
}
}
type filterChainConfig struct {
pattern string
chain FilterChain
opts []FilterOpt
}
// ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct {
routers map[string]*Tree
@ -151,6 +157,9 @@ type ControllerRegister struct {
// the filter created by FilterChain
chainRoot *FilterRouter
// keep registered chain and build it when serve http
filterChains []filterChainConfig
cfg *Config
}
@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
},
},
cfg: cfg,
filterChains: make([]filterChainConfig, 0, 4),
}
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
return res
}
// Init will be executed when HttpServer start running
func (p *ControllerRegister) Init() {
for i := len(p.filterChains) - 1; i >= 0 ; i -- {
fc := p.filterChains[i]
root := p.chainRoot
filterFunc := fc.chain(root.filterFunc)
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
p.chainRoot.next = root
}
}
// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
@ -186,7 +207,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
// Add("/api/delete",&RestController{},"delete:DeleteFood")
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) {
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) {
p.addWithMethodParams(pattern, c, nil, opts...)
}
@ -239,7 +260,7 @@ func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) {
}
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) {
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
@ -311,7 +332,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
}
p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
}
}
}
@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// }
// }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
p.chainRoot.next = root
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.filterChains = append(p.filterChains, filterChainConfig{
pattern: pattern,
chain: chain,
opts: opts,
})
}
// add Filter into
@ -583,7 +605,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str
for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) {
find := false
if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 {

View File

@ -26,6 +26,14 @@ import (
"github.com/beego/beego/v2/server/web/context"
)
type PrefixTestController struct {
Controller
}
func (ptc *PrefixTestController) PrefixList() {
ptc.Ctx.Output.Body([]byte("i am list in prefix test"))
}
type TestController struct {
Controller
}
@ -87,10 +95,24 @@ func (jc *JSONController) Get() {
jc.Ctx.Output.Body([]byte("ok"))
}
func TestPrefixUrlFor(t *testing.T){
handler := NewControllerRegister()
handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList"))
if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` {
logs.Info(a)
t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list")
}
if a := handler.URLFor(`TestController.PrefixList`); a != `` {
logs.Info(a)
t.Errorf("TestController.PrefixList must equal to empty string")
}
}
func TestUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param"))
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
if a := handler.URLFor("TestController.List"); a != "/api/list" {
logs.Info(a)
t.Errorf("TestController.List must equal to /api/list")
@ -113,9 +135,9 @@ func TestUrlFor3(t *testing.T) {
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL"))
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param"))
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL"))
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
logs.Info(handler.URLFor("TestController.GetURL"))
@ -145,7 +167,7 @@ func TestUserFunc(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List"))
handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
@ -235,7 +257,7 @@ func TestRouteOk(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams"))
handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams"))
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != "anderson+thomas+kungfu" {
@ -249,7 +271,7 @@ func TestManyRoute(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter"))
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter"))
handler.ServeHTTP(w, r)
body := w.Body.String()
@ -266,7 +288,7 @@ func TestEmptyResponse(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody"))
handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody"))
handler.ServeHTTP(w, r)
if body := w.Body.String(); body != "" {
@ -761,8 +783,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ := http.NewRequest("GET", "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(false))
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(false))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
@ -772,8 +794,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(true))
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(true))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
@ -787,8 +809,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(false))
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(false))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed")
@ -798,8 +820,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder()
handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(true))
handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
WithRouterSessionOn(true))
handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") == "" {
t.Errorf("TestRotuerSessionSet failed")

View File

@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
// init...
app.initAddr(addr)
app.Handlers.Init()
addr = app.Cfg.Listen.HTTPAddr
@ -266,8 +268,12 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
}
// Router see HttpServer.Router
func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer {
return BeeApp.Router(rootpath, c, opts...)
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...))
}
func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
return BeeApp.RouterWithOpts(rootpath, c, opts...)
}
// Router adds a patterned controller handler to BeeApp.
@ -286,7 +292,11 @@ func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer {
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...))
}
func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
app.Handlers.Add(rootPath, c, opts...)
return app
}

View File

@ -15,21 +15,22 @@ import (
)
func TestRedis(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
}
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig(redisConfig),
)
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Fatal("could not create manager:", err)

View File

@ -13,15 +13,15 @@ import (
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
)
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)

View File

@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) {
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
CookieSameSite http.SameSite `json:"cookieSameSite"`
}
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider

View File

@ -0,0 +1,143 @@
package session
import "net/http"
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
CookieSameSite http.SameSite `json:"cookieSameSite"`
}
func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) {
for _, opt := range opts {
opt(c)
}
}
type ManagerConfigOpt func(config *ManagerConfig)
func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig {
config := &ManagerConfig{}
for _, opt := range opts {
opt(config)
}
return config
}
// CfgCookieName set key of session id
func CfgCookieName(cookieName string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieName = cookieName
}
}
// CfgCookieName set len of session id
func CfgSessionIdLength(len int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionIDLength = len
}
}
// CfgSessionIdPrefix set prefix of session id
func CfgSessionIdPrefix(prefix string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionIDPrefix = prefix
}
}
//CfgSetCookie whether set `Set-Cookie` header in HTTP response
func CfgSetCookie(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSetCookie = enable
}
}
//CfgGcLifeTime set session gc lift time
func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Gclifetime = lifeTime
}
}
//CfgMaxLifeTime set session lift time
func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Maxlifetime = lifeTime
}
}
//CfgGcLifeTime set session lift time
func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieLifeTime = lifeTime
}
}
//CfgProviderConfig configure session provider
func CfgProviderConfig(providerConfig string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.ProviderConfig = providerConfig
}
}
//CfgDomain set cookie domain
func CfgDomain(domain string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Domain = domain
}
}
//CfgSessionIdInHTTPHeader enable session id in http header
func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSidInHTTPHeader = enable
}
}
//CfgSetSessionNameInHTTPHeader set key of session id in http header
func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.SessionNameInHTTPHeader = name
}
}
//EnableSidInURLQuery enable session id in query string
func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.EnableSidInURLQuery = enable
}
}
//DisableHTTPOnly set HTTPOnly for http.Cookie
func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.DisableHTTPOnly = !HTTPOnly
}
}
//CfgSecure set Secure for http.Cookie
func CfgSecure(Enable bool) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.Secure = Enable
}
}
//CfgSameSite set http.SameSite
func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt {
return func(config *ManagerConfig) {
config.CookieSameSite = sameSite
}
}

View File

@ -0,0 +1,222 @@
package session
import (
"net/http"
"testing"
)
func TestCfgCookieLifeTime(t *testing.T) {
value := 8754
c := NewManagerConfig(
CfgCookieLifeTime(value),
)
if c.CookieLifeTime != value {
t.Error()
}
}
func TestCfgDomain(t *testing.T) {
value := `http://domain.com`
c := NewManagerConfig(
CfgDomain(value),
)
if c.Domain != value {
t.Error()
}
}
func TestCfgSameSite(t *testing.T) {
value := http.SameSiteLaxMode
c := NewManagerConfig(
CfgSameSite(value),
)
if c.CookieSameSite != value {
t.Error()
}
}
func TestCfgSecure(t *testing.T) {
c := NewManagerConfig(
CfgSecure(true),
)
if c.Secure != true {
t.Error()
}
}
func TestCfgSecure1(t *testing.T) {
c := NewManagerConfig(
CfgSecure(false),
)
if c.Secure != false {
t.Error()
}
}
func TestCfgSessionIdPrefix(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgSessionIdPrefix(value),
)
if c.SessionIDPrefix != value {
t.Error()
}
}
func TestCfgSetSessionNameInHTTPHeader(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgSetSessionNameInHTTPHeader(value),
)
if c.SessionNameInHTTPHeader != value {
t.Error()
}
}
func TestCfgCookieName(t *testing.T) {
value := `sodiausodkljalsd`
c := NewManagerConfig(
CfgCookieName(value),
)
if c.CookieName != value {
t.Error()
}
}
func TestCfgEnableSidInURLQuery(t *testing.T) {
c := NewManagerConfig(
CfgEnableSidInURLQuery(true),
)
if c.EnableSidInURLQuery != true {
t.Error()
}
}
func TestCfgGcLifeTime(t *testing.T) {
value := int64(5454)
c := NewManagerConfig(
CfgGcLifeTime(value),
)
if c.Gclifetime != value {
t.Error()
}
}
func TestCfgHTTPOnly(t *testing.T) {
c := NewManagerConfig(
CfgHTTPOnly(true),
)
if c.DisableHTTPOnly != false {
t.Error()
}
}
func TestCfgHTTPOnly2(t *testing.T) {
c := NewManagerConfig(
CfgHTTPOnly(false),
)
if c.DisableHTTPOnly != true {
t.Error()
}
}
func TestCfgMaxLifeTime(t *testing.T) {
value := int64(5454)
c := NewManagerConfig(
CfgMaxLifeTime(value),
)
if c.Maxlifetime != value {
t.Error()
}
}
func TestCfgProviderConfig(t *testing.T) {
value := `asodiuasldkj12i39809as`
c := NewManagerConfig(
CfgProviderConfig(value),
)
if c.ProviderConfig != value {
t.Error()
}
}
func TestCfgSessionIdInHTTPHeader(t *testing.T) {
c := NewManagerConfig(
CfgSessionIdInHTTPHeader(true),
)
if c.EnableSidInHTTPHeader != true {
t.Error()
}
}
func TestCfgSessionIdInHTTPHeader1(t *testing.T) {
c := NewManagerConfig(
CfgSessionIdInHTTPHeader(false),
)
if c.EnableSidInHTTPHeader != false {
t.Error()
}
}
func TestCfgSessionIdLength(t *testing.T) {
value := int64(100)
c := NewManagerConfig(
CfgSessionIdLength(value),
)
if c.SessionIDLength != value {
t.Error()
}
}
func TestCfgSetCookie(t *testing.T) {
c := NewManagerConfig(
CfgSetCookie(true),
)
if c.EnableSetCookie != true {
t.Error()
}
}
func TestCfgSetCookie1(t *testing.T) {
c := NewManagerConfig(
CfgSetCookie(false),
)
if c.EnableSetCookie != false {
t.Error()
}
}
func TestNewManagerConfig(t *testing.T) {
c := NewManagerConfig()
if c == nil {
t.Error()
}
}
func TestManagerConfig_Opts(t *testing.T) {
c := NewManagerConfig()
c.Opts(CfgSetCookie(true))
if c.EnableSetCookie != true {
t.Error()
}
}

View File

@ -0,0 +1,18 @@
package session
type ProviderType string
const (
ProviderCookie ProviderType = `cookie`
ProviderFile ProviderType = `file`
ProviderMemory ProviderType = `memory`
ProviderCouchbase ProviderType = `couchbase`
ProviderLedis ProviderType = `ledis`
ProviderMemcache ProviderType = `memcache`
ProviderMysql ProviderType = `mysql`
ProviderPostgresql ProviderType = `postgresql`
ProviderRedis ProviderType = `redis`
ProviderRedisCluster ProviderType = `redis_cluster`
ProviderRedisSentinel ProviderType = `redis_sentinel`
ProviderSsdb ProviderType = `ssdb`
)

View File

@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler, "Test original root",
@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
// Replace the root path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
// Test replacement root (expect change)
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler,
@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
// Replace the "level1" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root
testHelperFnContentCheck(t, handler,
@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
// Replace the "/level1/level2" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)

7
sonar-project.properties Normal file
View File

@ -0,0 +1,7 @@
sonar.organization=beego
sonar.projectKey=beego_beego
# relative paths to source directories. More details and properties are described
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
sonar.sources=.
sonar.exclusions=**/*_test.go

View File

@ -109,6 +109,23 @@ func TestTask_Run(t *testing.T) {
assert.Equal(t, "Hello, world! 101", l[1].errinfo)
}
func TestCrudTask(t *testing.T) {
m := newTaskManager()
m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error {
return nil
}))
m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error {
return nil
}))
m.DeleteTask("my-task1")
assert.Equal(t, 1, len(m.adminTaskList))
m.ClearTask()
assert.Equal(t, 0, len(m.adminTaskList))
}
func wait(wg *sync.WaitGroup) chan bool {
ch := make(chan bool)
go func() {