diff --git a/CHANGELOG.md b/CHANGELOG.md index 1747197a..964545a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271) - [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274) - [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295) +- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296) ## ORM refactoring - [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) diff --git a/client/orm/db.go b/client/orm/db.go index bded065f..15db2f77 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -528,80 +528,20 @@ func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bo // If your primary key or unique column conflict will update // If no will insert func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - args0 := "" - iouStr := "" - argsMap := map[string]string{} - switch a.Driver { - case DRMySQL: - iouStr = "ON DUPLICATE KEY UPDATE" - case DRPostgres: - if len(args) == 0 { - return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) - } - args0 = strings.ToLower(args[0]) - iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) - default: - return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) - } - - // Get on the key-value pairs - for _, v := range args { - kv := strings.Split(v, "=") - if len(kv) == 2 { - argsMap[strings.ToLower(kv[0])] = kv[1] - } - } names := make([]string, 0, len(mi.Fields.DBcols)-1) - Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ) if err != nil { return 0, err } - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) - var conflitValue interface{} - for i, v := range names { - // identifier in database may not be case-sensitive, so quote it - v = fmt.Sprintf("%s%s%s", Q, v, Q) - marks[i] = "?" - valueStr := argsMap[strings.ToLower(v)] - if v == args0 { - conflitValue = values[i] - } - if valueStr != "" { - switch a.Driver { - case DRMySQL: - updates[i] = v + "=" + valueStr - case DRPostgres: - if conflitValue != nil { - // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0) - updateValues = append(updateValues, conflitValue) - } else { - return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) - } - } - } else { - updates[i] = v + "=?" - updateValues = append(updateValues, values[i]) - } + query, err := d.InsertOrUpdateSQL(names, &values, mi, a, args...) + + if err != nil { + return 0, err } - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - // conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr) - - d.ins.ReplaceMarks(&query) - if !d.ins.HasReturningID(mi, &query) { res, err := q.ExecContext(ctx, query, values...) if err == nil { @@ -625,6 +565,117 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.Mod return id, err } +func (d *dbBase) InsertOrUpdateSQL(names []string, values *[]interface{}, mi *models.ModelInfo, a *alias, args ...string) (string, error) { + + args0 := "" + + switch a.Driver { + case DRMySQL: + case DRPostgres: + if len(args) == 0 { + return "", fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) + } + args0 = strings.ToLower(args[0]) + default: + return "", fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) + } + + argsMap := map[string]string{} + // Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + quote := d.ins.TableQuote() + + buf := buffers.Get() + defer buffers.Put(buf) + + _, _ = buf.WriteString("INSERT INTO ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" (") + + for i, name := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(quote) + } + + _, _ = buf.WriteString(") VALUES (") + + for i := 0; i < len(names); i++ { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString("?") + } + + _, _ = buf.WriteString(") ") + + switch a.Driver { + case DRMySQL: + _, _ = buf.WriteString("ON DUPLICATE KEY UPDATE ") + case DRPostgres: + _, _ = buf.WriteString("ON CONFLICT (") + _, _ = buf.WriteString(args0) + _, _ = buf.WriteString(") DO UPDATE SET ") + } + + var conflitValue interface{} + for i, v := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } + // identifier in database may not be case-sensitive, so quote it + v = fmt.Sprintf("%s%s%s", quote, v, quote) + valueStr := argsMap[strings.ToLower(v)] + if v == args0 { + conflitValue = (*values)[i] + } + if valueStr != "" { + switch a.Driver { + case DRMySQL: + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=") + _, _ = buf.WriteString(valueStr) + case DRPostgres: + if conflitValue != nil { + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=(select ") + _, _ = buf.WriteString(valueStr) + _, _ = buf.WriteString(" from ") + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(" where ") + _, _ = buf.WriteString(args0) + _, _ = buf.WriteString(" = ? )") + *values = append(*values, conflitValue) + } else { + return "", fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) + } + } + } else { + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=?") + *values = append(*values, (*values)[i]) + } + } + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, nil +} + // Update execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) diff --git a/client/orm/db_test.go b/client/orm/db_test.go index 6a61f380..cc79b108 100644 --- a/client/orm/db_test.go +++ b/client/orm/db_test.go @@ -15,6 +15,7 @@ package orm import ( + "errors" "github.com/beego/beego/v2/client/orm/internal/buffers" "testing" @@ -648,3 +649,240 @@ func TestDbBase_UpdateBatchSQL(t *testing.T) { }) } } + +func TestDbBase_InsertOrUpdateSQL(t *testing.T) { + + mi := &models.ModelInfo{ + Table: "test_tab", + } + + testCases := []struct { + name string + db *dbBase + + names []string + values []interface{} + a *alias + args []string + + wantRes string + wantErr error + wantValues []interface{} + }{ + { + name: "test nonsupport driver", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &alias{ + Driver: DRSqlite, + DriverName: "sqlite3", + }, + args: []string{ + "`age`=20", + "`score`=`score`+1", + }, + + wantErr: errors.New("`sqlite3` nonsupport InsertOrUpdate in beego"), + wantValues: []interface{}{ + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with MySQL", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &alias{ + Driver: DRMySQL, + DriverName: "mysql", + }, + args: []string{ + "`age`=20", + "`score`=`score`+1", + }, + + wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=20, `score`=`score`+1", + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + }, + }, + { + name: "insert or update with MySQL with no args", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &alias{ + Driver: DRMySQL, + DriverName: "mysql", + }, + + wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=?, `score`=?", + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with PostgreSQL normal", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &alias{ + Driver: DRPostgres, + DriverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantRes: `INSERT INTO "test_tab" ("name", "age", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "name"=$4, "age"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`, + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + 18, + "test_name", + }, + }, + { + name: "insert or update with PostgreSQL without conflict column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &alias{ + Driver: DRPostgres, + DriverName: "postgres", + }, + + wantErr: errors.New("`postgres` use InsertOrUpdate must have a conflict column"), + wantValues: []interface{}{ + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with PostgreSQL the conflict column is not in front of the specified column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"score", "name", "age"}, + values: []interface{}{ + 12, + "test_name", + 18, + }, + a: &alias{ + Driver: DRPostgres, + DriverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantErr: errors.New("`\"name\"` must be in front of `\"score\"` in your struct"), + wantValues: []interface{}{ + 12, + "test_name", + 18, + }, + }, + { + name: "insert or update with PostgreSQL the conflict column is in front of the specified column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"age", "name", "score"}, + values: []interface{}{ + 18, + "test_name", + 12, + }, + a: &alias{ + Driver: DRPostgres, + DriverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantRes: `INSERT INTO "test_tab" ("age", "name", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "age"=$4, "name"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`, + wantValues: []interface{}{ + 18, + "test_name", + 12, + 18, + "test_name", + "test_name", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res, err := tc.db.InsertOrUpdateSQL(tc.names, &tc.values, mi, tc.a, tc.args...) + + assert.Equal(t, tc.wantValues, tc.values) + + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + + assert.Equal(t, tc.wantRes, res) + }) + } + +}