From 24b41552c53153b47c27d2f92f4b0f52d52a124a Mon Sep 17 00:00:00 2001 From: Uzziah <120019273+uzziahlin@users.noreply.github.com> Date: Mon, 10 Jul 2023 21:50:44 +0800 Subject: [PATCH] fix: refactor update sql (#5274) * fix: refactor UpdateSQL method and add test * fix: add modify record into CHANGELOG * fix: modify url in the CHANGELOG * fix: modify pr url in the CHANGELOG --- CHANGELOG.md | 1 + client/orm/db.go | 43 ++++++++++++++++++++++++++++++++++-------- client/orm/db_test.go | 44 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9eede89d..4c8ebfd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - [fix: modify InsertOrUpdate method, Remove the isMulti variable and its associated code](https://github.com/beego/beego/pull/5269) - [refactor cache/redis: Use redisConfig to receive incoming JSON (previously using a map)](https://github.com/beego/beego/pull/5268) - [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271) +- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274) ## ORM refactoring - [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) diff --git a/client/orm/db.go b/client/orm/db.go index fdecc478..8bb4bec4 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -674,14 +674,7 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, setValues = append(setValues, pkValue) - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ?, %s", Q, Q) - setColumns := strings.Join(setNames, sep) - - query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.Table, Q, Q, setColumns, Q, Q, pkName, Q) - - d.ins.ReplaceMarks(&query) + query := d.UpdateSQL(setNames, pkName, mi) res, err := q.ExecContext(ctx, query, setValues...) if err == nil { @@ -690,6 +683,40 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, return 0, err } +func (d *dbBase) UpdateSQL(setNames []string, pkName string, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + + Q := d.ins.TableQuote() + + _, _ = buf.WriteString("UPDATE ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" SET ") + + for i, name := range setNames { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + } + + _, _ = buf.WriteString(" WHERE ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(pkName) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + + query := buf.String() + d.ins.ReplaceMarks(&query) + + return query +} + // Delete execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { diff --git a/client/orm/db_test.go b/client/orm/db_test.go index 60624f05..43fa3798 100644 --- a/client/orm/db_test.go +++ b/client/orm/db_test.go @@ -129,6 +129,50 @@ func TestDbBase_InsertValueSQL(t *testing.T) { } } +func TestDbBase_UpdateSQL(t *testing.T) { + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + + setNames []string + pkName string + + wantRes string + }{ + { + name: "update by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + setNames: []string{"name", "age", "sender"}, + pkName: "id", + wantRes: "UPDATE `test_table` SET `name` = ?, `age` = ?, `sender` = ? WHERE `id` = ?", + }, + { + name: "update by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + setNames: []string{"name", "age", "sender"}, + pkName: "id", + wantRes: "UPDATE \"test_table\" SET \"name\" = $1, \"age\" = $2, \"sender\" = $3 WHERE \"id\" = $4", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.UpdateSQL(tc.setNames, tc.pkName, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + func TestDbBase_DeleteSQL(t *testing.T) { mi := &models.ModelInfo{ Table: "test_table",