fix: refactor InsertOrUpdate method in dbBase (#5296)

* fix: refactor InsertOrUpdate method in dbBase and add the test

* fix: add the change record to the CHANGELOG.md
This commit is contained in:
Uzziah 2023-08-18 20:47:24 +08:00 committed by GitHub
parent a5eda3267a
commit 46a00d3592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 355 additions and 65 deletions

View File

@ -15,6 +15,7 @@
- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271)
- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274)
- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295)
- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296)
## ORM refactoring
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)

View File

@ -528,80 +528,20 @@ func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bo
// If your primary key or unique column conflict will update
// If no will insert
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := ""
iouStr := ""
argsMap := map[string]string{}
switch a.Driver {
case DRMySQL:
iouStr = "ON DUPLICATE KEY UPDATE"
case DRPostgres:
if len(args) == 0 {
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
}
args0 = strings.ToLower(args[0])
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
default:
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
}
// Get on the key-value pairs
for _, v := range args {
kv := strings.Split(v, "=")
if len(kv) == 2 {
argsMap[strings.ToLower(kv[0])] = kv[1]
}
}
names := make([]string, 0, len(mi.Fields.DBcols)-1)
Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
updateValues := make([]interface{}, 0)
updates := make([]string, len(names))
var conflitValue interface{}
for i, v := range names {
// identifier in database may not be case-sensitive, so quote it
v = fmt.Sprintf("%s%s%s", Q, v, Q)
marks[i] = "?"
valueStr := argsMap[strings.ToLower(v)]
if v == args0 {
conflitValue = values[i]
}
if valueStr != "" {
switch a.Driver {
case DRMySQL:
updates[i] = v + "=" + valueStr
case DRPostgres:
if conflitValue != nil {
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0)
updateValues = append(updateValues, conflitValue)
} else {
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
}
}
} else {
updates[i] = v + "=?"
updateValues = append(updateValues, values[i])
}
query, err := d.InsertOrUpdateSQL(names, &values, mi, a, args...)
if err != nil {
return 0, err
}
values = append(values, updateValues...)
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep)
// conflitValue maybe is a int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query)
if !d.ins.HasReturningID(mi, &query) {
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
@ -625,6 +565,117 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.Mod
return id, err
}
func (d *dbBase) InsertOrUpdateSQL(names []string, values *[]interface{}, mi *models.ModelInfo, a *alias, args ...string) (string, error) {
args0 := ""
switch a.Driver {
case DRMySQL:
case DRPostgres:
if len(args) == 0 {
return "", fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
}
args0 = strings.ToLower(args[0])
default:
return "", fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
}
argsMap := map[string]string{}
// Get on the key-value pairs
for _, v := range args {
kv := strings.Split(v, "=")
if len(kv) == 2 {
argsMap[strings.ToLower(kv[0])] = kv[1]
}
}
quote := d.ins.TableQuote()
buf := buffers.Get()
defer buffers.Put(buf)
_, _ = buf.WriteString("INSERT INTO ")
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(mi.Table)
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(" (")
for i, name := range names {
if i > 0 {
_, _ = buf.WriteString(", ")
}
_, _ = buf.WriteString(quote)
_, _ = buf.WriteString(name)
_, _ = buf.WriteString(quote)
}
_, _ = buf.WriteString(") VALUES (")
for i := 0; i < len(names); i++ {
if i > 0 {
_, _ = buf.WriteString(", ")
}
_, _ = buf.WriteString("?")
}
_, _ = buf.WriteString(") ")
switch a.Driver {
case DRMySQL:
_, _ = buf.WriteString("ON DUPLICATE KEY UPDATE ")
case DRPostgres:
_, _ = buf.WriteString("ON CONFLICT (")
_, _ = buf.WriteString(args0)
_, _ = buf.WriteString(") DO UPDATE SET ")
}
var conflitValue interface{}
for i, v := range names {
if i > 0 {
_, _ = buf.WriteString(", ")
}
// identifier in database may not be case-sensitive, so quote it
v = fmt.Sprintf("%s%s%s", quote, v, quote)
valueStr := argsMap[strings.ToLower(v)]
if v == args0 {
conflitValue = (*values)[i]
}
if valueStr != "" {
switch a.Driver {
case DRMySQL:
_, _ = buf.WriteString(v)
_, _ = buf.WriteString("=")
_, _ = buf.WriteString(valueStr)
case DRPostgres:
if conflitValue != nil {
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
_, _ = buf.WriteString(v)
_, _ = buf.WriteString("=(select ")
_, _ = buf.WriteString(valueStr)
_, _ = buf.WriteString(" from ")
_, _ = buf.WriteString(mi.Table)
_, _ = buf.WriteString(" where ")
_, _ = buf.WriteString(args0)
_, _ = buf.WriteString(" = ? )")
*values = append(*values, conflitValue)
} else {
return "", fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
}
}
} else {
_, _ = buf.WriteString(v)
_, _ = buf.WriteString("=?")
*values = append(*values, (*values)[i])
}
}
query := buf.String()
d.ins.ReplaceMarks(&query)
return query, nil
}
// Update execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)

