fix: refactor Count method (#5300)
* fix: refactor Count method and add test * fix: add the change record into the CHANGELOG.md * fix: refactor the readSQL method and let countSQL reuse readSQL method * fix: fix the bug in the construction process of the order by clause * fix: modify the TestCountOrderBy、add the TestCount and TestOrderBy * fix: move the change record in CHANGELOG.md to developing --------- Co-authored-by: Ken <azai8599@163.com>
This commit is contained in:
parent
4eea71f1d7
commit
b2a37fe60e
@ -1,5 +1,5 @@
|
|||||||
# developing
|
# developing
|
||||||
- [orm: move the modelCache to internal/models package](https://github.com/beego/beego/pull/5306)
|
- [fix: refactor Count method](https://github.com/beego/beego/pull/5300)
|
||||||
|
|
||||||
# v2.1.1
|
# v2.1.1
|
||||||
- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232)
|
- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232)
|
||||||
|
|||||||
@ -1293,7 +1293,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
|
|||||||
|
|
||||||
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
|
||||||
cols := d.preProcCols(tCols) // pre process columns
|
cols := d.preProcCols(tCols) // pre process columns
|
||||||
return d.readSQL(tables, cols, cond, qs, mi, tz)
|
|
||||||
|
buf := buffers.Get()
|
||||||
|
defer buffers.Put(buf)
|
||||||
|
|
||||||
|
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
|
||||||
|
|
||||||
|
query := buf.String()
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
return query, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) preProcCols(cols []string) []string {
|
func (d *dbBase) preProcCols(cols []string) []string {
|
||||||
@ -1309,7 +1319,7 @@ func (d *dbBase) preProcCols(cols []string) []string {
|
|||||||
|
|
||||||
// readSQL generate a select sql string and return args
|
// readSQL generate a select sql string and return args
|
||||||
// ReadBatch and ReadValues methods will reuse this method.
|
// ReadBatch and ReadValues methods will reuse this method.
|
||||||
func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) []interface{} {
|
||||||
|
|
||||||
quote := d.ins.TableQuote()
|
quote := d.ins.TableQuote()
|
||||||
|
|
||||||
@ -1320,9 +1330,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
|||||||
join := tables.getJoinSQL()
|
join := tables.getJoinSQL()
|
||||||
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
|
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
|
||||||
|
|
||||||
buf := buffers.Get()
|
|
||||||
defer buffers.Put(buf)
|
|
||||||
|
|
||||||
_, _ = buf.WriteString("SELECT ")
|
_, _ = buf.WriteString("SELECT ")
|
||||||
|
|
||||||
if qs.distinct {
|
if qs.distinct {
|
||||||
@ -1372,6 +1379,37 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
|||||||
_, _ = buf.WriteString(" FOR UPDATE")
|
_, _ = buf.WriteString(" FOR UPDATE")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count excute count sql and return count result int64.
|
||||||
|
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||||
|
|
||||||
|
query, args := d.countSQL(qs, mi, cond, tz)
|
||||||
|
|
||||||
|
row := q.QueryRowContext(ctx, query, args...)
|
||||||
|
err = row.Scan(&cnt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) countSQL(qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
|
||||||
|
tables := newDbTables(mi, d.ins)
|
||||||
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
|
|
||||||
|
buf := buffers.Get()
|
||||||
|
defer buffers.Put(buf)
|
||||||
|
|
||||||
|
if len(qs.groups) > 0 {
|
||||||
|
_, _ = buf.WriteString("SELECT COUNT(*) FROM (")
|
||||||
|
}
|
||||||
|
|
||||||
|
qs.aggregate = "COUNT(*)"
|
||||||
|
args := d.readSQL(buf, tables, nil, cond, qs, mi, tz)
|
||||||
|
|
||||||
|
if len(qs.groups) > 0 {
|
||||||
|
_, _ = buf.WriteString(") AS T")
|
||||||
|
}
|
||||||
|
|
||||||
query := buf.String()
|
query := buf.String()
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
@ -1379,34 +1417,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
|||||||
return query, args
|
return query, args
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count excute count sql and return count result int64.
|
|
||||||
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
|
||||||
tables := newDbTables(mi, d.ins)
|
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
|
||||||
|
|
||||||
where, args := tables.getCondSQL(cond, false, tz)
|
|
||||||
groupBy := tables.getGroupSQL(qs.groups)
|
|
||||||
tables.getOrderSQL(qs.orders)
|
|
||||||
join := tables.getJoinSQL()
|
|
||||||
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
|
|
||||||
|
|
||||||
Q := d.ins.TableQuote()
|
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
|
|
||||||
Q, mi.Table, Q,
|
|
||||||
specifyIndexes, join, where, groupBy)
|
|
||||||
|
|
||||||
if groupBy != "" {
|
|
||||||
query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
|
||||||
|
|
||||||
row := q.QueryRowContext(ctx, query, args...)
|
|
||||||
err = row.Scan(&cnt)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values.
|
// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values.
|
||||||
func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||||
var sql string
|
var sql string
|
||||||
@ -2009,7 +2019,16 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
|
||||||
return d.readSQL(tables, cols, cond, qs, mi, tz)
|
buf := buffers.Get()
|
||||||
|
defer buffers.Put(buf)
|
||||||
|
|
||||||
|
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
|
||||||
|
|
||||||
|
query := buf.String()
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
return query, args
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportUpdateJoin flag of update joined record.
|
// SupportUpdateJoin flag of update joined record.
|
||||||
|
|||||||
@ -439,9 +439,9 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
|
|||||||
|
|
||||||
if order.IsRaw() {
|
if order.IsRaw() {
|
||||||
if len(clause) == 2 {
|
if len(clause) == 2 {
|
||||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
|
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s %s", clause[0], clause[1], order.SortString()))
|
||||||
} else if len(clause) == 1 {
|
} else if len(clause) == 1 {
|
||||||
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
|
orderSqls = append(orderSqls, fmt.Sprintf("%s %s", clause[0], order.SortString()))
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1331,6 +1331,108 @@ func TestDbBase_readValuesSQL(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDbBase_countSQL(t *testing.T) {
|
||||||
|
|
||||||
|
mc := models.NewModelCacheHandler()
|
||||||
|
|
||||||
|
err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2))
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
mc.Bootstrap()
|
||||||
|
|
||||||
|
mi, ok := mc.GetByMd(new(testTab))
|
||||||
|
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
cond := NewCondition().And("name", "test_name").
|
||||||
|
OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60))
|
||||||
|
|
||||||
|
tz := time.Local
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
db *dbBase
|
||||||
|
|
||||||
|
qs *querySet
|
||||||
|
|
||||||
|
wantRes string
|
||||||
|
wantArgs []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "count with MySQL has no group by",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBaseMysql(),
|
||||||
|
},
|
||||||
|
qs: &querySet{
|
||||||
|
mi: mi,
|
||||||
|
cond: cond,
|
||||||
|
useIndex: 1,
|
||||||
|
indexes: []string{"name", "score"},
|
||||||
|
related: make([]string, 0),
|
||||||
|
relDepth: 2,
|
||||||
|
},
|
||||||
|
wantRes: "SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) ",
|
||||||
|
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "count with MySQL has group by",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBaseMysql(),
|
||||||
|
},
|
||||||
|
qs: &querySet{
|
||||||
|
mi: mi,
|
||||||
|
cond: cond,
|
||||||
|
useIndex: 1,
|
||||||
|
indexes: []string{"name", "score"},
|
||||||
|
related: make([]string, 0),
|
||||||
|
relDepth: 2,
|
||||||
|
groups: []string{"name", "age"},
|
||||||
|
},
|
||||||
|
wantRes: "SELECT COUNT(*) FROM (SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ) AS T",
|
||||||
|
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "count with PostgreSQL has no group by",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
qs: &querySet{
|
||||||
|
mi: mi,
|
||||||
|
cond: cond,
|
||||||
|
related: make([]string, 0),
|
||||||
|
relDepth: 2,
|
||||||
|
},
|
||||||
|
wantRes: `SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) `,
|
||||||
|
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "count with PostgreSQL has group by",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
qs: &querySet{
|
||||||
|
mi: mi,
|
||||||
|
cond: cond,
|
||||||
|
related: make([]string, 0),
|
||||||
|
relDepth: 2,
|
||||||
|
groups: []string{"name", "age"},
|
||||||
|
},
|
||||||
|
wantRes: `SELECT COUNT(*) FROM (SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ) AS T`,
|
||||||
|
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
res, args := tc.db.countSQL(tc.qs, mi, cond, tz)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantRes, res)
|
||||||
|
assert.Equal(t, tc.wantArgs, args)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type testTab struct {
|
type testTab struct {
|
||||||
ID int64 `orm:"auto;pk;column(id)"`
|
ID int64 `orm:"auto;pk;column(id)"`
|
||||||
Name string `orm:"column(name)"`
|
Name string `orm:"column(name)"`
|
||||||
|
|||||||
@ -1140,7 +1140,10 @@ func TestOffset(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(num, 2))
|
throwFail(t, AssertIs(num, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOrderBy(t *testing.T) {
|
func TestCountOrderBy(t *testing.T) {
|
||||||
|
if IsPostgres {
|
||||||
|
return
|
||||||
|
}
|
||||||
qs := dORM.QueryTable("user")
|
qs := dORM.QueryTable("user")
|
||||||
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
|
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -1175,6 +1178,61 @@ func TestOrderBy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOrderBy(t *testing.T) {
|
||||||
|
var users []*User
|
||||||
|
qs := dORM.QueryTable("user")
|
||||||
|
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").All(&users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.OrderBy("status").Filter("user_name", "slene").All(&users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").All(&users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`profile__age`),
|
||||||
|
order_clause.SortDescending(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").All(&users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
if IsMysql {
|
||||||
|
num, err = qs.OrderClauses(
|
||||||
|
order_clause.Clause(
|
||||||
|
order_clause.Column(`rand()`),
|
||||||
|
order_clause.Raw(),
|
||||||
|
),
|
||||||
|
).Filter("user_name", "astaxie").All(&users)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
qs := dORM.QueryTable("user")
|
||||||
|
num, err := qs.Filter("user_name", "nobody").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.Filter("user_name", "slene").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.Filter("user_name", "astaxie").Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
}
|
||||||
|
|
||||||
func TestAll(t *testing.T) {
|
func TestAll(t *testing.T) {
|
||||||
var users []*User
|
var users []*User
|
||||||
qs := dORM.QueryTable("user")
|
qs := dORM.QueryTable("user")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user