diff --git a/CHANGELOG.md b/CHANGELOG.md index e5c3433d..9857d579 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ - [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295) - [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296) - [fix: refactor ReadBatch method](https://github.com/beego/beego/pull/5298) +- [fix: refactor ReadValues method](https://github.com/beego/beego/pull/5303) ## ORM refactoring - [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) diff --git a/client/orm/db.go b/client/orm/db.go index 8d59fd01..4c31dbb9 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -1292,6 +1292,24 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m } func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { + cols := d.preProcCols(tCols) // pre process columns + return d.readSQL(tables, cols, cond, qs, mi, tz) +} + +func (d *dbBase) preProcCols(cols []string) []string { + res := make([]string, len(cols)) + + quote := d.ins.TableQuote() + for i, col := range cols { + res[i] = fmt.Sprintf("T0.%s%s%s", quote, col, quote) + } + + return res +} + +// readSQL generate a select sql string and return args +// ReadBatch and ReadValues methods will reuse this method. +func (d *dbBase) readSQL(tables *dbTables, tCols []string, cond *Condition, qs *querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { quote := d.ins.TableQuote() @@ -1316,10 +1334,7 @@ func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, if i > 0 { _, _ = buf.WriteString(", ") } - _, _ = buf.WriteString("T0.") - _, _ = buf.WriteString(quote) _, _ = buf.WriteString(tCol) - _, _ = buf.WriteString(quote) } for _, tbl := range tables.tables { @@ -1897,25 +1912,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * } } - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, qs.offset, qs.limit) - join := tables.getJoinSQL() - specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) - - sels := strings.Join(cols, ", ") - - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", - sqlSelect, sels, - Q, mi.Table, Q, - specifyIndexes, join, where, groupBy, orderBy, limit) - - d.ins.ReplaceMarks(&query) + query, args := d.readValuesSQL(tables, cols, qs, mi, cond, tz) rs, err := q.QueryContext(ctx, query, args...) if err != nil { @@ -2011,6 +2008,10 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * return cnt, nil } +func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { + return d.readSQL(tables, cols, cond, qs, mi, tz) +} + // SupportUpdateJoin flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true diff --git a/client/orm/db_test.go b/client/orm/db_test.go index 32e90d17..2b551608 100644 --- a/client/orm/db_test.go +++ b/client/orm/db_test.go @@ -890,8 +890,6 @@ func TestDbBase_InsertOrUpdateSQL(t *testing.T) { func TestDbBase_readBatchSQL(t *testing.T) { - tCols := []string{"name", "score"} - mc := models.NewModelCacheHandler() err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2)) @@ -913,7 +911,8 @@ func TestDbBase_readBatchSQL(t *testing.T) { name string db *dbBase - qs *querySet + tCols []string + qs *querySet wantRes string wantArgs []interface{} @@ -923,6 +922,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBaseMysql(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -948,6 +948,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBaseMysql(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -974,6 +975,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBaseMysql(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1000,6 +1002,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBaseMysql(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1027,6 +1030,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBaseMysql(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1053,6 +1057,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBasePostgres(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1076,6 +1081,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBasePostgres(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1100,6 +1106,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBasePostgres(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1124,6 +1131,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBasePostgres(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1149,6 +1157,7 @@ func TestDbBase_readBatchSQL(t *testing.T) { db: &dbBase{ ins: newdbBasePostgres(), }, + tCols: []string{"name", "score"}, qs: &querySet{ mi: mi, cond: cond, @@ -1175,7 +1184,145 @@ func TestDbBase_readBatchSQL(t *testing.T) { tables := newDbTables(mi, tc.db.ins) tables.parseRelated(tc.qs.related, tc.qs.relDepth) - res, args := tc.db.readBatchSQL(tables, tCols, cond, tc.qs, mi, tz) + res, args := tc.db.readBatchSQL(tables, tc.tCols, cond, tc.qs, mi, tz) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantArgs, args) + }) + } + +} + +func TestDbBase_readValuesSQL(t *testing.T) { + + mc := models.NewModelCacheHandler() + + err := mc.Register("", false, new(testTab), new(testTab1), new(testTab2)) + + assert.Nil(t, err) + + mc.Bootstrap() + + mi, ok := mc.GetByMd(new(testTab)) + + assert.True(t, ok) + + cond := NewCondition().And("name", "test_name"). + OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60)) + + tz := time.Local + + testCases := []struct { + name string + db *dbBase + + cols []string + qs *querySet + + wantRes string + wantArgs []interface{} + }{ + { + name: "read values with MySQL", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"}, + qs: &querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + }, + wantRes: "SELECT T0.`name` name, T0.`age` age, T0.`score` score FROM `test_tab` T0 USE INDEX(`name`,`score`) WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with MySQL and distinct", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"}, + qs: &querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + distinct: true, + }, + wantRes: "SELECT DISTINCT T0.`name` name, T0.`age` age, T0.`score` score FROM `test_tab` T0 USE INDEX(`name`,`score`) WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with PostgreSQL", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`}, + qs: &querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + }, + wantRes: `SELECT T0."name" name, T0."age" age, T0."score" score FROM "test_tab" T0 WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with PostgreSQL and distinct", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`}, + qs: &querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + distinct: true, + }, + wantRes: `SELECT DISTINCT T0."name" name, T0."age" age, T0."score" score FROM "test_tab" T0 WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tables := newDbTables(mi, tc.db.ins) + + res, args := tc.db.readValuesSQL(tables, tc.cols, tc.qs, mi, cond, tz) assert.Equal(t, tc.wantRes, res) assert.Equal(t, tc.wantArgs, args)