From b2a37fe60e1146146afb9c70dc3362d0b4033eed Mon Sep 17 00:00:00 2001 From: Uzziah <120019273+uzziahlin@users.noreply.github.com> Date: Tue, 12 Sep 2023 00:18:04 +0800 Subject: [PATCH] fix: refactor Count method (#5300) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: refactor Count method and add test * fix: add the change record into the CHANGELOG.md * fix: refactor the readSQL method and let countSQL reuse readSQL method * fix: fix the bug in the construction process of the order by clause * fix: modify the TestCountOrderBy态add the TestCount and TestOrderBy * fix: move the change record in CHANGELOG.md to developing --------- Co-authored-by: Ken --- CHANGELOG.md | 2 +- client/orm/db.go | 87 ++++++++++++++++++++-------------- client/orm/db_tables.go | 4 +- client/orm/db_test.go | 102 ++++++++++++++++++++++++++++++++++++++++ client/orm/orm_test.go | 60 ++++++++++++++++++++++- 5 files changed, 217 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9857d579..6f0333c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # developing -- [orm: move the modelCache to internal/models package](https://github.com/beego/beego/pull/5306) +- [fix: refactor Count method](https://github.com/beego/beego/pull/5300) # v2.1.1 - [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232) diff --git a/client/orm/db.go b/client/orm/db.go index 4c31dbb9..b06016e5 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -1293,7 +1293,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { cols := d.preProcCols(tCols) // pre process columns - return d.readSQL(tables, cols, cond, qs, mi, tz) + + buf := buffers.Get() + defer buffers.Put(buf) + + args := d.readSQL(buf, tables, cols, cond, qs, mi, tz) + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, args } func (d *dbBase) preProcCols(cols []string) []string { @@ -1309,7 +1319,7 @@ func (d *dbBase) preProcCols(cols []string) []string { // readSQL generate a select sql string and return args // ReadBatch and ReadValues methods will reuse this method. -func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { +func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) []interface{} { quote := d.ins.TableQuote() @@ -1320,9 +1330,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs * join := tables.getJoinSQL() specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) - buf := buffers.Get() - defer buffers.Put(buf) - _, _ = buf.WriteString("SELECT ") if qs.distinct { @@ -1372,6 +1379,37 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs * _, _ = buf.WriteString(" FOR UPDATE") } + return args +} + +// Count excute count sql and return count result int64. +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { + + query, args := d.countSQL(qs, mi, cond, tz) + + row := q.QueryRowContext(ctx, query, args...) + err = row.Scan(&cnt) + return +} + +func (d *dbBase) countSQL(qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + buf := buffers.Get() + defer buffers.Put(buf) + + if len(qs.groups) > 0 { + _, _ = buf.WriteString("SELECT COUNT(*) FROM (") + } + + qs.aggregate = "COUNT(*)" + args := d.readSQL(buf, tables, nil, cond, qs, mi, tz) + + if len(qs.groups) > 0 { + _, _ = buf.WriteString(") AS T") + } + query := buf.String() d.ins.ReplaceMarks(&query) @@ -1379,34 +1417,6 @@ func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs * return query, args } -// Count excute count sql and return count result int64. -func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - tables.getOrderSQL(qs.orders) - join := tables.getJoinSQL() - specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) - - Q := d.ins.TableQuote() - - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", - Q, mi.Table, Q, - specifyIndexes, join, where, groupBy) - - if groupBy != "" { - query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) - } - - d.ins.ReplaceMarks(&query) - - row := q.QueryRowContext(ctx, query, args...) - err = row.Scan(&cnt) - return -} - // GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { var sql string @@ -2009,7 +2019,16 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * } func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { - return d.readSQL(tables, cols, cond, qs, mi, tz) + buf := buffers.Get() + defer buffers.Put(buf) + + args := d.readSQL(buf, tables, cols, cond, qs, mi, tz) + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, args } // SupportUpdateJoin flag of update joined record. diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 9d30afb3..c22f0a11 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -439,9 +439,9 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if order.IsRaw() { if len(clause) == 2 { - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s %s", clause[0], clause[1], order.SortString())) } else if len(clause) == 1 { - orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) + orderSqls = append(orderSqls, fmt.Sprintf("%s %s", clause[0], order.SortString())) } else { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) } diff --git a/client/orm/db_test.go b/client/orm/db_test.go index 2b551608..2bb32090 100644 --- a/client/orm/db_test.go +++ b/client/orm/db_test.go @@ -1331,6 +1331,108 @@ func TestDbBase_readValuesSQL(t *testing.T) { } +func TestDbBase_countSQL(t *testing.T) { + + mc := models.NewModelCacheHandler() + + err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2)) + + assert.Nil(t, err) + + mc.Bootstrap() + + mi, ok := mc.GetByMd(new(testTab)) + + assert.True(t, ok) + + cond := NewCondition().And("name", "test_name"). + OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60)) + + tz := time.Local + + testCases := []struct { + name string + db *dbBase + + qs *querySet + + wantRes string + wantArgs []interface{} + }{ + { + name: "count with MySQL has no group by", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + qs: &querySet{ + mi: mi, + cond: cond, + useIndex: 1, + indexes: []string{"name", "score"}, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) ", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with MySQL has group by", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + qs: &querySet{ + mi: mi, + cond: cond, + useIndex: 1, + indexes: []string{"name", "score"}, + related: make([]string, 0), + relDepth: 2, + groups: []string{"name", "age"}, + }, + wantRes: "SELECT COUNT(*) FROM (SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ) AS T", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with PostgreSQL has no group by", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + qs: &querySet{ + mi: mi, + cond: cond, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) `, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with PostgreSQL has group by", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + qs: &querySet{ + mi: mi, + cond: cond, + related: make([]string, 0), + relDepth: 2, + groups: []string{"name", "age"}, + }, + wantRes: `SELECT COUNT(*) FROM (SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ) AS T`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, args := tc.db.countSQL(tc.qs, mi, cond, tz) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantArgs, args) + }) + } +} + type testTab struct { ID int64 `orm:"auto;pk;column(id)"` Name string `orm:"column(name)"` diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 3dced8ff..a371f193 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -1140,7 +1140,10 @@ func TestOffset(t *testing.T) { throwFail(t, AssertIs(num, 2)) } -func TestOrderBy(t *testing.T) { +func TestCountOrderBy(t *testing.T) { + if IsPostgres { + return + } qs := dORM.QueryTable("user") num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() throwFail(t, err) @@ -1175,6 +1178,61 @@ func TestOrderBy(t *testing.T) { } } +func TestOrderBy(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } +} + +func TestCount(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + func TestAll(t *testing.T) { var users []*User qs := dORM.QueryTable("user")