fix: refactor ReadBatch method (#5298)

This commit is contained in:
Uzziah 2023-08-23 18:35:17 +08:00 committed by GitHub
parent 46a00d3592
commit e9d3357643
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 400 additions and 33 deletions

View File

@ -16,6 +16,7 @@
- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274)
- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295)
- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296)
- [fix: refactor ReadBatch method](https://github.com/beego/beego/pull/5298)
## ORM refactoring
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)

View File

@ -1127,11 +1127,6 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
RegisterModel(container)
}
rlimit := qs.limit
offset := qs.offset
Q := d.ins.TableQuote()
var tCols []string
if len(cols) > 0 {
hasRel := len(qs.related) > 0 || qs.relDepth > 0
@ -1163,44 +1158,18 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
tCols = mi.Fields.DBcols
}
colsNum := len(tCols)
sep := fmt.Sprintf("%s, T0.%s", Q, Q)
sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
colsNum := len(tCols)
for _, tbl := range tables.tables {
if tbl.sel {
colsNum += len(tbl.mi.Fields.DBcols)
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.Fields.DBcols, sep), Q)
}
}
sqlSelect := "SELECT"
if qs.distinct {
sqlSelect += " DISTINCT"
}
if qs.aggregate != "" {
sels = qs.aggregate
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.Table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit)
if qs.forUpdate {
query += " FOR UPDATE"
}
d.ins.ReplaceMarks(&query)
query, args := d.readBatchSQL(tables, tCols, cond, qs, mi, tz)
rs, err := q.QueryContext(ctx, query, args...)
if err != nil {
@ -1322,6 +1291,79 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
return cnt, nil
}
func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) {
quote := d.ins.TableQuote()
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
buf := buffers.Get()
defer buffers.Put(buf)
_, _ = buf.WriteString("SELECT ")
if qs.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if qs.aggregate == "" {
for i, tCol := range tCols {
if i > 0 {
_, _ = buf.WriteString(", ")
}
_, _ = buf.WriteString("T0.")
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(tCol)
_, _ = buf.WriteString(quote)
}
for _, tbl := range tables.tables {
if tbl.sel {
_, _ = buf.WriteString(", ")
for i, DBcol := range tbl.mi.Fields.DBcols {
if i > 0 {
_, _ = buf.WriteString(", ")
}
_, _ = buf.WriteString(tbl.index)
_, _ = buf.WriteString(".")
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(DBcol)
_, _ = buf.WriteString(quote)
}
}
}
} else {
_, _ = buf.WriteString(qs.aggregate)
}
_, _ = buf.WriteString(" FROM ")
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(mi.Table)
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(" T0 ")
_, _ = buf.WriteString(specifyIndexes)
_, _ = buf.WriteString(join)
_, _ = buf.WriteString(where)
_, _ = buf.WriteString(groupBy)
_, _ = buf.WriteString(orderBy)
_, _ = buf.WriteString(limit)
if qs.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
}
query := buf.String()
d.ins.ReplaceMarks(&query)
return query, args
}
// Count excute count sql and return count result int64.
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)

View File

@ -16,8 +16,10 @@ package orm
import (
"errors"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/internal/buffers"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -886,3 +888,325 @@ func TestDbBase_InsertOrUpdateSQL(t *testing.T) {
}
}
func TestDbBase_readBatchSQL(t *testing.T) {
tCols := []string{"name", "score"}
mc := &modelCache{
cache: make(map[string]*models.ModelInfo),
cacheByFullName: make(map[string]*models.ModelInfo),
}
err := mc.register("", false, new(testTab), new(testTab1), new(testTab2))
assert.Nil(t, err)
mc.bootstrap()
mi, ok := mc.getByMd(new(testTab))
assert.True(t, ok)
cond := NewCondition().And("name", "test_name").
OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60))
tz := time.Local
testCases := []struct {
name string
db *dbBase
qs *querySet
wantRes string
wantArgs []interface{}
}{
{
name: "read batch with MySQL",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
useIndex: 1,
indexes: []string{"name", "score"},
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with MySQL and distinct",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
useIndex: 1,
indexes: []string{"name", "score"},
distinct: true,
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT DISTINCT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with MySQL and aggregate",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
useIndex: 1,
indexes: []string{"name", "score"},
aggregate: "sum(`T0`.`score`), count(`T1`.`name_1`)",
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT sum(`T0`.`score`), count(`T1`.`name_1`) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with MySQL and distinct and aggregate",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
useIndex: 1,
indexes: []string{"name", "score"},
distinct: true,
aggregate: "sum(`T0`.`score`), count(`T1`.`name_1`)",
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT DISTINCT sum(`T0`.`score`), count(`T1`.`name_1`) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with MySQL and for update",
db: &dbBase{
ins: newdbBaseMysql(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
useIndex: 1,
indexes: []string{"name", "score"},
forUpdate: true,
related: make([]string, 0),
relDepth: 2,
},
wantRes: "SELECT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100 FOR UPDATE",
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with PostgreSQL",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with PostgreSQL and distinct",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
distinct: true,
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT DISTINCT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with PostgreSQL and aggregate",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
aggregate: `sum("T0"."score"), count("T1"."name_1")`,
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT sum("T0"."score"), count("T1"."name_1") FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with PostgreSQL and distinct and aggregate",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
distinct: true,
aggregate: `sum("T0"."score"), count("T1"."name_1")`,
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT DISTINCT sum("T0"."score"), count("T1"."name_1") FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
{
name: "read batch with PostgreSQL and for update",
db: &dbBase{
ins: newdbBasePostgres(),
},
qs: &querySet{
mi: mi,
cond: cond,
limit: 10,
offset: 100,
groups: []string{"name", "age"},
orders: []*order_clause.Order{
order_clause.Clause(order_clause.Column("score"),
order_clause.SortDescending()),
order_clause.Clause(order_clause.Column("age"),
order_clause.SortAscending()),
},
forUpdate: true,
related: make([]string, 0),
relDepth: 2,
},
wantRes: `SELECT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100 FOR UPDATE`,
wantArgs: []interface{}{"test_name", int64(18), int64(60)},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tables := newDbTables(mi, tc.db.ins)
tables.parseRelated(tc.qs.related, tc.qs.relDepth)
res, args := tc.db.readBatchSQL(tables, tCols, cond, tc.qs, mi, tz)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantArgs, args)
})
}
}
type testTab struct {
ID int64 `orm:"auto;pk;column(id)"`
Name string `orm:"column(name)"`
Age int64 `orm:"column(age)"`
Score int64 `orm:"column(score)"`
TestTab1 *testTab1 `orm:"rel(fk);column(test_tab_1_id)"`
}
type testTab1 struct {
ID int64 `orm:"auto;pk;column(id)"`
Name1 string `orm:"column(name_1)"`
Age1 int64 `orm:"column(age_1)"`
Score1 int64 `orm:"column(score_1)"`
TestTab2 *testTab2 `orm:"rel(fk);column(test_tab_2_id)"`
}
type testTab2 struct {
ID int64 `orm:"auto;pk;column(id)"`
Name2 int64 `orm:"column(name_2)"`
Age2 int64 `orm:"column(age_2)"`
Score2 int64 `orm:"column(score_2)"`
}