From 957c526efb61d2bf760779e67be379463f4b9c38 Mon Sep 17 00:00:00 2001 From: Uzziah <120019273+uzziahlin@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:53:07 +0800 Subject: [PATCH] fix: refactor Delete method (#5271) * fix: refactor Delete method and add test * fix: add modify record into CHANGELOG --- CHANGELOG.md | 1 + client/orm/db.go | 37 ++++++++++++++++++++++------ client/orm/db_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f0f9be0..9eede89d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - [fix: refactor InsertValue method](https://github.com/beego/beego/pull/5267) - [fix: modify InsertOrUpdate method, Remove the isMulti variable and its associated code](https://github.com/beego/beego/pull/5269) - [refactor cache/redis: Use redisConfig to receive incoming JSON (previously using a map)](https://github.com/beego/beego/pull/5268) +- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271) ## ORM refactoring - [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) diff --git a/client/orm/db.go b/client/orm/db.go index 596d1667..fdecc478 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -713,14 +713,8 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, args = append(args, pkValue) } - Q := d.ins.TableQuote() + query := d.DeleteSQL(whereCols, mi) - sep := fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.Table, Q, Q, wheres, Q) - - d.ins.ReplaceMarks(&query) res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() @@ -738,6 +732,35 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, return 0, err } +func (d *dbBase) DeleteSQL(whereCols []string, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + + Q := d.ins.TableQuote() + + _, _ = buf.WriteString("DELETE FROM ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" WHERE ") + + for i, col := range whereCols { + if i > 0 { + _, _ = buf.WriteString(" AND ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(col) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + } + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query +} + // UpdateBatch update table-related record by querySet. // need querySet not struct reflect.Value to update related records. func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { diff --git a/client/orm/db_test.go b/client/orm/db_test.go index a0c9eb9b..60624f05 100644 --- a/client/orm/db_test.go +++ b/client/orm/db_test.go @@ -128,3 +128,60 @@ func TestDbBase_InsertValueSQL(t *testing.T) { }) } } + +func TestDbBase_DeleteSQL(t *testing.T) { + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + + whereCols []string + + wantRes string + }{ + { + name: "delete by dbBase with id", + db: &dbBase{ + ins: &dbBase{}, + }, + whereCols: []string{"id"}, + wantRes: "DELETE FROM `test_table` WHERE `id` = ?", + }, + { + name: "delete by dbBase not id", + db: &dbBase{ + ins: &dbBase{}, + }, + whereCols: []string{"name", "age"}, + wantRes: "DELETE FROM `test_table` WHERE `name` = ? AND `age` = ?", + }, + { + name: "delete by dbBasePostgres with id", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + whereCols: []string{"id"}, + wantRes: "DELETE FROM \"test_table\" WHERE \"id\" = $1", + }, + { + name: "delete by dbBasePostgres not id", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + whereCols: []string{"name", "age"}, + wantRes: "DELETE FROM \"test_table\" WHERE \"name\" = $1 AND \"age\" = $2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.DeleteSQL(tc.whereCols, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +}