diff --git a/CHANGELOG.md b/CHANGELOG.md index b177e912..25c91b6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,4 +8,5 @@ - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) - Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) - Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) -- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) \ No newline at end of file +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) diff --git a/client/orm/cmd.go b/client/orm/cmd.go index b0661971..b377a5f2 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,6 +15,7 @@ package orm import ( + "context" "flag" "fmt" "os" @@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } + ctx := context.Background() for i, mi := range modelCache.allOrdered() { if !isApplicableTableForDB(mi.addrField, d.al.Name) { @@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/client/orm/db.go b/client/orm/db.go index c994469f..a49d6df7 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() dbcols := make([]string, 0, len(mi.fields.dbcols)) @@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } // insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err @@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } // execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } // multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err @@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err @@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) if err != nil && err.Error() == `pq: syntax error at or near "ON"` { @@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } // execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) + res, err := q.ExecContext(ctx, query, setValues...) if err == nil { return res.RowsAffected() } @@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { @@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } + res, err := q.ExecContext(ctx, query, values...) if err == nil { return res.RowsAffected() } @@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * } // delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -943,7 +933,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } // read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -1052,18 +1042,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err } defer rs.Close() @@ -1178,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } // excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1200,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } @@ -1655,7 +1631,7 @@ setValue: } // query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params @@ -1738,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond d.ins.ReplaceMarks(&query) - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1853,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { return nil } @@ -1898,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } // get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1940,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index ee68baf7..c89b1e52 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" "strings" @@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index 1de440b6..a3b93ff3 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strings" @@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index 12431d6e..b2f321db 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strconv" ) @@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string mi.table, name, Q, name, Q, Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index aff713a5..6a4b3131 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -73,11 +74,11 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } - return d.dbBase.Read(q, mi, ind, tz, cols, false) + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } // get sqlite operator. @@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { } // get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 6020a488..48c5b4e7 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go index c6da420d..59ffe877 100644 --- a/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { return nil } @@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { return nil } diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go index 4d477353..e10f70af 100644 --- a/client/orm/do_nothing_orm_test.go +++ b/client/orm/do_nothing_orm_test.go @@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, o.Driver()) - assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.Read(nil)) @@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(0), i) - assert.Nil(t, o.QueryTableWithCtx(nil, nil)) assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.Read(nil)) diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go index 75852c63..7afa07f1 100644 --- a/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -27,7 +27,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to trace QuerySetter -// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// actually we trace invoking "QueryTable" // the method Begin*, Commit and Rollback are ignored. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go index db60876e..e68e9670 100644 --- a/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -31,7 +31,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to records the metrics of QuerySetter -// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" +// actually we only records metrics of invoking "QueryTable" type FilterChainBuilder struct { summaryVec prometheus.ObserverVec AppName string diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a60390a1..caf2b3f9 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { - return f.QueryM2MWithCtx(context.Background(), md, name) -} - -func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, _ := modelCache.getByMd(md) inv := &Invocation{ - Method: "QueryM2MWithCtx", + Method: "QueryM2M", Args: []interface{}{md, name}, Md: md, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, f: func(c context.Context) []interface{} { - res := f.ormer.QueryM2MWithCtx(c, md, name) + res := f.ormer.QueryM2M(md, name) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil } return res[0].(QueryM2Mer) } -func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { - return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) } -func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { var ( name string md interface{} @@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT } inv := &Invocation{ - Method: "QueryTableWithCtx", + Method: "QueryTable", Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, Md: md, mi: mi, f: func(c context.Context) []interface{} { - res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + res := f.ormer.QueryTable(ptrStructOrTableName) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil @@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT return res[0].(QuerySeter) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + func (f *filterOrmDecorator) DBStats() *sql.DBStats { inv := &Invocation{ Method: "DBStats", diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 9e223358..566499dd 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, "QueryM2M", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, "QueryTable", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) diff --git a/client/orm/models_test.go b/client/orm/models_test.go index 5add6e45..3fd35765 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -492,45 +492,45 @@ var ( helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - + go get -u github.com/beego/beego/v2/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq go get -u github.com/pingcap/tidb - + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" go test -v github.com/beego/beego/v2/client/orm - - + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/beego/beego/v2/client/orm - - + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/beego/beego/v2/client/orm - + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' go test -v github.com/beego/beego/v2/pgk/orm - + ` ) diff --git a/client/orm/orm.go b/client/orm/orm.go index f33c4b51..660f2939 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -136,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form @@ -145,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -155,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.InsertWithCtx(ctx, md) @@ -180,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -223,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } @@ -234,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } @@ -245,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err } @@ -262,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) + return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } // delete model in database @@ -272,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } @@ -284,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - return o.QueryM2MWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -300,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri return newQueryM2M(md, o, mi, fi, ind) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) +} + // load related models to md model. // args are limit, offset int and order string. // @@ -452,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) -} -func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -470,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return + return qs +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) } // return a raw query seter for raw sql string. @@ -596,9 +602,8 @@ func NewOrm() Ormer { func NewOrmUsingDB(aliasName string) Ormer { if al, ok := dataBaseCache.get(aliasName); ok { return newDBWithAlias(al) - } else { - panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } // NewOrmWithDB create a new ormer object with specify *sql.DB for query diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 61addeb5..8ac373b5 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } - func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 6f9798d3..50c1ca41 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" ) @@ -31,6 +32,10 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } @@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -70,11 +75,11 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 17e1b5d1..9da49bba 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,7 +14,10 @@ package orm -import "reflect" +import ( + "context" + "reflect" +) // model to model struct type queryM2M struct { @@ -33,6 +36,10 @@ type queryM2M struct { // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo mfi := fi.reverseFieldInfo @@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) } // clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 692c24cf..9f7b8441 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -17,8 +17,9 @@ package orm import ( "context" "fmt" - "github.com/beego/beego/v2/client/orm/hints" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" ) type colValue struct { @@ -64,22 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []*order_clause.Order - distinct bool - forUpdate bool - useIndex int - indexes []string - orm *ormBase - ctx context.Context - forContext bool - aggregate string + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + aggregate string } var _ QuerySeter = new(querySet) @@ -223,23 +222,39 @@ func (o querySet) GetCond() *Condition { // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.CountWithCtx(context.Background()) +} + +func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.ExistWithCtx(context.Background()) +} + +func (o *querySet) ExistWithCtx(ctx context.Context) bool { + cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) + return o.UpdateWithCtx(context.Background(), values) +} + +func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.DeleteWithCtx(context.Background()) +} + +func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // return a insert queryer. @@ -248,20 +263,32 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { + return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + return o.AllWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { + return o.OneWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } @@ -279,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } // query all rows into map[string]interface with specify key and value column name. @@ -322,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - // create new QuerySeter. func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) @@ -341,4 +373,4 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { func (o querySet) Aggregate(s string) QuerySeter { o.aggregate = s return &o -} \ No newline at end of file +} diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index d13a0b65..e2e25ac4 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -2820,3 +2820,23 @@ func TestCondition(t *testing.T) { throwFail(t, AssertIs(!cycleFlag, true)) return } + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) + throwFail(t, AssertIs(err, context.Canceled)) +} diff --git a/client/orm/types.go b/client/orm/types.go index dd6c0b95..59eb9055 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,10 +17,10 @@ package orm import ( "context" "database/sql" - "github.com/beego/beego/v2/client/orm/clauses/order_clause" "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) @@ -197,12 +197,16 @@ type DQL interface { // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter DBStats() *sql.DBStats @@ -231,6 +235,7 @@ type TxOrmer interface { // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) Close() error } @@ -350,9 +355,11 @@ type QuerySeter interface { // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + CountWithCtx(context.Context) (int64, error) // check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool + ExistWithCtx(context.Context) bool // execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ @@ -362,11 +369,13 @@ type QuerySeter interface { // "user_name": "slene2" // }) // user slene's name will change to slene2 Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) // delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) // return a insert queryer. // it can be used in times. // example: @@ -375,18 +384,21 @@ type QuerySeter interface { // num, err = i.Insert(&user2) // user table will add one record user2 at once // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) // query all data and map to containers. // cols means the columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) // query one row data and map to containers. // cols means the columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error // query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. @@ -394,18 +406,21 @@ type QuerySeter interface { // var maps []Params // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) // query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) // query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) // query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data @@ -454,18 +469,23 @@ type QueryM2Mer interface { // insert one or more rows to m2m table // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) // remove models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool // clean all models in related of origin model Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) // count all related models of origin model Count() (int64, error) + CountWithCtx(context.Context) (int64, error) } // RawPreparer raw query statement @@ -539,11 +559,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - // QueryContext(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier @@ -580,28 +600,28 @@ type txEnder interface { // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -610,12 +630,12 @@ type dbBaser interface { TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) ShowTablesQuery() string ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool + IndexExists(context.Context, dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error + setval(context.Context, dbQuerier, *modelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/server/web/router.go b/server/web/router.go index f9fd4c87..d7fd3047 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -142,8 +142,8 @@ func WithRouterSessionOn(sessionOn bool) ControllerOption { type filterChainConfig struct { pattern string - chain FilterChain - opts []FilterOpt + chain FilterChain + opts []FilterOpt } // ControllerRegister containers registered router rules, controller handlers and filters. @@ -180,7 +180,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { return beecontext.NewContext() }, }, - cfg: cfg, + cfg: cfg, filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) @@ -189,7 +189,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Init will be executed when HttpServer start running func (p *ControllerRegister) Init() { - for i := len(p.filterChains) - 1; i >= 0 ; i -- { + for i := len(p.filterChains) - 1; i >= 0; i-- { fc := p.filterChains[i] root := p.chainRoot filterFunc := fc.chain(root.filterFunc) @@ -265,11 +265,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeBeego - route.sessionOn = p.cfg.WebConfig.Session.SessionOn - route.controllerType = t + route := p.createBeegoRouter(t, pattern) route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) execController, ok := vc.Interface().(ControllerInterface) @@ -477,34 +473,84 @@ func (p *ControllerRegister) RouterAny(pattern string, f interface{}) { // // AddRouterMethod("get","/api/:id", MyController.Ping) func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { - httpMethod = strings.ToUpper(httpMethod) - if httpMethod != "*" && !HTTPMETHOD[httpMethod] { - panic("not support http method: " + httpMethod) - } - + httpMethod = p.getUpperMethodString(httpMethod) ct, methodName := getReflectTypeAndMethod(f) + p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) } // addBeegoTypeRouter add beego type router func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { + route := p.createBeegoRouter(ct, pattern) + methods := p.getHttpMethodMapMethod(httpMethod, ctMethod) + route.methods = methods + + p.addRouterForMethod(route) +} + +// createBeegoRouter create beego router base on reflect type and pattern +func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo { route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeBeego route.sessionOn = p.cfg.WebConfig.Session.SessionOn route.controllerType = ct + return route +} +// createRestfulRouter create restful router with filter function and pattern +func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.runFunction = f + return route +} + +// createHandlerRouter create handler router with handler and pattern +func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.handler = h + return route +} + +// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will +// use http method as the controller method +func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string { methods := make(map[string]string) - if httpMethod == "*" { - for val := range HTTPMETHOD { + // not match-all sign, only add for the http method + if httpMethod != "*" { + + if ctMethod == "" { + ctMethod = httpMethod + } + methods[httpMethod] = ctMethod + return methods + } + + // add all http method + for val := range HTTPMETHOD { + if ctMethod == "" { + methods[val] = val + } else { methods[val] = ctMethod } - } else { - methods[httpMethod] = ctMethod } - route.methods = methods + return methods +} - p.addRouterForMethod(route) +// getUpperMethodString get upper string of method, and panic if the method +// is not valid +func (p *ControllerRegister) getUpperMethodString(method string) string { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + return method } // get reflect controller type and method by controller method expression @@ -632,36 +678,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) { // ctx.Output.Body("hello world") // }) func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.sessionOn = p.cfg.WebConfig.Session.SessionOn - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } + method = p.getUpperMethodString(method) + + route := p.createRestfulRouter(f, pattern) + methods := p.getHttpMethodMapMethod(method, "") route.methods = methods - for k := range methods { - p.addToRouter(k, pattern, route) - } + + p.addRouterForMethod(route) } // Handler add user defined Handler func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.sessionOn = p.cfg.WebConfig.Session.SessionOn - route.handler = h + route := p.createHandlerRouter(h, pattern) if len(options) > 0 { if _, ok := options[0].(bool); ok { pattern = path.Join(pattern, "?:all(.*)") @@ -693,16 +721,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) controllerName := strings.TrimSuffix(ct.Name(), "Controller") for i := 0; i < rt.NumMethod(); i++ { if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.sessionOn = p.cfg.WebConfig.Session.SessionOn - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern + + route := p.createBeegoRouter(ct, pattern) + route.methods = map[string]string{"*": rt.Method(i).Name} for m := range HTTPMETHOD { p.addToRouter(m, pattern, route) p.addToRouter(m, patternInit, route) @@ -739,8 +764,8 @@ func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) p.filterChains = append(p.filterChains, filterChainConfig{ pattern: pattern, - chain: chain, - opts: opts, + chain: chain, + opts: opts, }) }