fix: refactor Count method (#5300)
* fix: refactor Count method and add test * fix: add the change record into the CHANGELOG.md * fix: refactor the readSQL method and let countSQL reuse readSQL method * fix: fix the bug in the construction process of the order by clause * fix: modify the TestCountOrderBy、add the TestCount and TestOrderBy * fix: move the change record in CHANGELOG.md to developing --------- Co-authored-by: Ken <azai8599@163.com>
This commit is contained in:
parent
4eea71f1d7
commit
b2a37fe60e
@ -1,5 +1,5 @@
|
||||
# developing
|
||||
- [orm: move the modelCache to internal/models package](https://github.com/beego/beego/pull/5306)
|
||||
- [fix: refactor Count method](https://github.com/beego/beego/pull/5300)
|
||||
|
||||
# v2.1.1
|
||||
- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232)
|
||||
|
||||
@ -1293,7 +1293,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
|
||||
|
||||
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
|
||||
cols := d.preProcCols(tCols) // pre process columns
|
||||
return d.readSQL(tables, cols, cond, qs, mi, tz)
|
||||
|
||||
buf := buffers.Get()
|
||||
defer buffers.Put(buf)
|
||||
|
||||
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
|
||||
|
||||
query := buf.String()
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
return query, args
|
||||
}
|
||||
|
||||
func (d *dbBase) preProcCols(cols []string) []string {
|
||||
@ -1309,7 +1319,7 @@ func (d *dbBase) preProcCols(cols []string) []string {
|
||||
|
||||
// readSQL generate a select sql string and return args
|
||||
// ReadBatch and ReadValues methods will reuse this method.
|
||||
func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
|
||||
func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) []interface{} {
|
||||
|
||||
quote := d.ins.TableQuote()
|
||||
|
||||
@ -1320,9 +1330,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
||||
join := tables.getJoinSQL()
|
||||
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
|
||||
|
||||
buf := buffers.Get()
|
||||
defer buffers.Put(buf)
|
||||
|
||||
_, _ = buf.WriteString("SELECT ")
|
||||
|
||||
if qs.distinct {
|
||||
@ -1372,6 +1379,37 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
||||
_, _ = buf.WriteString(" FOR UPDATE")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// Count excute count sql and return count result int64.
|
||||
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||
|
||||
query, args := d.countSQL(qs, mi, cond, tz)
|
||||
|
||||
row := q.QueryRowContext(ctx, query, args...)
|
||||
err = row.Scan(&cnt)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *dbBase) countSQL(qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
|
||||
buf := buffers.Get()
|
||||
defer buffers.Put(buf)
|
||||
|
||||
if len(qs.groups) > 0 {
|
||||
_, _ = buf.WriteString("SELECT COUNT(*) FROM (")
|
||||
}
|
||||
|
||||
qs.aggregate = "COUNT(*)"
|
||||
args := d.readSQL(buf, tables, nil, cond, qs, mi, tz)
|
||||
|
||||
if len(qs.groups) > 0 {
|
||||
_, _ = buf.WriteString(") AS T")
|
||||
}
|
||||
|
||||
query := buf.String()
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
@ -1379,34 +1417,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *
|
||||
return query, args
|
||||
}
|
||||
|
||||
// Count excute count sql and return count result int64.
|
||||
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
|
||||
where, args := tables.getCondSQL(cond, false, tz)
|
||||
groupBy := tables.getGroupSQL(qs.groups)
|
||||
tables.getOrderSQL(qs.orders)
|
||||
join := tables.getJoinSQL()
|
||||
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
|
||||
Q, mi.Table, Q,
|
||||
specifyIndexes, join, where, groupBy)
|
||||
|
||||
if groupBy != "" {
|
||||
query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
|
||||
}
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
row := q.QueryRowContext(ctx, query, args...)
|
||||
err = row.Scan(&cnt)
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values.
|
||||
func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||
var sql string
|
||||
@ -2009,7 +2019,16 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
|
||||
}
|
||||
|
||||
func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) {
|
||||
return d.readSQL(tables, cols, cond, qs, mi, tz)
|
||||
buf := buffers.Get()
|
||||
defer buffers.Put(buf)
|
||||
|
||||
args := d.readSQL(buf, tables, cols, cond, qs, mi, tz)
|
||||
|
||||
query := buf.String()
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
return query, args
|
||||
}
|
||||
|
||||
// SupportUpdateJoin flag of update joined record.
|
||||
|
||||
@ -439,9 +439,9 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
|
||||
|
||||
if order.IsRaw() {
|
||||
if len(clause) == 2 {
|
||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
|
||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s %s", clause[0], clause[1], order.SortString()))
|
||||
} else if len(clause) == 1 {
|
||||
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
|
||||
orderSqls = append(orderSqls, fmt.Sprintf("%s %s", clause[0], order.SortString()))
|
||||
} else {
|
||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
|
||||
}
|
||||
|
||||
@ -1331,6 +1331,108 @@ func TestDbBase_readValuesSQL(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestDbBase_countSQL(t *testing.T) {
|
||||
|
||||
mc := models.NewModelCacheHandler()
|
||||
|
||||
err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2))
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
mc.Bootstrap()
|
||||
|
||||
mi, ok := mc.GetByMd(new(testTab))
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
cond := NewCondition().And("name", "test_name").
|
||||
OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60))
|
||||
|
||||
tz := time.Local
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
db *dbBase
|
||||
|
||||
qs *querySet
|
||||
|
||||
wantRes string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "count with MySQL has no group by",
|
||||
db: &dbBase{
|
||||
ins: newdbBaseMysql(),
|
||||
},
|
||||
qs: &querySet{
|
||||
mi: mi,
|
||||
cond: cond,
|
||||
useIndex: 1,
|
||||
indexes: []string{"name", "score"},
|
||||
related: make([]string, 0),
|
||||
relDepth: 2,
|
||||
},
|
||||
wantRes: "SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) ",
|
||||
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||
},
|
||||
{
|
||||
name: "count with MySQL has group by",
|
||||
db: &dbBase{
|
||||
ins: newdbBaseMysql(),
|
||||
},
|
||||
qs: &querySet{
|
||||
mi: mi,
|
||||
cond: cond,
|
||||
useIndex: 1,
|
||||
indexes: []string{"name", "score"},
|
||||
related: make([]string, 0),
|
||||
relDepth: 2,
|
||||
groups: []string{"name", "age"},
|
||||
},
|
||||
wantRes: "SELECT COUNT(*) FROM (SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ) AS T",
|
||||
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||
},
|
||||
{
|
||||
name: "count with PostgreSQL has no group by",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
qs: &querySet{
|
||||
mi: mi,
|
||||
cond: cond,
|
||||
related: make([]string, 0),
|
||||
relDepth: 2,
|
||||
},
|
||||
wantRes: `SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) `,
|
||||
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||
},
|
||||
{
|
||||
name: "count with PostgreSQL has group by",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
qs: &querySet{
|
||||
mi: mi,
|
||||
cond: cond,
|
||||
related: make([]string, 0),
|
||||
relDepth: 2,
|
||||
groups: []string{"name", "age"},
|
||||
},
|
||||
wantRes: `SELECT COUNT(*) FROM (SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ) AS T`,
|
||||
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res, args := tc.db.countSQL(tc.qs, mi, cond, tz)
|
||||
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantArgs, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testTab struct {
|
||||
ID int64 `orm:"auto;pk;column(id)"`
|
||||
Name string `orm:"column(name)"`
|
||||
|
||||
@ -1140,7 +1140,10 @@ func TestOffset(t *testing.T) {
|
||||
throwFail(t, AssertIs(num, 2))
|
||||
}
|
||||
|
||||
func TestOrderBy(t *testing.T) {
|
||||
func TestCountOrderBy(t *testing.T) {
|
||||
if IsPostgres {
|
||||
return
|
||||
}
|
||||
qs := dORM.QueryTable("user")
|
||||
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
|
||||
throwFail(t, err)
|
||||
@ -1175,6 +1178,61 @@ func TestOrderBy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderBy(t *testing.T) {
|
||||
var users []*User
|
||||
qs := dORM.QueryTable("user")
|
||||
num, err := qs.OrderBy("-status").Filter("user_name", "nobody").All(&users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.OrderBy("status").Filter("user_name", "slene").All(&users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").All(&users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.OrderClauses(
|
||||
order_clause.Clause(
|
||||
order_clause.Column(`profile__age`),
|
||||
order_clause.SortDescending(),
|
||||
),
|
||||
).Filter("user_name", "astaxie").All(&users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
if IsMysql {
|
||||
num, err = qs.OrderClauses(
|
||||
order_clause.Clause(
|
||||
order_clause.Column(`rand()`),
|
||||
order_clause.Raw(),
|
||||
),
|
||||
).Filter("user_name", "astaxie").All(&users)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
qs := dORM.QueryTable("user")
|
||||
num, err := qs.Filter("user_name", "nobody").Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.Filter("user_name", "slene").Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.Filter("user_name", "astaxie").Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.Filter("user_name", "astaxie").Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
var users []*User
|
||||
qs := dORM.QueryTable("user")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user