View File

@ -15,6 +15,7 @@
package orm
import (
"errors"
"github.com/beego/beego/v2/client/orm/internal/buffers"
"testing"
@ -648,3 +649,240 @@ func TestDbBase_UpdateBatchSQL(t *testing.T) {
})
}
}
func TestDbBase_InsertOrUpdateSQL(t *testing.T) {
mi := &models.ModelInfo{
Table: "test_tab",
}
testCases := []struct {
name string
db *dbBase
names []string
values []interface{}
a *alias
args []string
wantRes string
wantErr error
wantValues []interface{}
}{
{
name: "test nonsupport driver",
db: &dbBase{
ins: newdbBaseSqlite(),
},
names: []string{"name", "age", "score"},
values: []interface{}{
"test_name",
18,
12,
},
a: &alias{
Driver: DRSqlite,
DriverName: "sqlite3",
},
args: []string{
"`age`=20",
"`score`=`score`+1",
},
wantErr: errors.New("`sqlite3` nonsupport InsertOrUpdate in beego"),
wantValues: []interface{}{
"test_name",
18,
12,
},
},
{
name: "insert or update with MySQL",
db: &dbBase{
ins: newdbBaseMysql(),
},
names: []string{"name", "age", "score"},
values: []interface{}{
"test_name",
18,
12,
},
a: &alias{
Driver: DRMySQL,
DriverName: "mysql",
},
args: []string{
"`age`=20",
"`score`=`score`+1",
},
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=20, `score`=`score`+1",
wantValues: []interface{}{
"test_name",
18,
12,
"test_name",
},
},
{
name: "insert or update with MySQL with no args",
db: &dbBase{
ins: newdbBaseMysql(),
},
names: []string{"name", "age", "score"},
values: []interface{}{
"test_name",
18,
12,
},
a: &alias{
Driver: DRMySQL,
DriverName: "mysql",
},
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=?, `score`=?",
wantValues: []interface{}{
"test_name",
18,
12,
"test_name",
18,
12,
},
},
{
name: "insert or update with PostgreSQL normal",
db: &dbBase{
ins: newdbBasePostgres(),
},
names: []string{"name", "age", "score"},
values: []interface{}{
"test_name",
18,
12,
},
a: &alias{
Driver: DRPostgres,
DriverName: "postgres",
},
args: []string{
`"name"`,
`"score"="score_1"`,
},
wantRes: `INSERT INTO "test_tab" ("name", "age", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "name"=$4, "age"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
wantValues: []interface{}{
"test_name",
18,
12,
"test_name",
18,
"test_name",
},
},
{
name: "insert or update with PostgreSQL without conflict column",
db: &dbBase{
ins: newdbBasePostgres(),
},
names: []string{"name", "age", "score"},
values: []interface{}{
"test_name",
18,
12,
},
a: &alias{
Driver: DRPostgres,
DriverName: "postgres",
},
wantErr: errors.New("`postgres` use InsertOrUpdate must have a conflict column"),
wantValues: []interface{}{
"test_name",
18,
12,
},
},
{
name: "insert or update with PostgreSQL the conflict column is not in front of the specified column",
db: &dbBase{
ins: newdbBasePostgres(),
},
names: []string{"score", "name", "age"},
values: []interface{}{
12,
"test_name",
18,
},
a: &alias{
Driver: DRPostgres,
DriverName: "postgres",
},
args: []string{
`"name"`,
`"score"="score_1"`,
},
wantErr: errors.New("`\"name\"` must be in front of `\"score\"` in your struct"),
wantValues: []interface{}{
12,
"test_name",
18,
},
},
{
name: "insert or update with PostgreSQL the conflict column is in front of the specified column",
db: &dbBase{
ins: newdbBasePostgres(),
},
names: []string{"age", "name", "score"},
values: []interface{}{
18,
"test_name",
12,
},
a: &alias{
Driver: DRPostgres,
DriverName: "postgres",
},
args: []string{
`"name"`,
`"score"="score_1"`,
},
wantRes: `INSERT INTO "test_tab" ("age", "name", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "age"=$4, "name"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
wantValues: []interface{}{
18,
"test_name",
12,
18,
"test_name",
"test_name",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := tc.db.InsertOrUpdateSQL(tc.names, &tc.values, mi, tc.a, tc.args...)
assert.Equal(t, tc.wantValues, tc.values)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}