fix: refactor InsertOrUpdate method in dbBase (#5296)
* fix: refactor InsertOrUpdate method in dbBase and add the test * fix: add the change record to the CHANGELOG.md
This commit is contained in:
parent
a5eda3267a
commit
46a00d3592
@ -15,6 +15,7 @@
|
||||
- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271)
|
||||
- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274)
|
||||
- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295)
|
||||
- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296)
|
||||
|
||||
## ORM refactoring
|
||||
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)
|
||||
|
||||
181
client/orm/db.go
181
client/orm/db.go
@ -528,80 +528,20 @@ func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bo
|
||||
// If your primary key or unique column conflict will update
|
||||
// If no will insert
|
||||
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||
args0 := ""
|
||||
iouStr := ""
|
||||
argsMap := map[string]string{}
|
||||
switch a.Driver {
|
||||
case DRMySQL:
|
||||
iouStr = "ON DUPLICATE KEY UPDATE"
|
||||
case DRPostgres:
|
||||
if len(args) == 0 {
|
||||
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||
}
|
||||
args0 = strings.ToLower(args[0])
|
||||
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
||||
default:
|
||||
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||
}
|
||||
|
||||
// Get on the key-value pairs
|
||||
for _, v := range args {
|
||||
kv := strings.Split(v, "=")
|
||||
if len(kv) == 2 {
|
||||
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||
}
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(mi.Fields.DBcols)-1)
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
marks := make([]string, len(names))
|
||||
updateValues := make([]interface{}, 0)
|
||||
updates := make([]string, len(names))
|
||||
var conflitValue interface{}
|
||||
for i, v := range names {
|
||||
// identifier in database may not be case-sensitive, so quote it
|
||||
v = fmt.Sprintf("%s%s%s", Q, v, Q)
|
||||
marks[i] = "?"
|
||||
valueStr := argsMap[strings.ToLower(v)]
|
||||
if v == args0 {
|
||||
conflitValue = values[i]
|
||||
}
|
||||
if valueStr != "" {
|
||||
switch a.Driver {
|
||||
case DRMySQL:
|
||||
updates[i] = v + "=" + valueStr
|
||||
case DRPostgres:
|
||||
if conflitValue != nil {
|
||||
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
||||
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0)
|
||||
updateValues = append(updateValues, conflitValue)
|
||||
} else {
|
||||
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
updates[i] = v + "=?"
|
||||
updateValues = append(updateValues, values[i])
|
||||
}
|
||||
query, err := d.InsertOrUpdateSQL(names, &values, mi, a, args...)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
values = append(values, updateValues...)
|
||||
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
qupdates := strings.Join(updates, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
// conflitValue maybe is a int,can`t use fmt.Sprintf
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if !d.ins.HasReturningID(mi, &query) {
|
||||
res, err := q.ExecContext(ctx, query, values...)
|
||||
if err == nil {
|
||||
@ -625,6 +565,117 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.Mod
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertOrUpdateSQL(names []string, values *[]interface{}, mi *models.ModelInfo, a *alias, args ...string) (string, error) {
|
||||
|
||||
args0 := ""
|
||||
|
||||
switch a.Driver {
|
||||
case DRMySQL:
|
||||
case DRPostgres:
|
||||
if len(args) == 0 {
|
||||
return "", fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||
}
|
||||
args0 = strings.ToLower(args[0])
|
||||
default:
|
||||
return "", fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||
}
|
||||
|
||||
argsMap := map[string]string{}
|
||||
// Get on the key-value pairs
|
||||
for _, v := range args {
|
||||
kv := strings.Split(v, "=")
|
||||
if len(kv) == 2 {
|
||||
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||
}
|
||||
}
|
||||
|
||||
quote := d.ins.TableQuote()
|
||||
|
||||
buf := buffers.Get()
|
||||
defer buffers.Put(buf)
|
||||
|
||||
_, _ = buf.WriteString("INSERT INTO ")
|
||||
_, _ = buf.WriteString(quote)
|
||||
_, _ = buf.WriteString(mi.Table)
|
||||
_, _ = buf.WriteString(quote)
|
||||
_, _ = buf.WriteString(" (")
|
||||
|
||||
for i, name := range names {
|
||||
if i > 0 {
|
||||
_, _ = buf.WriteString(", ")
|
||||
}
|
||||
_, _ = buf.WriteString(quote)
|
||||
_, _ = buf.WriteString(name)
|
||||
_, _ = buf.WriteString(quote)
|
||||
}
|
||||
|
||||
_, _ = buf.WriteString(") VALUES (")
|
||||
|
||||
for i := 0; i < len(names); i++ {
|
||||
if i > 0 {
|
||||
_, _ = buf.WriteString(", ")
|
||||
}
|
||||
_, _ = buf.WriteString("?")
|
||||
}
|
||||
|
||||
_, _ = buf.WriteString(") ")
|
||||
|
||||
switch a.Driver {
|
||||
case DRMySQL:
|
||||
_, _ = buf.WriteString("ON DUPLICATE KEY UPDATE ")
|
||||
case DRPostgres:
|
||||
_, _ = buf.WriteString("ON CONFLICT (")
|
||||
_, _ = buf.WriteString(args0)
|
||||
_, _ = buf.WriteString(") DO UPDATE SET ")
|
||||
}
|
||||
|
||||
var conflitValue interface{}
|
||||
for i, v := range names {
|
||||
if i > 0 {
|
||||
_, _ = buf.WriteString(", ")
|
||||
}
|
||||
// identifier in database may not be case-sensitive, so quote it
|
||||
v = fmt.Sprintf("%s%s%s", quote, v, quote)
|
||||
valueStr := argsMap[strings.ToLower(v)]
|
||||
if v == args0 {
|
||||
conflitValue = (*values)[i]
|
||||
}
|
||||
if valueStr != "" {
|
||||
switch a.Driver {
|
||||
case DRMySQL:
|
||||
_, _ = buf.WriteString(v)
|
||||
_, _ = buf.WriteString("=")
|
||||
_, _ = buf.WriteString(valueStr)
|
||||
case DRPostgres:
|
||||
if conflitValue != nil {
|
||||
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
||||
_, _ = buf.WriteString(v)
|
||||
_, _ = buf.WriteString("=(select ")
|
||||
_, _ = buf.WriteString(valueStr)
|
||||
_, _ = buf.WriteString(" from ")
|
||||
_, _ = buf.WriteString(mi.Table)
|
||||
_, _ = buf.WriteString(" where ")
|
||||
_, _ = buf.WriteString(args0)
|
||||
_, _ = buf.WriteString(" = ? )")
|
||||
*values = append(*values, conflitValue)
|
||||
} else {
|
||||
return "", fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_, _ = buf.WriteString(v)
|
||||
_, _ = buf.WriteString("=?")
|
||||
*values = append(*values, (*values)[i])
|
||||
}
|
||||
}
|
||||
|
||||
query := buf.String()
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// Update execute update sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/beego/beego/v2/client/orm/internal/buffers"
|
||||
"testing"
|
||||
|
||||
@ -648,3 +649,240 @@ func TestDbBase_UpdateBatchSQL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_InsertOrUpdateSQL(t *testing.T) {
|
||||
|
||||
mi := &models.ModelInfo{
|
||||
Table: "test_tab",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
db *dbBase
|
||||
|
||||
names []string
|
||||
values []interface{}
|
||||
a *alias
|
||||
args []string
|
||||
|
||||
wantRes string
|
||||
wantErr error
|
||||
wantValues []interface{}
|
||||
}{
|
||||
{
|
||||
name: "test nonsupport driver",
|
||||
db: &dbBase{
|
||||
ins: newdbBaseSqlite(),
|
||||
},
|
||||
|
||||
names: []string{"name", "age", "score"},
|
||||
values: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRSqlite,
|
||||
DriverName: "sqlite3",
|
||||
},
|
||||
args: []string{
|
||||
"`age`=20",
|
||||
"`score`=`score`+1",
|
||||
},
|
||||
|
||||
wantErr: errors.New("`sqlite3` nonsupport InsertOrUpdate in beego"),
|
||||
wantValues: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with MySQL",
|
||||
db: &dbBase{
|
||||
ins: newdbBaseMysql(),
|
||||
},
|
||||
|
||||
names: []string{"name", "age", "score"},
|
||||
values: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRMySQL,
|
||||
DriverName: "mysql",
|
||||
},
|
||||
args: []string{
|
||||
"`age`=20",
|
||||
"`score`=`score`+1",
|
||||
},
|
||||
|
||||
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=20, `score`=`score`+1",
|
||||
wantValues: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
"test_name",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with MySQL with no args",
|
||||
db: &dbBase{
|
||||
ins: newdbBaseMysql(),
|
||||
},
|
||||
|
||||
names: []string{"name", "age", "score"},
|
||||
values: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRMySQL,
|
||||
DriverName: "mysql",
|
||||
},
|
||||
|
||||
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=?, `score`=?",
|
||||
wantValues: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with PostgreSQL normal",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
|
||||
names: []string{"name", "age", "score"},
|
||||
values: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRPostgres,
|
||||
DriverName: "postgres",
|
||||
},
|
||||
args: []string{
|
||||
`"name"`,
|
||||
`"score"="score_1"`,
|
||||
},
|
||||
|
||||
wantRes: `INSERT INTO "test_tab" ("name", "age", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "name"=$4, "age"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
|
||||
wantValues: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
"test_name",
|
||||
18,
|
||||
"test_name",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with PostgreSQL without conflict column",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
|
||||
names: []string{"name", "age", "score"},
|
||||
values: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRPostgres,
|
||||
DriverName: "postgres",
|
||||
},
|
||||
|
||||
wantErr: errors.New("`postgres` use InsertOrUpdate must have a conflict column"),
|
||||
wantValues: []interface{}{
|
||||
"test_name",
|
||||
18,
|
||||
12,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with PostgreSQL the conflict column is not in front of the specified column",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
|
||||
names: []string{"score", "name", "age"},
|
||||
values: []interface{}{
|
||||
12,
|
||||
"test_name",
|
||||
18,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRPostgres,
|
||||
DriverName: "postgres",
|
||||
},
|
||||
args: []string{
|
||||
`"name"`,
|
||||
`"score"="score_1"`,
|
||||
},
|
||||
|
||||
wantErr: errors.New("`\"name\"` must be in front of `\"score\"` in your struct"),
|
||||
wantValues: []interface{}{
|
||||
12,
|
||||
"test_name",
|
||||
18,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insert or update with PostgreSQL the conflict column is in front of the specified column",
|
||||
db: &dbBase{
|
||||
ins: newdbBasePostgres(),
|
||||
},
|
||||
|
||||
names: []string{"age", "name", "score"},
|
||||
values: []interface{}{
|
||||
18,
|
||||
"test_name",
|
||||
12,
|
||||
},
|
||||
a: &alias{
|
||||
Driver: DRPostgres,
|
||||
DriverName: "postgres",
|
||||
},
|
||||
args: []string{
|
||||
`"name"`,
|
||||
`"score"="score_1"`,
|
||||
},
|
||||
|
||||
wantRes: `INSERT INTO "test_tab" ("age", "name", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "age"=$4, "name"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
|
||||
wantValues: []interface{}{
|
||||
18,
|
||||
"test_name",
|
||||
12,
|
||||
18,
|
||||
"test_name",
|
||||
"test_name",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
res, err := tc.db.InsertOrUpdateSQL(tc.names, &tc.values, mi, tc.a, tc.args...)
|
||||
|
||||
assert.Equal(t, tc.wantValues, tc.values)
|
||||
|
||||
assert.Equal(t, tc.wantErr, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user