refactor code

This commit is contained in:
Jason li 2021-01-12 11:34:40 +08:00
commit 8e2cd8f85f
23 changed files with 391 additions and 261 deletions

View File

@ -8,4 +8,5 @@
- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386)
- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) - Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294)
- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) - Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404)
- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) - Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416)
- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424)

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf(" %s\n", err.Error()) fmt.Printf(" %s\n", err.Error())
} }
ctx := context.Background()
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) { if !isApplicableTableForDB(mi.addrField, d.al.Name) {
@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error {
} }
var fields []*fieldInfo var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table) columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
if err != nil { if err != nil {
if d.rtOnError { if d.rtOnError {
return err return err
@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error {
} }
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo { if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
} }

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
// create insert sql preparation statement object. // create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
dbcols := make([]string, 0, len(mi.fields.dbcols)) dbcols := make([]string, 0, len(mi.fields.dbcols))
@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.HasReturningID(mi, &query) d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query) stmt, err := q.PrepareContext(ctx, query)
return stmt, query, err return stmt, query, err
} }
// insert struct with prepared statement and given struct reflect value. // insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err
} }
res, err := stmt.Exec(values...) res, err := stmt.ExecContext(ctx, values...)
if err == nil { if err == nil {
return res.LastInsertId() return res.LastInsertId()
} }
@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
} }
// query sql ,read records and persist in dbBaser. // query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...) row := q.QueryRowContext(ctx, query, args...)
if err := row.Scan(refs...); err != nil { if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return ErrNoRows return ErrNoRows
@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
} }
// execute insert sql dbQuerier with given struct reflect.Value. // execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols)) names := make([]string, 0, len(mi.fields.dbcols))
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
id, err := d.InsertValue(q, mi, false, names, values) id, err := d.InsertValue(ctx, q, mi, false, names, values)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if len(autoFields) > 0 { if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields) err = d.ins.setval(ctx, q, mi, autoFields)
} }
return id, err return id, err
} }
// multi-insert sql with given slice struct reflect.Value. // multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var ( var (
cnt int64 cnt int64
nums int nums int
@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
} }
if i > 1 && i%bulk == 0 || length == i { if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums]) num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
var err error var err error
if len(autoFields) > 0 { if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields) err = d.ins.setval(ctx, q, mi, autoFields)
} }
return cnt, err return cnt, err
@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...) res, err := q.ExecContext(ctx, query, values...)
if err == nil { if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
} }
return 0, err return 0, err
} }
row := q.QueryRow(query, values...) row := q.QueryRowContext(ctx, query, values...)
var id int64 var id int64
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err
@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
// InsertOrUpdate a row // InsertOrUpdate a row
// If your primary key or unique column conflict will update // If your primary key or unique column conflict will update
// If no will insert // If no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := "" args0 := ""
iouStr := "" iouStr := ""
argsMap := map[string]string{} argsMap := map[string]string{}
@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...) res, err := q.ExecContext(ctx, query, values...)
if err == nil { if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
return 0, err return 0, err
} }
row := q.QueryRow(query, values...) row := q.QueryRowContext(ctx, query, values...)
var id int64 var id int64
err = row.Scan(&id) err = row.Scan(&id)
if err != nil && err.Error() == `pq: syntax error at or near "ON"` { if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
} }
// execute update sql dbQuerier with given struct reflect.Value. // execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if !ok { if !ok {
return 0, ErrMissPK return 0, ErrMissPK
@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, setValues...) res, err := q.ExecContext(ctx, query, setValues...)
if err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} }
@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// execute delete sql dbQuerier with given struct reflect.Value. // execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk. // delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...) res, err := q.ExecContext(ctx, query, args...)
if err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
} }
} }
err := d.deleteRels(q, mi, args, tz) err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// update table-related record by querySet. // update table-related record by querySet.
// need querySet not struct reflect.Value to update related records. // need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
for col, val := range params { for col, val := range params {
@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var err error res, err := q.ExecContext(ctx, query, values...)
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
if err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} }
@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
// delete related records. // delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship. // do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
switch fi.onDelete { switch fi.onDelete {
case odCascade: case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
if err != nil { if err != nil {
return err return err
} }
@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
if fi.onDelete == odSetDefault { if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String() params[fi.column] = fi.initial.String()
} }
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
if err != nil { if err != nil {
return err return err
} }
@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
} }
// delete table-related records. // delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true tables.skipEnd = true
@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
r, err := q.Query(query, args...) r, err := q.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var res sql.Result res, err := q.ExecContext(ctx, query, args...)
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
if err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if num > 0 { if num > 0 {
err := d.deleteRels(q, mi, args, tz) err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -943,7 +933,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
// read related records. // read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -1052,18 +1042,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows rs, err := q.QueryContext(ctx, query, args...)
var err error if err != nil {
if qs != nil && qs.forContext { return 0, err
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
} }
defer rs.Close() defer rs.Close()
@ -1178,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} }
// excute count sql and return count result int64. // excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -1200,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var row *sql.Row row := q.QueryRowContext(ctx, query, args...)
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
err = row.Scan(&cnt) err = row.Scan(&cnt)
return return
} }
@ -1655,7 +1631,7 @@ setValue:
} }
// query sql, read values , save to *[]ParamList. // query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
maps []Params maps []Params
@ -1738,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
rs, err := q.Query(query, args...) rs, err := q.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -1853,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
} }
// sync auto key // sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil return nil
} }
@ -1898,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
} }
// get all cloumns in table. // get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string) columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query) rows, err := db.QueryContext(ctx, query)
if err != nil { if err != nil {
return columns, err return columns, err
} }
@ -1940,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
} }
// not implement. // not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
} }
// execute sql to check index exist. // execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int var cnt int
row.Scan(&cnt) row.Scan(&cnt)
@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
// If your primary key or unique column conflict will update // If your primary key or unique column conflict will update
// If no will insert // If no will insert
// Add "`" for mysql sql building // Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string var iouStr string
argsMap := map[string]string{} argsMap := map[string]string{}
@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...) res, err := q.ExecContext(ctx, query, values...)
if err == nil { if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val
return 0, err return 0, err
} }
row := q.QueryRow(query, values...) row := q.QueryRowContext(ctx, query, values...)
var id int64 var id int64
err = row.Scan(&id) err = row.Scan(&id)
return id, err return id, err

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
} }
// check index is exist // check index is exist
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...) res, err := q.ExecContext(ctx, query, values...)
if err == nil { if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam
} }
return 0, err return 0, err
} }
row := q.QueryRow(query, values...) row := q.QueryRowContext(ctx, query, values...)
var id int64 var id int64
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
) )
@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
} }
// sync auto key // sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 { if len(autoFields) == 0 {
return nil return nil
} }
@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string
mi.table, name, mi.table, name,
Q, name, Q, Q, name, Q,
Q, mi.table, Q) Q, mi.table, Q)
if _, err := db.Exec(query); err != nil { if _, err := db.ExecContext(ctx, query); err != nil {
return err return err
} }
} }
@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string {
} }
// check index exist in postgresql. // check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query) row := db.QueryRowContext(ctx, query)
var cnt int var cnt int
row.Scan(&cnt) row.Scan(&cnt)
return cnt > 0 return cnt > 0

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
@ -73,11 +74,11 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax // override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate { if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
} }
return d.dbBase.Read(q, mi, ind, tz, cols, false) return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
} }
// get sqlite operator. // get sqlite operator.
@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
} }
// get columns in sqlite. // get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query) rows, err := db.QueryContext(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
} }
// check index exist in sqlite. // check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table) query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query) rows, err := db.QueryContext(ctx, query)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
) )
@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
} }
// execute sql to check index exist. // execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int var cnt int
row.Scan(&cnt) row.Scan(&cnt)

