fix 5500: querySet should not be changed when constructing SQL (#5502)

This commit is contained in:
Ming Deng 2023-10-06 19:48:10 +08:00 committed by GitHub
parent 82aa2c28dc
commit e4b67e86ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 63 deletions

View File

@ -1094,7 +1094,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
}
// ReadBatch read related records.
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -1291,7 +1291,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
return cnt, nil
}
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
cols := d.preProcCols(tCols) // pre process columns
buf := buffers.Get()
@ -1319,7 +1319,7 @@ func (d *dbBase) preProcCols(cols []string) []string {
// readSQL generate a select sql string and return args
// ReadBatch and ReadValues methods will reuse this method.
func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) []interface{} {
func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs querySet, mi *models.ModelInfo, tz *time.Location) []interface{} {
quote := d.ins.TableQuote()
@ -1383,7 +1383,7 @@ func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, c
}
// Count excute count sql and return count result int64.
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
query, args := d.countSQL(qs, mi, cond, tz)
@ -1392,7 +1392,7 @@ func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *model
return
}
func (d *dbBase) countSQL(qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
func (d *dbBase) countSQL(qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -1860,7 +1860,7 @@ setValue:
}
// ReadValues query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
maps []Params
lists []ParamsList
@ -2018,7 +2018,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
return cnt, nil
}
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
buf := buffers.Get()
defer buffers.Put(buf)

View File

@ -912,7 +912,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
db *dbBase
tCols []string
qs *querySet
qs querySet
wantRes string
wantArgs []interface{}
@ -923,7 +923,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -949,7 +949,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -976,7 +976,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1003,7 +1003,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1031,7 +1031,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1058,7 +1058,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1082,7 +1082,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1107,7 +1107,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1132,7 +1132,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1158,7 +1158,7 @@ func TestDbBase_readBatchSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
tCols: []string{"name", "score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1217,7 +1217,7 @@ func TestDbBase_readValuesSQL(t *testing.T) {
db *dbBase
cols []string
qs *querySet
qs querySet
wantRes string
wantArgs []interface{}
@ -1228,7 +1228,7 @@ func TestDbBase_readValuesSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1252,7 +1252,7 @@ func TestDbBase_readValuesSQL(t *testing.T) {
ins: newdbBaseMysql(),
},
cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1277,7 +1277,7 @@ func TestDbBase_readValuesSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1299,7 +1299,7 @@ func TestDbBase_readValuesSQL(t *testing.T) {
ins: newdbBasePostgres(),
},
cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
limit: 10,
@ -1354,7 +1354,7 @@ func TestDbBase_countSQL(t *testing.T) {
name string
db *dbBase
qs *querySet
qs querySet
wantRes string
wantArgs []interface{}
@ -1364,7 +1364,7 @@ func TestDbBase_countSQL(t *testing.T) {
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
useIndex: 1,
@ -1380,7 +1380,7 @@ func TestDbBase_countSQL(t *testing.T) {
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
useIndex: 1,
@ -1397,7 +1397,7 @@ func TestDbBase_countSQL(t *testing.T) {
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
related: make([]string, 0),
@ -1411,7 +1411,7 @@ func TestDbBase_countSQL(t *testing.T) {
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
qs: querySet{
mi: mi,
cond: cond,
related: make([]string, 0),

View File

@ -226,73 +226,75 @@ func (o querySet) GetCond() *Condition {
}
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
func (o querySet) Count() (int64, error) {
return o.CountWithCtx(context.Background())
}
func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) {
func (o querySet) CountWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
func (o querySet) Exist() bool {
return o.ExistWithCtx(context.Background())
}
func (o *querySet) ExistWithCtx(ctx context.Context) bool {
func (o querySet) ExistWithCtx(ctx context.Context) bool {
cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
func (o querySet) Update(values Params) (int64, error) {
return o.UpdateWithCtx(context.Background(), values)
}
func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
func (o querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, &o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
func (o querySet) Delete() (int64, error) {
return o.DeleteWithCtx(context.Background())
}
func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
func (o querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, &o, o.mi, o.cond, o.orm.alias.TZ)
}
// return an insert queryer.
// PrepareInsert return an insert queryer.
// it can be used in times.
// example:
//
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
func (o querySet) PrepareInsert() (Inserter, error) {
return o.PrepareInsertWithCtx(context.Background())
}
func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
func (o querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
return newInsertSet(ctx, o.orm, o.mi)
}
// query All data and map to containers.
// All query all data and map to containers.
// cols means the Columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
func (o querySet) All(container interface{}, cols ...string) (int64, error) {
return o.AllWithCtx(context.Background(), container, cols...)
}
func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
// AllWithCtx see All
func (o querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// One query one row data and map to containers.
// cols means the Columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
func (o querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...)
}
func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
// OneWithCtx check One
func (o querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error {
o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
@ -308,38 +310,40 @@ func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols .
return nil
}
// query All data and map to []map[string]interface.
// Values query All data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
func (o querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.ValuesWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
// ValuesWithCtx see Values
func (o querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query All data and map to [][]interface
// ValuesList query data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
func (o querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.ValuesListWithCtx(context.Background(), results, exprs...)
}
func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
func (o querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query All data and map to []interface.
// ValuesFlat query all data and map to []interface.
// it's designed for one row record Set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
func (o querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.ValuesFlatWithCtx(context.Background(), result, expr)
}
func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
// ValuesFlatWithCtx see ValuesFlat
func (o querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// query All rows into map[string]interface with specify key and value column name.
// RowsToMap query rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@ -350,11 +354,11 @@ func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, ex
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
func (o querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// query All rows into struct with specify key and value column name.
// RowsToStruct query rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@ -365,7 +369,7 @@ func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, er
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
func (o querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}

View File

@ -623,9 +623,9 @@ type dbQuerier interface {
// base database struct
type dbBaser interface {
Read(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
ReadBatch(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *alias, ...string) (int64, error)