Add context support for orm

Signed-off-by: Penghui Liao <liaoishere@gmail.com>
This commit is contained in:
Penghui Liao
2021-01-08 19:04:00 +08:00
parent c603131436
commit 21777d3143
15 changed files with 215 additions and 175 deletions

View File

@@ -15,6 +15,7 @@
package orm
import (
"context"
"database/sql"
"errors"
"fmt"
@@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote()
dbcols := make([]string, 0, len(mi.fields.dbcols))
@@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
d.ins.HasReturningID(mi, &query)
stmt, err := q.Prepare(query)
stmt, err := q.PrepareContext(ctx, query)
return stmt, query, err
}
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return 0, err
@@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
err := row.Scan(&id)
return id, err
}
res, err := stmt.Exec(values...)
res, err := stmt.ExecContext(ctx, values...)
if err == nil {
return res.LastInsertId()
}
@@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string
var args []interface{}
@@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...)
row := q.QueryRowContext(ctx, query, args...)
if err := row.Scan(refs...); err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
@@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
}
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols))
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil {
return 0, err
}
id, err := d.InsertValue(q, mi, false, names, values)
id, err := d.InsertValue(ctx, q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return id, err
}
// multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
@@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
@@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
err = d.ins.setval(ctx, q, mi, autoFields)
}
return cnt, err
@@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
}
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err
@@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
// InsertOrUpdate a row
// If your primary key or unique column conflict will update
// If no will insert
func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := ""
iouStr := ""
argsMap := map[string]string{}
@@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
@@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
return 0, err
}
row := q.QueryRow(query, values...)
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
@@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if !ok {
return 0, ErrMissPK
@@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, setValues...)
res, err := q.ExecContext(ctx, query, setValues...)
if err == nil {
return res.RowsAffected()
}
@@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
var whereCols []string
var args []interface{}
// if specify cols length > 0, then use it for where condition.
@@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
@@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
@@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
return res.RowsAffected()
}
@@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
@@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
@@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
}
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins)
tables.skipEnd = true
@@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
r, err := q.Query(query, args...)
r, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query)
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
res, err := q.ExecContext(ctx, query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
err := d.deleteRels(ctx, q, mi, args, tz)
if err != nil {
return num, err
}
@@ -943,7 +933,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@@ -1052,18 +1042,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
var err error
if qs != nil && qs.forContext {
rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rs.Close()
@@ -1178,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@@ -1200,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query)
var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
row := q.QueryRowContext(ctx, query, args...)
err = row.Scan(&cnt)
return
}
@@ -1655,7 +1631,7 @@ setValue:
}
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
maps []Params
@@ -1738,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
d.ins.ReplaceMarks(&query)
rs, err := q.Query(query, args...)
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
@@ -1853,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
}
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
@@ -1898,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
}
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return columns, err
}
@@ -1940,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string {
}
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool {
panic(ErrNotImplement)
}