View File

@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
return nil return nil
} }
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
return nil return nil
} }
@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return nil return nil
} }
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
return nil return nil
} }

View File

@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, o.Driver()) assert.Nil(t, o.Driver())
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.QueryM2M(nil, ""))
assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.ReadWithCtx(nil, nil))
assert.Nil(t, o.Read(nil)) assert.Nil(t, o.Read(nil))
@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(0), i) assert.Equal(t, int64(0), i)
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.QueryTable(nil))
assert.Nil(t, o.Read(nil)) assert.Nil(t, o.Read(nil))

View File

@ -27,7 +27,7 @@ import (
// this Filter's behavior looks a little bit strange // this Filter's behavior looks a little bit strange
// for example: // for example:
// if we want to trace QuerySetter // if we want to trace QuerySetter
// actually we trace invoking "QueryTable" and "QueryTableWithCtx" // actually we trace invoking "QueryTable"
// the method Begin*, Commit and Rollback are ignored. // the method Begin*, Commit and Rollback are ignored.
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
type FilterChainBuilder struct { type FilterChainBuilder struct {

View File

@ -31,7 +31,7 @@ import (
// this Filter's behavior looks a little bit strange // this Filter's behavior looks a little bit strange
// for example: // for example:
// if we want to records the metrics of QuerySetter // if we want to records the metrics of QuerySetter
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" // actually we only records metrics of invoking "QueryTable"
type FilterChainBuilder struct { type FilterChainBuilder struct {
summaryVec prometheus.ObserverVec summaryVec prometheus.ObserverVec
AppName string AppName string

View File

@ -20,6 +20,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/core/utils"
) )
@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac
} }
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
return f.QueryM2MWithCtx(context.Background(), md, name)
}
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, _ := modelCache.getByMd(md) mi, _ := modelCache.getByMd(md)
inv := &Invocation{ inv := &Invocation{
Method: "QueryM2MWithCtx", Method: "QueryM2M",
Args: []interface{}{md, name}, Args: []interface{}{md, name},
Md: md, Md: md,
mi: mi, mi: mi,
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
f: func(c context.Context) []interface{} { f: func(c context.Context) []interface{} {
res := f.ormer.QueryM2MWithCtx(c, md, name) res := f.ormer.QueryM2M(md, name)
return []interface{}{res} return []interface{}{res}
}, },
} }
res := f.root(ctx, inv) res := f.root(context.Background(), inv)
if res[0] == nil { if res[0] == nil {
return nil return nil
} }
return res[0].(QueryM2Mer) return res[0].(QueryM2Mer)
} }
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { // NOTE: this method is deprecated, context parameter will not take effect.
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.")
return f.QueryM2M(md, name)
} }
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
var ( var (
name string name string
md interface{} md interface{}
@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
} }
inv := &Invocation{ inv := &Invocation{
Method: "QueryTableWithCtx", Method: "QueryTable",
Args: []interface{}{ptrStructOrTableName}, Args: []interface{}{ptrStructOrTableName},
InsideTx: f.insideTx, InsideTx: f.insideTx,
TxStartTime: f.txStartTime, TxStartTime: f.txStartTime,
Md: md, Md: md,
mi: mi, mi: mi,
f: func(c context.Context) []interface{} { f: func(c context.Context) []interface{} {
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) res := f.ormer.QueryTable(ptrStructOrTableName)
return []interface{}{res} return []interface{}{res}
}, },
} }
res := f.root(ctx, inv) res := f.root(context.Background(), inv)
if res[0] == nil { if res[0] == nil {
return nil return nil
@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT
return res[0].(QuerySeter) return res[0].(QuerySeter)
} }
// NOTE: this method is deprecated, context parameter will not take effect.
func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.")
return f.QueryTable(ptrStructOrTableName)
}
func (f *filterOrmDecorator) DBStats() *sql.DBStats { func (f *filterOrmDecorator) DBStats() *sql.DBStats {
inv := &Invocation{ inv := &Invocation{
Method: "DBStats", Method: "DBStats",

View File

@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
o := &filterMockOrm{} o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter { od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} { return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryM2MWithCtx", inv.Method) assert.Equal(t, "QueryM2M", inv.Method)
assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, 2, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx) assert.False(t, inv.InsideTx)
@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) {
o := &filterMockOrm{} o := &filterMockOrm{}
od := NewFilterOrmDecorator(o, func(next Filter) Filter { od := NewFilterOrmDecorator(o, func(next Filter) Filter {
return func(ctx context.Context, inv *Invocation) []interface{} { return func(ctx context.Context, inv *Invocation) []interface{} {
assert.Equal(t, "QueryTableWithCtx", inv.Method) assert.Equal(t, "QueryTable", inv.Method)
assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, 1, len(inv.Args))
assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.Equal(t, "FILTER_TEST", inv.GetTableName())
assert.False(t, inv.InsideTx) assert.False(t, inv.InsideTx)

View File

@ -492,45 +492,45 @@ var (
helpinfo = `need driver and source! helpinfo = `need driver and source!
Default DB Drivers. Default DB Drivers.
driver: url driver: url
mysql: https://github.com/go-sql-driver/mysql mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3 sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb tidb: https://github.com/pingcap/tidb
usage: usage:
go get -u github.com/beego/beego/v2/client/orm go get -u github.com/beego/beego/v2/client/orm
go get -u github.com/go-sql-driver/mysql go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3 go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb go get -u github.com/pingcap/tidb
#### MySQL #### MySQL
mysql -u root -e 'create database orm_test;' mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8" export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/beego/beego/v2/client/orm go test -v github.com/beego/beego/v2/client/orm
#### Sqlite3 #### Sqlite3
export ORM_DRIVER=sqlite3 export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory' export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/beego/beego/v2/client/orm go test -v github.com/beego/beego/v2/client/orm
#### PostgreSQL #### PostgreSQL
psql -c 'create database orm_test;' -U postgres psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/beego/beego/v2/client/orm go test -v github.com/beego/beego/v2/client/orm
#### TiDB #### TiDB
export ORM_DRIVER=tidb export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test' export ORM_SOURCE='memory://test/test'
go test -v github.com/beego/beego/v2/pgk/orm go test -v github.com/beego/beego/v2/pgk/orm
` `
) )

View File

@ -136,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error {
} }
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
} }
// read data to model, like Read(), but use "SELECT FOR UPDATE" form // read data to model, like Read(), but use "SELECT FOR UPDATE" form
@ -145,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
} }
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true)
} }
// Try to read a row from the database, or insert one if it doesn't exist // Try to read a row from the database, or insert one if it doesn't exist
@ -155,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...) cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false)
if err == ErrNoRows { if err == ErrNoRows {
// Create // Create
id, err := o.InsertWithCtx(ctx, md) id, err := o.InsertWithCtx(ctx, md)
@ -180,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) {
} }
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
@ -223,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
for i := 0; i < sind.Len(); i++ { for i := 0; i < sind.Len(); i++ {
ind := reflect.Indirect(sind.Index(i)) ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false) mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -234,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac
} }
} else { } else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false) mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ)
} }
return cnt, nil return cnt, nil
} }
@ -245,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (
} }
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...)
if err != nil { if err != nil {
return id, err return id, err
} }
@ -262,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
} }
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols)
} }
// delete model in database // delete model in database
@ -272,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
} }
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return num, err return num, err
} }
@ -284,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str
// create a models to models queryer // create a models to models queryer
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
return o.QueryM2MWithCtx(context.Background(), md, name)
}
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
@ -300,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri
return newQueryM2M(md, o, mi, fi, ind) return newQueryM2M(md, o, mi, fi, ind)
} }
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer {
logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.")
return o.QueryM2M(md, name)
}
// load related models to md model. // load related models to md model.
// args are limit, offset int and order string. // args are limit, offset int and order string.
// //
@ -452,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
// table name can be string or struct. // table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
}
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string var name string
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table) name = nameStrategyMap[defaultNameStrategy](table)
@ -470,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in
if qs == nil { if qs == nil {
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name)) panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
} }
return return qs
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.")
return o.QueryTable(ptrStructOrTableName)
} }
// return a raw query seter for raw sql string. // return a raw query seter for raw sql string.
@ -596,9 +602,8 @@ func NewOrm() Ormer {
func NewOrmUsingDB(aliasName string) Ormer { func NewOrmUsingDB(aliasName string) Ormer {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
return newDBWithAlias(al) return newDBWithAlias(al)
} else {
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
} }
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
} }
// NewOrmWithDB create a new ormer object with specify *sql.DB for query // NewOrmWithDB create a new ormer object with specify *sql.DB for query

