From 20a50308426fe8e7cbb1c0cfd442366c6778f757 Mon Sep 17 00:00:00 2001 From: Uzziah <120019273+uzziahlin@users.noreply.github.com> Date: Thu, 29 Jun 2023 21:33:46 +0800 Subject: [PATCH] fix: refactor InsertValue method (#5267) * fix: refactor insertValue method and add the test * fix: exec goimports and add Licence file header * fix: modify construct method of dbBase * fix: add modify record into CHANGELOG --- CHANGELOG.md | 1 + client/orm/db.go | 70 +++++++++---- client/orm/db_test.go | 130 +++++++++++++++++++++++++ client/orm/internal/buffers/buffers.go | 1 + 4 files changed, 182 insertions(+), 20 deletions(-) create mode 100644 client/orm/db_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index c358f69d..7323e8b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - [orm: missing handling %COL% placeholder](https://github.com/beego/beego/pull/5257) - [fix: use of ioutil package](https://github.com/beego/beego/pull/5261) - [cache/redis: support skipEmptyPrefix option ](https://github.com/beego/beego/pull/5264) +- [fix: refactor InsertValue method](https://github.com/beego/beego/pull/5267) ## ORM refactoring - [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) diff --git a/client/orm/db.go b/client/orm/db.go index 0b3cfb80..ddb9e589 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -23,6 +23,8 @@ import ( "strings" "time" + "github.com/beego/beego/v2/client/orm/internal/buffers" + "github.com/beego/beego/v2/client/orm/internal/logs" "github.com/beego/beego/v2/client/orm/internal/utils" @@ -450,26 +452,7 @@ func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *models.ModelI // InsertValue execute insert sql with given struct and given values. // insert the given values, not the field values in struct. func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { - Q := d.ins.TableQuote() - - marks := make([]string, len(names)) - for i := range marks { - marks[i] = "?" - } - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti && multi > 1 { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) + query := d.InsertValueSQL(names, values, isMulti, mi) if isMulti || !d.ins.HasReturningID(mi, &query) { res, err := q.ExecContext(ctx, query, values...) @@ -494,6 +477,53 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelI return id, err } +func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bool, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + + Q := d.ins.TableQuote() + + _, _ = buf.WriteString("INSERT INTO ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + + _, _ = buf.WriteString(" (") + for i, name := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(Q) + } + _, _ = buf.WriteString(") VALUES (") + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = "?" + } + qmarks := strings.Join(marks, ", ") + + _, _ = buf.WriteString(qmarks) + + multi := len(values) / len(names) + + if isMulti && multi > 1 { + for i := 0; i < multi-1; i++ { + _, _ = buf.WriteString("), (") + _, _ = buf.WriteString(qmarks) + } + } + + _ = buf.WriteByte(')') + + query := buf.String() + d.ins.ReplaceMarks(&query) + + return query +} + // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert diff --git a/client/orm/db_test.go b/client/orm/db_test.go new file mode 100644 index 00000000..a0c9eb9b --- /dev/null +++ b/client/orm/db_test.go @@ -0,0 +1,130 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm/internal/models" +) + +func TestDbBase_InsertValueSQL(t *testing.T) { + + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + isMulti bool + names []string + values []interface{} + + wantRes string + }{ + { + name: "single insert by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "single insert by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + { + name: "multi insert by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?), (?, ?)", + }, + { + name: "multi insert by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2), ($3, $4)", + }, + { + name: "multi insert by dbBase but values is not enough", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2"}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "multi insert by dbBasePostgres but values is not enough", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2"}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + { + name: "single insert by dbBase but values is double to names", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "single insert by dbBasePostgres but values is double to names", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.InsertValueSQL(tc.names, tc.values, tc.isMulti, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +} diff --git a/client/orm/internal/buffers/buffers.go b/client/orm/internal/buffers/buffers.go index 045c00e0..db341ff4 100644 --- a/client/orm/internal/buffers/buffers.go +++ b/client/orm/internal/buffers/buffers.go @@ -22,6 +22,7 @@ type Buffer interface { Write(p []byte) (int, error) WriteString(s string) (int, error) WriteByte(c byte) error + String() string } func Get() Buffer {