fix: refactor InsertOrUpdate method in dbBase (#5296)
* fix: refactor InsertOrUpdate method in dbBase and add the test * fix: add the change record to the CHANGELOG.md
This commit is contained in:
parent
a5eda3267a
commit
46a00d3592
@ -15,6 +15,7 @@
|
|||||||
- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271)
|
- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271)
|
||||||
- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274)
|
- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274)
|
||||||
- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295)
|
- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295)
|
||||||
|
- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296)
|
||||||
|
|
||||||
## ORM refactoring
|
## ORM refactoring
|
||||||
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)
|
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)
|
||||||
|
|||||||
181
client/orm/db.go
181
client/orm/db.go
@ -528,80 +528,20 @@ func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bo
|
|||||||
// If your primary key or unique column conflict will update
|
// If your primary key or unique column conflict will update
|
||||||
// If no will insert
|
// If no will insert
|
||||||
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
args0 := ""
|
|
||||||
iouStr := ""
|
|
||||||
argsMap := map[string]string{}
|
|
||||||
switch a.Driver {
|
|
||||||
case DRMySQL:
|
|
||||||
iouStr = "ON DUPLICATE KEY UPDATE"
|
|
||||||
case DRPostgres:
|
|
||||||
if len(args) == 0 {
|
|
||||||
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
|
||||||
}
|
|
||||||
args0 = strings.ToLower(args[0])
|
|
||||||
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get on the key-value pairs
|
|
||||||
for _, v := range args {
|
|
||||||
kv := strings.Split(v, "=")
|
|
||||||
if len(kv) == 2 {
|
|
||||||
argsMap[strings.ToLower(kv[0])] = kv[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
names := make([]string, 0, len(mi.Fields.DBcols)-1)
|
names := make([]string, 0, len(mi.Fields.DBcols)-1)
|
||||||
Q := d.ins.TableQuote()
|
|
||||||
values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
|
values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
query, err := d.InsertOrUpdateSQL(names, &values, mi, a, args...)
|
||||||
updateValues := make([]interface{}, 0)
|
|
||||||
updates := make([]string, len(names))
|
if err != nil {
|
||||||
var conflitValue interface{}
|
return 0, err
|
||||||
for i, v := range names {
|
|
||||||
// identifier in database may not be case-sensitive, so quote it
|
|
||||||
v = fmt.Sprintf("%s%s%s", Q, v, Q)
|
|
||||||
marks[i] = "?"
|
|
||||||
valueStr := argsMap[strings.ToLower(v)]
|
|
||||||
if v == args0 {
|
|
||||||
conflitValue = values[i]
|
|
||||||
}
|
|
||||||
if valueStr != "" {
|
|
||||||
switch a.Driver {
|
|
||||||
case DRMySQL:
|
|
||||||
updates[i] = v + "=" + valueStr
|
|
||||||
case DRPostgres:
|
|
||||||
if conflitValue != nil {
|
|
||||||
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
|
||||||
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0)
|
|
||||||
updateValues = append(updateValues, conflitValue)
|
|
||||||
} else {
|
|
||||||
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
updates[i] = v + "=?"
|
|
||||||
updateValues = append(updateValues, values[i])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
values = append(values, updateValues...)
|
|
||||||
|
|
||||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
|
||||||
qmarks := strings.Join(marks, ", ")
|
|
||||||
qupdates := strings.Join(updates, ", ")
|
|
||||||
columns := strings.Join(names, sep)
|
|
||||||
|
|
||||||
// conflitValue maybe is a int,can`t use fmt.Sprintf
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
|
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
|
||||||
|
|
||||||
if !d.ins.HasReturningID(mi, &query) {
|
if !d.ins.HasReturningID(mi, &query) {
|
||||||
res, err := q.ExecContext(ctx, query, values...)
|
res, err := q.ExecContext(ctx, query, values...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -625,6 +565,117 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.Mod
|
|||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) InsertOrUpdateSQL(names []string, values *[]interface{}, mi *models.ModelInfo, a *alias, args ...string) (string, error) {
|
||||||
|
|
||||||
|
args0 := ""
|
||||||
|
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
|
case DRPostgres:
|
||||||
|
if len(args) == 0 {
|
||||||
|
return "", fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||||
|
}
|
||||||
|
args0 = strings.ToLower(args[0])
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
argsMap := map[string]string{}
|
||||||
|
// Get on the key-value pairs
|
||||||
|
for _, v := range args {
|
||||||
|
kv := strings.Split(v, "=")
|
||||||
|
if len(kv) == 2 {
|
||||||
|
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
quote := d.ins.TableQuote()
|
||||||
|
|
||||||
|
buf := buffers.Get()
|
||||||
|
defer buffers.Put(buf)
|
||||||
|
|
||||||
|
_, _ = buf.WriteString("INSERT INTO ")
|
||||||
|
_, _ = buf.WriteString(quote)
|
||||||
|
_, _ = buf.WriteString(mi.Table)
|
||||||
|
_, _ = buf.WriteString(quote)
|
||||||
|
_, _ = buf.WriteString(" (")
|
||||||
|
|
||||||
|
for i, name := range names {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
_, _ = buf.WriteString(quote)
|
||||||
|
_, _ = buf.WriteString(name)
|
||||||
|
_, _ = buf.WriteString(quote)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = buf.WriteString(") VALUES (")
|
||||||
|
|
||||||
|
for i := 0; i < len(names); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
_, _ = buf.WriteString("?")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = buf.WriteString(") ")
|
||||||
|
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
|
_, _ = buf.WriteString("ON DUPLICATE KEY UPDATE ")
|
||||||
|
case DRPostgres:
|
||||||
|
_, _ = buf.WriteString("ON CONFLICT (")
|
||||||
|
_, _ = buf.WriteString(args0)
|
||||||
|
_, _ = buf.WriteString(") DO UPDATE SET ")
|
||||||
|
}
|
||||||
|
|
||||||
|
var conflitValue interface{}
|
||||||
|
for i, v := range names {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
// identifier in database may not be case-sensitive, so quote it
|
||||||
|
v = fmt.Sprintf("%s%s%s", quote, v, quote)
|
||||||
|
valueStr := argsMap[strings.ToLower(v)]
|
||||||
|
if v == args0 {
|
||||||
|
conflitValue = (*values)[i]
|
||||||
|
}
|
||||||
|
if valueStr != "" {
|
||||||
|
switch a.Driver {
|
||||||
|
case DRMySQL:
|
||||||
|
_, _ = buf.WriteString(v)
|
||||||
|
_, _ = buf.WriteString("=")
|
||||||
|
_, _ = buf.WriteString(valueStr)
|
||||||
|
case DRPostgres:
|
||||||
|
if conflitValue != nil {
|
||||||
|
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
|
||||||
|
_, _ = buf.WriteString(v)
|
||||||
|
_, _ = buf.WriteString("=(select ")
|
||||||
|
_, _ = buf.WriteString(valueStr)
|
||||||
|
_, _ = buf.WriteString(" from ")
|
||||||
|
_, _ = buf.WriteString(mi.Table)
|
||||||
|
_, _ = buf.WriteString(" where ")
|
||||||
|
_, _ = buf.WriteString(args0)
|
||||||
|
_, _ = buf.WriteString(" = ? )")
|
||||||
|
*values = append(*values, conflitValue)
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, _ = buf.WriteString(v)
|
||||||
|
_, _ = buf.WriteString("=?")
|
||||||
|
*values = append(*values, (*values)[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := buf.String()
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Update execute update sql dbQuerier with given struct reflect.Value.
|
// Update execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"github.com/beego/beego/v2/client/orm/internal/buffers"
|
"github.com/beego/beego/v2/client/orm/internal/buffers"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -648,3 +649,240 @@ func TestDbBase_UpdateBatchSQL(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDbBase_InsertOrUpdateSQL(t *testing.T) {
|
||||||
|
|
||||||
|
mi := &models.ModelInfo{
|
||||||
|
Table: "test_tab",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
db *dbBase
|
||||||
|
|
||||||
|
names []string
|
||||||
|
values []interface{}
|
||||||
|
a *alias
|
||||||
|
args []string
|
||||||
|
|
||||||
|
wantRes string
|
||||||
|
wantErr error
|
||||||
|
wantValues []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test nonsupport driver",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBaseSqlite(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"name", "age", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRSqlite,
|
||||||
|
DriverName: "sqlite3",
|
||||||
|
},
|
||||||
|
args: []string{
|
||||||
|
"`age`=20",
|
||||||
|
"`score`=`score`+1",
|
||||||
|
},
|
||||||
|
|
||||||
|
wantErr: errors.New("`sqlite3` nonsupport InsertOrUpdate in beego"),
|
||||||
|
wantValues: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with MySQL",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBaseMysql(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"name", "age", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRMySQL,
|
||||||
|
DriverName: "mysql",
|
||||||
|
},
|
||||||
|
args: []string{
|
||||||
|
"`age`=20",
|
||||||
|
"`score`=`score`+1",
|
||||||
|
},
|
||||||
|
|
||||||
|
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=20, `score`=`score`+1",
|
||||||
|
wantValues: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
"test_name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with MySQL with no args",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBaseMysql(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"name", "age", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRMySQL,
|
||||||
|
DriverName: "mysql",
|
||||||
|
},
|
||||||
|
|
||||||
|
wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=?, `score`=?",
|
||||||
|
wantValues: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with PostgreSQL normal",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"name", "age", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRPostgres,
|
||||||
|
DriverName: "postgres",
|
||||||
|
},
|
||||||
|
args: []string{
|
||||||
|
`"name"`,
|
||||||
|
`"score"="score_1"`,
|
||||||
|
},
|
||||||
|
|
||||||
|
wantRes: `INSERT INTO "test_tab" ("name", "age", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "name"=$4, "age"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
|
||||||
|
wantValues: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
"test_name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with PostgreSQL without conflict column",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"name", "age", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRPostgres,
|
||||||
|
DriverName: "postgres",
|
||||||
|
},
|
||||||
|
|
||||||
|
wantErr: errors.New("`postgres` use InsertOrUpdate must have a conflict column"),
|
||||||
|
wantValues: []interface{}{
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with PostgreSQL the conflict column is not in front of the specified column",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"score", "name", "age"},
|
||||||
|
values: []interface{}{
|
||||||
|
12,
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRPostgres,
|
||||||
|
DriverName: "postgres",
|
||||||
|
},
|
||||||
|
args: []string{
|
||||||
|
`"name"`,
|
||||||
|
`"score"="score_1"`,
|
||||||
|
},
|
||||||
|
|
||||||
|
wantErr: errors.New("`\"name\"` must be in front of `\"score\"` in your struct"),
|
||||||
|
wantValues: []interface{}{
|
||||||
|
12,
|
||||||
|
"test_name",
|
||||||
|
18,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert or update with PostgreSQL the conflict column is in front of the specified column",
|
||||||
|
db: &dbBase{
|
||||||
|
ins: newdbBasePostgres(),
|
||||||
|
},
|
||||||
|
|
||||||
|
names: []string{"age", "name", "score"},
|
||||||
|
values: []interface{}{
|
||||||
|
18,
|
||||||
|
"test_name",
|
||||||
|
12,
|
||||||
|
},
|
||||||
|
a: &alias{
|
||||||
|
Driver: DRPostgres,
|
||||||
|
DriverName: "postgres",
|
||||||
|
},
|
||||||
|
args: []string{
|
||||||
|
`"name"`,
|
||||||
|
`"score"="score_1"`,
|
||||||
|
},
|
||||||
|
|
||||||
|
wantRes: `INSERT INTO "test_tab" ("age", "name", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "age"=$4, "name"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`,
|
||||||
|
wantValues: []interface{}{
|
||||||
|
18,
|
||||||
|
"test_name",
|
||||||
|
12,
|
||||||
|
18,
|
||||||
|
"test_name",
|
||||||
|
"test_name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
res, err := tc.db.InsertOrUpdateSQL(tc.names, &tc.values, mi, tc.a, tc.args...)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantValues, tc.values)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantErr, err)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantRes, res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user