View File

@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error {
} }
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
return d.ExecContext(context.Background(), args...)
}
func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
a := time.Now() a := time.Now()
res, err := d.stmt.Exec(args...) res, err := d.stmt.ExecContext(ctx, args...)
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
return res, err return res, err
} }
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
return d.QueryContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
a := time.Now() a := time.Now()
res, err := d.stmt.Query(args...) res, err := d.stmt.QueryContext(ctx, args...)
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
return res, err return res, err
} }
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
return d.QueryRowContext(context.Background(), args...)
}
func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
a := time.Now() a := time.Now()
res := d.stmt.QueryRow(args...) res := d.stmt.QueryRow(args...)
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
) )
@ -31,6 +32,10 @@ var _ Inserter = new(insertSet)
// insert model ignore it's registered or not. // insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) { func (o *insertSet) Insert(md interface{}) (int64, error) {
return o.InsertWithCtx(context.Background(), md)
}
func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
} }
@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if name != o.mi.fullName { if name != o.mi.fullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
} }
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
@ -70,11 +75,11 @@ func (o *insertSet) Close() error {
} }
// create new insert queryer. // create new insert queryer.
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
bi := new(insertSet) bi := new(insertSet)
bi.orm = orm bi.orm = orm
bi.mi = mi bi.mi = mi
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,7 +14,10 @@
package orm package orm
import "reflect" import (
"context"
"reflect"
)
// model to model struct // model to model struct
type queryM2M struct { type queryM2M struct {
@ -33,6 +36,10 @@ type queryM2M struct {
// //
// make sure the relation is defined in post model struct tag. // make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) { func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
return o.AddWithCtx(context.Background(), mds...)
}
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
mi := fi.relThroughModelInfo mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo mfi := fi.reverseFieldInfo
@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
names = append(names, otherNames...) names = append(names, otherNames...)
values = append(values, otherValues...) values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values) return dbase.InsertValue(ctx, orm.db, mi, true, names, values)
} }
// remove models following the origin model relationship // remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return o.RemoveWithCtx(context.Background(), mds...)
}
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
// check model is existed in relationship of origin model // check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool { func (o *queryM2M) Exist(md interface{}) bool {
return o.ExistWithCtx(context.Background(), md)
}
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md). return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist() Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
} }
// clean all models in related of origin model // clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) { func (o *queryM2M) Clear() (int64, error) {
return o.ClearWithCtx(context.Background())
}
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
} }
// count all related models of origin model // count all related models of origin model
func (o *queryM2M) Count() (int64, error) { func (o *queryM2M) Count() (int64, error) {
return o.CountWithCtx(context.Background())
}
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
} }
var _ QueryM2Mer = new(queryM2M) var _ QueryM2Mer = new(queryM2M)

