fix: refactor Count method (#5300)

* fix: refactor Count method and add test

* fix: add the change record into the CHANGELOG.md

* fix: refactor the readSQL method and let countSQL reuse readSQL method

* fix: fix the bug in the construction process of the order by clause

* fix: modify the TestCountOrderBy、add the TestCount and TestOrderBy

* fix: move the change record in CHANGELOG.md to developing

---------

Co-authored-by: Ken <azai8599@163.com>
This commit is contained in:
Uzziah 2023-09-12 00:18:04 +08:00 committed by GitHub
parent 4eea71f1d7
commit b2a37fe60e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 217 additions and 38 deletions

View File

@ -1,5 +1,5 @@
# developing # developing
- [orm: move the modelCache to internal/models package](https://github.com/beego/beego/pull/5306) - [fix: refactor Count method](https://github.com/beego/beego/pull/5300)
# v2.1.1 # v2.1.1
- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232) - [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232)

View File

@ -1293,7 +1293,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
cols := d.preProcCols(tCols) // pre process columns cols := d.preProcCols(tCols) // pre process columns
return d.readSQL(tables, cols, cond, qs, mi, tz)
buf := buffers.Get()
defer buffers.Put(buf)
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
query := buf.String()
d.ins.ReplaceMarks(&query)
return query, args
} }
func (d *dbBase) preProcCols(cols []string) []string { func (d *dbBase) preProcCols(cols []string) []string {
@ -1309,7 +1319,7 @@ func (d *dbBase) preProcCols(cols []string) []string {
// readSQL generate a select sql string and return args // readSQL generate a select sql string and return args
// ReadBatch and ReadValues methods will reuse this method. // ReadBatch and ReadValues methods will reuse this method.
func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) []interface{} {
quote := d.ins.TableQuote() quote := d.ins.TableQuote()
@ -1320,9 +1330,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
buf := buffers.Get()
defer buffers.Put(buf)
_, _ = buf.WriteString("SELECT ") _, _ = buf.WriteString("SELECT ")
if qs.distinct { if qs.distinct {
@ -1372,6 +1379,37 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
_, _ = buf.WriteString(" FOR UPDATE") _, _ = buf.WriteString(" FOR UPDATE")
} }
return args
}
// Count excute count sql and return count result int64.
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
query, args := d.countSQL(qs, mi, cond, tz)
row := q.QueryRowContext(ctx, query, args...)
err = row.Scan(&cnt)
return
}
func (d *dbBase) countSQL(qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
buf := buffers.Get()
defer buffers.Put(buf)
if len(qs.groups) > 0 {
_, _ = buf.WriteString("SELECT COUNT(*) FROM (")
}
qs.aggregate = "COUNT(*)"
args := d.readSQL(buf, tables, nil, cond, qs, mi, tz)
if len(qs.groups) > 0 {
_, _ = buf.WriteString(") AS T")
}
query := buf.String() query := buf.String()
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1379,34 +1417,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
return query, args return query, args
} }
// Count excute count sql and return count result int64.
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
Q := d.ins.TableQuote()
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
Q, mi.Table, Q,
specifyIndexes, join, where, groupBy)
if groupBy != "" {
query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
}
d.ins.ReplaceMarks(&query)
row := q.QueryRowContext(ctx, query, args...)
err = row.Scan(&cnt)
return
}
// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. // GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
var sql string var sql string
@ -2009,7 +2019,16 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
} }
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
return d.readSQL(tables, cols, cond, qs, mi, tz) buf := buffers.Get()
defer buffers.Put(buf)
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
query := buf.String()
d.ins.ReplaceMarks(&query)
return query, args
} }
// SupportUpdateJoin flag of update joined record. // SupportUpdateJoin flag of update joined record.

View File

@ -439,9 +439,9 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
if order.IsRaw() { if order.IsRaw() {
if len(clause) == 2 { if len(clause) == 2 {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) orderSqls = append(orderSqls, fmt.Sprintf("%s.%s %s", clause[0], clause[1], order.SortString()))
} else if len(clause) == 1 { } else if len(clause) == 1 {
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) orderSqls = append(orderSqls, fmt.Sprintf("%s %s", clause[0], order.SortString()))
} else { } else {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
} }

View File

@ -1331,6 +1331,108 @@ func TestDbBase_readValuesSQL(t *testing.T) {
} }
func TestDbBase_countSQL(t *testing.T) {
mc := models.NewModelCacheHandler()
err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2))
assert.Nil(t, err)
mc.Bootstrap()
mi, ok := mc.GetByMd(new(testTab))
assert.True(t, ok)
cond := NewCondition().And("name", "test_name").
OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60))
tz := time.Local
testCases := []struct {
name string
db *dbBase
qs *querySet
wantRes string
wantArgs []interface{}
}{
{
name: "count with MySQL has no group by",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
useIndex: 1,
indexes: []string{"name", "score"},
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) ",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "count with MySQL has group by",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
useIndex: 1,
indexes: []string{"name", "score"},
related: make([]string, 0),
relDepth: 2,
groups: []string{"name", "age"},
},
wantRes: "SELECT COUNT(*) FROM (SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ) AS T",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "count with PostgreSQL has no group by",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) `,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "count with PostgreSQL has group by",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
related: make([]string, 0),
relDepth: 2,
groups: []string{"name", "age"},
},
wantRes: `SELECT COUNT(*) FROM (SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ) AS T`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, args := tc.db.countSQL(tc.qs, mi, cond, tz)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantArgs, args)
})
}
}
type testTab struct { type testTab struct {
ID int64 `orm:"auto;pk;column(id)"` ID int64 `orm:"auto;pk;column(id)"`
Name string `orm:"column(name)"` Name string `orm:"column(name)"`

View File

@ -1140,7 +1140,10 @@ func TestOffset(t *testing.T) {
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
} }
func TestOrderBy(t *testing.T) { func TestCountOrderBy(t *testing.T) {
if IsPostgres {
return
}
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
throwFail(t, err) throwFail(t, err)
@ -1175,6 +1178,61 @@ func TestOrderBy(t *testing.T) {
} }
} }
func TestOrderBy(t *testing.T) {
var users []*User
qs := dORM.QueryTable("user")
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.OrderBy("status").Filter("user_name", "slene").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`profile__age`),
order_clause.SortDescending(),
),
).Filter("user_name", "astaxie").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
if IsMysql {
num, err = qs.OrderClauses(
order_clause.Clause(
order_clause.Column(`rand()`),
order_clause.Raw(),
),
).Filter("user_name", "astaxie").All(&users)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
}
func TestCount(t *testing.T) {
qs := dORM.QueryTable("user")
num, err := qs.Filter("user_name", "nobody").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "astaxie").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestAll(t *testing.T) { func TestAll(t *testing.T) {
var users []*User var users []*User
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")