View File

@ -17,8 +17,9 @@ package orm
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/beego/beego/v2/client/orm/hints"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints"
) )
type colValue struct { type colValue struct {
@ -64,22 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct // real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
related []string related []string
relDepth int relDepth int
limit int64 limit int64
offset int64 offset int64
groups []string groups []string
orders []*order_clause.Order orders []*order_clause.Order
distinct bool distinct bool
forUpdate bool forUpdate bool
useIndex int useIndex int
indexes []string indexes []string
orm *ormBase orm *ormBase
ctx context.Context aggregate string
forContext bool
aggregate string
} }
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
@ -223,23 +222,39 @@ func (o querySet) GetCond() *Condition {
// return QuerySeter execution result number // return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.CountWithCtx(context.Background())
}
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// check result empty or not after QuerySeter executed // check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool { func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.ExistWithCtx(context.Background())
}
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0 return cnt > 0
} }
// execute update with parameters // execute update with parameters
func (o *querySet) Update(values Params) (int64, error) { func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) return o.UpdateWithCtx(context.Background(), values)
}
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
} }
// execute delete // execute delete
func (o *querySet) Delete() (int64, error) { func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.DeleteWithCtx(context.Background())
}
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// return a insert queryer. // return a insert queryer.
@ -248,20 +263,32 @@ func (o *querySet) Delete() (int64, error) {
// i,err := sq.PrepareInsert() // i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{}) // i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) { func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi) return o.PrepareInsertWithCtx(context.Background())
}
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
return newInsertSet(ctx, o.orm, o.mi)
} }
// query all data and map to containers. // query all data and map to containers.
// cols means the columns when querying. // cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) { func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) return o.AllWithCtx(context.Background(), container, cols...)
}
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
} }
// query one row data and map to containers. // query one row data and map to containers.
// cols means the columns when querying. // cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...)
}
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
o.limit = 1 o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
} }
@ -279,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error {
// expres means condition expression. // expres means condition expression.
// it converts data to []map[column]value. // it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.ValuesWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to [][]interface // query all data and map to [][]interface
// it converts data to [][column_index]value // it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.ValuesListWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to []interface. // query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value. // it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) return o.ValuesFlatWithCtx(context.Background(), result, expr)
}
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
} }
// query all rows into map[string]interface with specify key and value column name. // query all rows into map[string]interface with specify key and value column name.
@ -322,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter. // create new QuerySeter.
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
@ -341,4 +373,4 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
func (o querySet) Aggregate(s string) QuerySeter { func (o querySet) Aggregate(s string) QuerySeter {
o.aggregate = s o.aggregate = s
return &o return &o
} }

View File

@ -2820,3 +2820,23 @@ func TestCondition(t *testing.T) {
throwFail(t, AssertIs(!cycleFlag, true)) throwFail(t, AssertIs(!cycleFlag, true))
return return
} }
func TestContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
user := User{UserName: "slene"}
err := dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, err)
cancel()
err = dORM.ReadWithCtx(ctx, &user, "UserName")
throwFail(t, AssertIs(err, context.Canceled))
ctx, cancel = context.WithCancel(context.Background())
cancel()
qs := dORM.QueryTable(user)
_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx)
throwFail(t, AssertIs(err, context.Canceled))
}

View File

@ -17,10 +17,10 @@ package orm
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"reflect" "reflect"
"time" "time"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/core/utils"
) )
@ -197,12 +197,16 @@ type DQL interface {
// post := Post{Id: 4} // post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags") // m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer QueryM2M(md interface{}, name string) QueryM2Mer
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations. // return a QuerySeter for table operations.
// table name can be string or struct. // table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter QueryTable(ptrStructOrTableName interface{}) QuerySeter
// NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
DBStats() *sql.DBStats DBStats() *sql.DBStats
@ -231,6 +235,7 @@ type TxOrmer interface {
// Inserter insert prepared statement // Inserter insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertWithCtx(context.Context, interface{}) (int64, error)
Close() error Close() error
} }
@ -350,9 +355,11 @@ type QuerySeter interface {
// for example: // for example:
// num, err = qs.Filter("profile__age__gt", 28).Count() // num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error) Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
// check result empty or not after QuerySeter executed // check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0 // the same as QuerySeter.Count > 0
Exist() bool Exist() bool
ExistWithCtx(context.Context) bool
// execute update with parameters // execute update with parameters
// for example: // for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{ // num, err = qs.Filter("user_name", "slene").Update(Params{
@ -362,11 +369,13 @@ type QuerySeter interface {
// "user_name": "slene2" // "user_name": "slene2"
// }) // user slene's name will change to slene2 // }) // user slene's name will change to slene2
Update(values Params) (int64, error) Update(values Params) (int64, error)
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
// delete from table // delete from table
// for example: // for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2 // //delete two user who's name is testing1 or testing2
Delete() (int64, error) Delete() (int64, error)
DeleteWithCtx(context.Context) (int64, error)
// return a insert queryer. // return a insert queryer.
// it can be used in times. // it can be used in times.
// example: // example:
@ -375,18 +384,21 @@ type QuerySeter interface {
// num, err = i.Insert(&user2) // user table will add one record user2 at once // num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close // err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
PrepareInsertWithCtx(context.Context) (Inserter, error)
// query all data and map to containers. // query all data and map to containers.
// cols means the columns when querying. // cols means the columns when querying.
// for example: // for example:
// var users []*User // var users []*User
// qs.All(&users) // users[0],users[1],users[2] ... // qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error) All(container interface{}, cols ...string) (int64, error)
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
// query one row data and map to containers. // query one row data and map to containers.
// cols means the columns when querying. // cols means the columns when querying.
// for example: // for example:
// var user User // var user User
// qs.One(&user) //user.UserName == "slene" // qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error One(container interface{}, cols ...string) error
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
// query all data and map to []map[string]interface. // query all data and map to []map[string]interface.
// expres means condition expression. // expres means condition expression.
// it converts data to []map[column]value. // it converts data to []map[column]value.
@ -394,18 +406,21 @@ type QuerySeter interface {
// var maps []Params // var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene" // qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error) Values(results *[]Params, exprs ...string) (int64, error)
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface // query all data and map to [][]interface
// it converts data to [][column_index]value // it converts data to [][column_index]value
// for example: // for example:
// var list []ParamsList // var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene" // qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error) ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface. // query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value. // it's designed for one column record set, auto change to []value, not [][column]value.
// for example: // for example:
// var list ParamsList // var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene" // qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error) ValuesFlat(result *ParamsList, expr string) (int64, error)
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name. // query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value" // keyCol = "name", valueCol = "value"
// table data // table data
@ -454,18 +469,23 @@ type QueryM2Mer interface {
// insert one or more rows to m2m table // insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag. // make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
AddWithCtx(context.Context, ...interface{}) (int64, error)
// remove models following the origin model relationship // remove models following the origin model relationship
// only delete rows from m2m table // only delete rows from m2m table
// for example: // for example:
// tag3 := &Tag{Id:5,Name: "TestTag3"} // tag3 := &Tag{Id:5,Name: "TestTag3"}
// num, err = m2m.Remove(tag3) // num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
// check model is existed in relationship of origin model // check model is existed in relationship of origin model
Exist(interface{}) bool Exist(interface{}) bool
ExistWithCtx(context.Context, interface{}) bool
// clean all models in related of origin model // clean all models in related of origin model
Clear() (int64, error) Clear() (int64, error)
ClearWithCtx(context.Context) (int64, error)
// count all related models of origin model // count all related models of origin model
Count() (int64, error) Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
} }
// RawPreparer raw query statement // RawPreparer raw query statement
@ -539,11 +559,11 @@ type RawSeter interface {
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error) Query(args ...interface{}) (*sql.Rows, error)
// QueryContext(args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
} }
// db querier // db querier
@ -580,28 +600,28 @@ type txEnder interface {
// base database struct // base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool SupportUpdateJoin() bool
OperatorSQL(string) string OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)
@ -610,12 +630,12 @@ type dbBaser interface {
TimeToDB(*time.Time, *time.Location) TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error) GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error) GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string ShowTablesQuery() string
ShowColumnsQuery(string) string ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool IndexExists(context.Context, dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error setval(context.Context, dbQuerier, *modelInfo, []string) error
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
} }

View File

@ -142,8 +142,8 @@ func WithRouterSessionOn(sessionOn bool) ControllerOption {
type filterChainConfig struct { type filterChainConfig struct {
pattern string pattern string
chain FilterChain chain FilterChain
opts []FilterOpt opts []FilterOpt
} }
// ControllerRegister containers registered router rules, controller handlers and filters. // ControllerRegister containers registered router rules, controller handlers and filters.
@ -180,7 +180,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
return beecontext.NewContext() return beecontext.NewContext()
}, },
}, },
cfg: cfg, cfg: cfg,
filterChains: make([]filterChainConfig, 0, 4), filterChains: make([]filterChainConfig, 0, 4),
} }
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
@ -189,7 +189,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
// Init will be executed when HttpServer start running // Init will be executed when HttpServer start running
func (p *ControllerRegister) Init() { func (p *ControllerRegister) Init() {
for i := len(p.filterChains) - 1; i >= 0 ; i -- { for i := len(p.filterChains) - 1; i >= 0; i-- {
fc := p.filterChains[i] fc := p.filterChains[i]
root := p.chainRoot root := p.chainRoot
filterFunc := fc.chain(root.filterFunc) filterFunc := fc.chain(root.filterFunc)
@ -265,11 +265,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
route := &ControllerInfo{} route := p.createBeegoRouter(t, pattern)
route.pattern = pattern
route.routerType = routerTypeBeego
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.controllerType = t
route.initialize = func() ControllerInterface { route.initialize = func() ControllerInterface {
vc := reflect.New(route.controllerType) vc := reflect.New(route.controllerType)
execController, ok := vc.Interface().(ControllerInterface) execController, ok := vc.Interface().(ControllerInterface)
@ -477,34 +473,84 @@ func (p *ControllerRegister) RouterAny(pattern string, f interface{}) {
// //
// AddRouterMethod("get","/api/:id", MyController.Ping) // AddRouterMethod("get","/api/:id", MyController.Ping)
func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) {
httpMethod = strings.ToUpper(httpMethod) httpMethod = p.getUpperMethodString(httpMethod)
if httpMethod != "*" && !HTTPMETHOD[httpMethod] {
panic("not support http method: " + httpMethod)
}
ct, methodName := getReflectTypeAndMethod(f) ct, methodName := getReflectTypeAndMethod(f)
p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern)
} }
// addBeegoTypeRouter add beego type router // addBeegoTypeRouter add beego type router
func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) {
route := p.createBeegoRouter(ct, pattern)
methods := p.getHttpMethodMapMethod(httpMethod, ctMethod)
route.methods = methods
p.addRouterForMethod(route)
}
// createBeegoRouter create beego router base on reflect type and pattern
func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo {
route := &ControllerInfo{} route := &ControllerInfo{}
route.pattern = pattern route.pattern = pattern
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.controllerType = ct route.controllerType = ct
return route
}
// createRestfulRouter create restful router with filter function and pattern
func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.runFunction = f
return route
}
// createHandlerRouter create handler router with handler and pattern
func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.handler = h
return route
}
// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will
// use http method as the controller method
func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string {
methods := make(map[string]string) methods := make(map[string]string)
if httpMethod == "*" { // not match-all sign, only add for the http method
for val := range HTTPMETHOD { if httpMethod != "*" {
if ctMethod == "" {
ctMethod = httpMethod
}
methods[httpMethod] = ctMethod
return methods
}
// add all http method
for val := range HTTPMETHOD {
if ctMethod == "" {
methods[val] = val
} else {
methods[val] = ctMethod methods[val] = ctMethod
} }
} else {
methods[httpMethod] = ctMethod
} }
route.methods = methods return methods
}
p.addRouterForMethod(route) // getUpperMethodString get upper string of method, and panic if the method
// is not valid
func (p *ControllerRegister) getUpperMethodString(method string) string {
method = strings.ToUpper(method)
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method)
}
return method
} }
// get reflect controller type and method by controller method expression // get reflect controller type and method by controller method expression
@ -632,36 +678,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
// ctx.Output.Body("hello world") // ctx.Output.Body("hello world")
// }) // })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method) method = p.getUpperMethodString(method)
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method) route := p.createRestfulRouter(f, pattern)
} methods := p.getHttpMethodMapMethod(method, "")
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for val := range HTTPMETHOD {
methods[val] = val
}
} else {
methods[method] = method
}
route.methods = methods route.methods = methods
for k := range methods {
p.addToRouter(k, pattern, route) p.addRouterForMethod(route)
}
} }
// Handler add user defined Handler // Handler add user defined Handler
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &ControllerInfo{} route := p.createHandlerRouter(h, pattern)
route.pattern = pattern
route.routerType = routerTypeHandler
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.handler = h
if len(options) > 0 { if len(options) > 0 {
if _, ok := options[0].(bool); ok { if _, ok := options[0].(bool); ok {
pattern = path.Join(pattern, "?:all(.*)") pattern = path.Join(pattern, "?:all(.*)")
@ -693,16 +721,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
controllerName := strings.TrimSuffix(ct.Name(), "Controller") controllerName := strings.TrimSuffix(ct.Name(), "Controller")
for i := 0; i < rt.NumMethod(); i++ { for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) { if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
route := &ControllerInfo{}
route.routerType = routerTypeBeego
route.sessionOn = p.cfg.WebConfig.Session.SessionOn
route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
route := p.createBeegoRouter(ct, pattern)
route.methods = map[string]string{"*": rt.Method(i).Name}
for m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route) p.addToRouter(m, patternInit, route)
@ -739,8 +764,8 @@ func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.filterChains = append(p.filterChains, filterChainConfig{ p.filterChains = append(p.filterChains, filterChainConfig{
pattern: pattern, pattern: pattern,
chain: chain, chain: chain,
opts: opts, opts: opts,
}) })
} }