Resolve conflicts among master branch and develop branch (#5286)

* feature extend readthrough for cache module (#5116)

* feature 增加readthrough

* feature: add write though for cache mode (#5117)

* feature: add writethough for cache mode

* feature add singleflight cache (#5119)

* build(deps): bump go.opentelemetry.io/otel/trace from 1.8.0 to 1.11.2

Bumps [go.opentelemetry.io/otel/trace](https://github.com/open-telemetry/opentelemetry-go) from 1.8.0 to 1.11.2.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.8.0...v1.11.2)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/trace
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix 5129: must set formatter after init the logger

* remove beego.vip

* build(deps): bump actions/stale from 5 to 7

Bumps [actions/stale](https://github.com/actions/stale) from 5 to 7.
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v7)

---
updated-dependencies:
- dependency-name: actions/stale
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix 5079: only log msg when the channel is not closed (#5132)

* optimize test

* upgrade otel dependencies to v1.11.2

* format code

* Bloom filter cache (#5126)

* feature: add bloom filter cache

* feature upload remove all temp file

* bugfix Controller SaveToFile remove all temp file

* rft: motify BeeLogger signalChan (#5139)

* add non-block write log in asynchronous mode (#5150)

* add non-block write log in asynchronous mode

---------

Co-authored-by: chenhaokun <chenhaokun@itiger.com>

* fix the docsite URL (#5173)

* Unified gopkg.in/yaml version to v2 (#5169)

* Unified gopkg.in/yaml version to v2 and go mod tidy

* update CHANGELOG

* bugfix: protect field access with lock to avoid possible data race (#5211)

* fix some comments (#5194)

Signed-off-by: cui fliter <imcusg@gmail.com>

* build(deps): bump github.com/prometheus/client_golang (#5213)

Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.14.0 to 1.15.1.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.14.0...v1.15.1)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* build(deps): bump go.etcd.io/etcd/client/v3 from 3.5.4 to 3.5.9 (#5209)

Bumps [go.etcd.io/etcd/client/v3](https://github.com/etcd-io/etcd) from 3.5.4 to 3.5.9.
- [Release notes](https://github.com/etcd-io/etcd/releases)
- [Commits](https://github.com/etcd-io/etcd/compare/v3.5.4...v3.5.9)

---
updated-dependencies:
- dependency-name: go.etcd.io/etcd/client/v3
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* cache: fix typo and optimize the naming

* Release 2.1.0 change log

* bugfix: beegoAppConfig String and Strings function has bug

* httplib: fix unstable test, do not use httplib.org

* chore: pkg imported more than once

* chore: fmt modify

* chore: Use github.com/go-kit/log

* chore: unnecessary use of fmt.Sprintf

* fix: golangci-lint error

* orm: refactor ORM introducing internal/models pkg

* remove adapter package

* build(deps): bump github.com/bits-and-blooms/bloom/v3

Bumps [github.com/bits-and-blooms/bloom/v3](https://github.com/bits-and-blooms/bloom) from 3.3.1 to 3.5.0.
- [Release notes](https://github.com/bits-and-blooms/bloom/releases)
- [Commits](https://github.com/bits-and-blooms/bloom/compare/v3.3.1...v3.5.0)

---
updated-dependencies:
- dependency-name: github.com/bits-and-blooms/bloom/v3
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* feat: add write-delete cache mode

* fix: unnecessary assignment to the blank identifier

* fix: add change into .CHANGELOG file

* build(deps): bump golang.org/x/sync from 0.1.0 to 0.3.0

Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.1.0 to 0.3.0.
- [Commits](https://github.com/golang/sync/compare/v0.1.0...v0.3.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* build(deps): bump golang.org/x/crypto

Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.0.0-20220315160706-3147a52a75dd to 0.10.0.
- [Commits](https://github.com/golang/crypto/commits/v0.10.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* remove golang--lint-ci

* Beego web.Run() runs the server twice

* fix 5255: Check the rows.Err() if rows.Next() is false

* closes 5254: %COL% should be a common placeholder

* build(deps): bump github.com/prometheus/client_golang

Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.15.1 to 1.16.0.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.15.1...v1.16.0)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix: use of ioutil package (#5261)

* fix ioutil.NopCloser

* fix ioutil.ReadAll

* fix ioutil.ReadFile

* fix ioutil.WriteFile

* run goimports -w -format-only ./

* update CHANGELOG.md

* feature: add write-double-delete cache mode (#5263)

* cache/redis: support skipEmptyPrefix option (#5264)

* fix: refactor InsertValue method (#5267)

* fix: refactor insertValue method and add the test

* fix: exec goimports and add Licence file header

* fix: modify construct method of dbBase

* fix: add modify record into CHANGELOG

* fix: modify InsertOrUpdate method (#5269)

* fix: modify InsertOrUpdate method, Remove the isMulti variable and its associated code

* fix: Delete unnecessary judgment branches

* fix: add modify record into CHANGELOG

* cache/redis: use redisConfig to receive incoming JSON (previously using a map) (#5268)

* refactor cache/redis: Use redisConfig to receive incoming JSON (previously using a map).

* refactor cache/redis: Use the string type to receive JSON parameters.

---------

Co-authored-by: Tan <tanqianheng@gmail.com>

* fix: refactor Delete method (#5271)

* fix: refactor Delete method and add test

* fix: add modify record into CHANGELOG

* fix: refactor update sql (#5274)

* fix: refactor UpdateSQL method and add test

* fix: add modify record into CHANGELOG

* fix: modify url in the CHANGELOG

* fix: modify pr url in the CHANGELOG

* Fix setPK function for table without primary key (#5276)

---------

Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: cui fliter <imcusg@gmail.com>
Co-authored-by: Stone-afk <73482944+Stone-afk@users.noreply.github.com>
Co-authored-by: hookokoko <hooko@tju.edu.cn>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: hookokoko <648646891@qq.com>
Co-authored-by: Stone-afk <1711865140@qq.com>
Co-authored-by: chenhaokun <chenhaokun@itiger.com>
Co-authored-by: Xuing <admin@xuing.cn>
Co-authored-by: cui fliter <imcusg@gmail.com>
Co-authored-by: guoguangwu <guoguangwu@magic-shield.com>
Co-authored-by: uzziah <uzziahlin@gmail.com>
Co-authored-by: Hanjiang Yu <delacroix.yu@gmail.com>
Co-authored-by: Kota <mdryzk64smsh@gmail.com>
Co-authored-by: Uzziah <120019273+uzziahlin@users.noreply.github.com>
Co-authored-by: Handkerchiefs-t <59816423+Handkerchiefs-t@users.noreply.github.com>
Co-authored-by: Tan <tanqianheng@gmail.com>
Co-authored-by: mlgd <mlgd17@gmail.com>
This commit is contained in:
Ming Deng
2023-07-31 23:00:02 +08:00
committed by GitHub
parent 420e11ee63
commit 0bd2df91a1
269 changed files with 3904 additions and 18313 deletions

View File

@@ -20,6 +20,10 @@ import (
"fmt"
"os"
"strings"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
)
type commander interface {
@@ -53,7 +57,7 @@ func RunCommand() {
BootStrap()
args := argString(os.Args[2:])
args := utils.ArgString(os.Args[2:])
name := args.Get(0)
if name == "help" {
@@ -112,7 +116,7 @@ func (d *commandSyncDb) Run() error {
for i, mi := range defaultModelCache.allOrdered() {
query := drops[i]
if !d.noInfo {
fmt.Printf("drop table `%s`\n", mi.table)
fmt.Printf("drop table `%s`\n", mi.Table)
}
_, err := db.Exec(query)
if d.verbose {
@@ -143,18 +147,18 @@ func (d *commandSyncDb) Run() error {
ctx := context.Background()
for i, mi := range defaultModelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name)
if !models.IsApplicableTableForDB(mi.AddrField, d.al.Name) {
fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.Table, d.al.Name)
continue
}
if tables[mi.table] {
if tables[mi.Table] {
if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table)
fmt.Printf("table `%s` already exists, skip\n", mi.Table)
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
var fields []*models.FieldInfo
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.Table)
if err != nil {
if d.rtOnError {
return err
@@ -162,8 +166,8 @@ func (d *commandSyncDb) Run() error {
fmt.Printf(" %s\n", err.Error())
}
for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; !ok {
for _, fi := range mi.Fields.FieldsDB {
if _, ok := columns[fi.Column]; !ok {
fields = append(fields, fi)
}
}
@@ -172,7 +176,7 @@ func (d *commandSyncDb) Run() error {
query := getColumnAddQuery(d.al, fi)
if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
fmt.Printf("add column `%s` for table `%s`\n", fi.FullName, mi.Table)
}
_, err := db.Exec(query)
@@ -187,7 +191,7 @@ func (d *commandSyncDb) Run() error {
}
}
for _, idx := range indexes[mi.table] {
for _, idx := range indexes[mi.Table] {
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
@@ -211,11 +215,11 @@ func (d *commandSyncDb) Run() error {
}
if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table)
fmt.Printf("create table `%s` \n", mi.Table)
}
queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] {
for _, idx := range indexes[mi.Table] {
queries = append(queries, idx.SQL)
}
@@ -265,7 +269,7 @@ func (d *commandSQLAll) Run() error {
var all []string
for i, mi := range defaultModelCache.allOrdered() {
queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] {
for _, idx := range indexes[mi.Table] {
queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")

View File

@@ -17,6 +17,8 @@ package orm
import (
"fmt"
"strings"
"github.com/beego/beego/v2/client/orm/internal/models"
)
type dbIndex struct {
@@ -26,17 +28,22 @@ type dbIndex struct {
}
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
func getColumnTyp(al *alias, fi *models.FieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
fieldSize := fi.size
fieldType := fi.FieldType
fieldSize := fi.Size
defer func() {
// handling the placeholder, including %COL%
col = strings.ReplaceAll(col, "%COL%", fi.Column)
}()
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText {
if al.Driver == DRPostgres && fi.ToText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize)
@@ -51,11 +58,11 @@ checkColumn:
col = T["time.Time-date"]
case TypeDateTimeField:
// the precision of sqlite is not implemented
if al.Driver == 2 || fi.timePrecision == nil {
if al.Driver == 2 || fi.TimePrecision == nil {
col = T["time.Time"]
} else {
s := T["time.Time-precision"]
col = fmt.Sprintf(s, *fi.timePrecision)
col = fmt.Sprintf(s, *fi.TimePrecision)
}
case TypeBitField:
@@ -85,7 +92,7 @@ checkColumn:
if !strings.Contains(s, "%d") {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
col = fmt.Sprintf(s, fi.Digits, fi.Decimals)
}
case TypeJSONField:
if al.Driver != DRPostgres {
@@ -100,8 +107,8 @@ checkColumn:
}
col = T["jsonb"]
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
fieldType = fi.RelModelInfo.Fields.Pk.FieldType
fieldSize = fi.RelModelInfo.Fields.Pk.Size
goto checkColumn
}
@@ -109,34 +116,34 @@ checkColumn:
}
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
func getColumnAddQuery(al *alias, fi *models.FieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
if !fi.null {
if !fi.Null {
typ += " " + "NOT NULL"
}
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
Q, fi.mi.table, Q,
Q, fi.column, Q,
Q, fi.Mi.Table, Q,
Q, fi.Column, Q,
typ, getColumnDefault(fi),
)
}
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
func getColumnDefault(fi *models.FieldInfo) string {
var v, t, d string
// Skip default attribute if field is in relations
if fi.rel || fi.reverse {
if fi.Rel || fi.Reverse {
return v
}
t = " DEFAULT '%s' "
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
switch fi.FieldType {
case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v
@@ -153,14 +160,14 @@ func getColumnDefault(fi *fieldInfo) string {
d = "{}"
}
if fi.colDefault {
if !fi.initial.Exist() {
if fi.ColDefault {
if !fi.Initial.Exist() {
v = fmt.Sprintf(t, "")
} else {
v = fmt.Sprintf(t, fi.initial.String())
v = fmt.Sprintf(t, fi.Initial.String())
}
} else {
if !fi.null {
if !fi.Null {
v = fmt.Sprintf(t, d)
}
}

View File

@@ -0,0 +1,39 @@
package orm
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm/internal/models"
)
func Test_getColumnTyp(t *testing.T) {
testCases := []struct {
name string
fi *models.FieldInfo
al *alias
wantCol string
}{
{
// https://github.com/beego/beego/issues/5254
name: "issue 5254",
fi: &models.FieldInfo{
FieldType: TypePositiveIntegerField,
Column: "my_col",
},
al: &alias{
DbBaser: newdbBasePostgres(),
},
wantCol: `bigint CHECK("my_col" >= 0)`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
col := getColumnTyp(tc.al, tc.fi)
assert.Equal(t, tc.wantCol, col)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,8 @@ import (
"sync"
"time"
"github.com/beego/beego/v2/client/orm/internal/logs"
lru "github.com/hashicorp/golang-lru"
)
@@ -320,7 +322,7 @@ func detectTZ(al *alias) {
al.TZ = t.Location()
}
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
@@ -347,7 +349,7 @@ func detectTZ(al *alias) {
if err == nil {
al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
}
@@ -479,7 +481,7 @@ end:
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
logs.DebugLog.Println(err.Error())
}
return err

View File

@@ -19,6 +19,10 @@ import (
"fmt"
"reflect"
"strings"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// mysql operators.
@@ -72,28 +76,28 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
// OperatorSQL get mysql operator.
func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
// DbTypes get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
// ShowTablesQuery show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
// ShowColumnsQuery show Columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
// IndexExists execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@@ -106,7 +110,7 @@ func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table strin
// If your primary key or unique column conflict will update
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
argsMap := map[string]string{}
@@ -120,10 +124,9 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model
}
}
isMulti := false
names := make([]string, 0, len(mi.fields.dbcols)-1)
names := make([]string, 0, len(mi.Fields.DBcols)-1)
Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
if err != nil {
return 0, err
}
@@ -150,26 +153,17 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model
qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
// conflitValue maybe is an int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
if !d.ins.HasReturningID(mi, &query) {
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
lastInsertId, err := res.LastInsertId()
if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable
} else {
return lastInsertId, nil

View File

@@ -19,6 +19,10 @@ import (
"fmt"
"strings"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/hints"
)
@@ -116,16 +120,16 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
case hints.KeyIgnoreIndex:
hint = `NO_INDEX`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
}
// execute insert sql with given struct and given values.
// InsertValue execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@@ -143,7 +147,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query)
@@ -156,7 +160,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn
lastInsertId, err := res.LastInsertId()
if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable
} else {
return lastInsertId, nil

View File

@@ -18,6 +18,10 @@ import (
"context"
"fmt"
"strconv"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// postgresql operators.
@@ -76,7 +80,7 @@ func (d *dbBasePostgres) OperatorSQL(operator string) string {
}
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)
@@ -128,20 +132,20 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
}
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
fi := mi.fields.pk
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
func (d *dbBasePostgres) HasReturningID(mi *models.ModelInfo, query *string) bool {
fi := mi.Fields.Pk
if fi.FieldType&IsPositiveIntegerField == 0 && fi.FieldType&IsIntegerField == 0 {
return false
}
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.Column)
}
return true
}
// sync auto key
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
@@ -149,9 +153,9 @@ func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
mi.Table, name,
Q, name, Q,
Q, mi.table, Q)
Q, mi.Table, Q)
if _, err := db.ExecContext(ctx, query); err != nil {
return err
}
@@ -164,9 +168,9 @@ func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
// show table columns sql for postgresql.
// show table Columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.Columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
// get column types of postgresql.
@@ -185,7 +189,7 @@ func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table st
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
logs.DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
return ``
}

View File

@@ -22,6 +22,10 @@ import (
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/hints"
)
@@ -74,9 +78,9 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
logs.DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
}
@@ -88,8 +92,8 @@ func (d *dbBaseSqlite) OperatorSQL(operator string) string {
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) {
if fi.FieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
@@ -114,7 +118,7 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
// get columns in sqlite.
// get Columns in sqlite.
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.QueryContext(ctx, query)
@@ -132,10 +136,10 @@ func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table strin
columns[name.String] = [3]string{name.String, typ.String, null.String}
}
return columns, nil
return columns, rows.Err()
}
// get show columns sql in sqlite.
// get show Columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
@@ -171,7 +175,7 @@ func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, inde
case hints.KeyUseIndex, hints.KeyForceIndex:
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
}

View File

@@ -19,6 +19,8 @@ import (
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
)
@@ -31,8 +33,8 @@ type dbTable struct {
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
mi *models.ModelInfo
fi *models.FieldInfo
jtl *dbTable
}
@@ -40,14 +42,14 @@ type dbTable struct {
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
mi *models.ModelInfo
base dbBaser
skipEnd bool
}
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
func (t *dbTables) set(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
@@ -64,7 +66,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
}
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
func (t *dbTables) add(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1
@@ -82,29 +84,29 @@ func (t *dbTables) get(name string) (*dbTable, bool) {
return j, ok
}
// get related fields info in recursive depth loop.
// get related Fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
func (t *dbTables) loopDepth(depth int, prefix string, fi *models.FieldInfo, related []string) []string {
if depth < 0 || fi.FieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
prefix = fi.Name
} else {
prefix = prefix + ExprSep + fi.name
prefix = prefix + ExprSep + fi.Name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
for _, fi := range fi.RelModelInfo.Fields.FieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
// parse related fields.
// parse related Fields.
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
@@ -117,7 +119,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
for _, fi := range t.mi.Fields.FieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
@@ -133,18 +135,18 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
inner := true
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
if fi, ok := mmi.Fields.GetByAny(ex); ok && fi.Rel && fi.FieldType != RelManyToMany {
names = append(names, fi.Name)
mmi = fi.RelModelInfo
if fi.null || t.skipEnd {
if fi.Null || t.skipEnd {
inner = false
}
jt := t.set(names, mmi, fi, inner)
jt.jtl = jtl
if fi.reverse {
if fi.Reverse {
cancel = false
}
@@ -185,24 +187,24 @@ func (t *dbTables) getJoinSQL() (join string) {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
table = jt.mi.Table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
case jt.fi.FieldType == RelManyToMany || jt.fi.FieldType == RelReverseMany || jt.fi.Reverse && jt.fi.ReverseFieldInfo.FieldType == RelManyToMany:
c1 = jt.fi.Mi.Fields.Pk.Column
for _, ffi := range jt.mi.Fields.FieldsRel {
if jt.fi.Mi == ffi.RelModelInfo {
c2 = ffi.Column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk.column
c1 = jt.fi.Column
c2 = jt.fi.RelModelInfo.Fields.Pk.Column
if jt.fi.reverse {
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
if jt.fi.Reverse {
c1 = jt.mi.Fields.Pk.Column
c2 = jt.fi.ReverseFieldInfo.Column
}
}
@@ -213,11 +215,11 @@ func (t *dbTables) getJoinSQL() (join string) {
}
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
func (t *dbTables) parseExprs(mi *models.ModelInfo, exprs []string) (index, name string, info *models.FieldInfo, success bool) {
var (
jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
fi *models.FieldInfo
fiN *models.FieldInfo
mmi = mi
)
@@ -238,38 +240,38 @@ loopFor:
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
fi, ok = mmi.Fields.GetByAny(ex)
}
_ = okN
if ok {
isRel := fi.rel || fi.reverse
isRel := fi.Rel || fi.Reverse
names = append(names, fi.name)
names = append(names, fi.Name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
case fi.Rel:
mmi = fi.RelModelInfo
if fi.FieldType == RelManyToMany {
mmi = fi.RelThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
case fi.Reverse:
mmi = fi.ReverseFieldInfo.Mi
}
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
fiN, okN = mmi.Fields.GetByAny(exprs[i+1])
}
if isRel && (!fi.mi.isThrough || num != i) {
if fi.null || t.skipEnd {
if isRel && (!fi.Mi.IsThrough || num != i) {
if fi.Null || t.skipEnd {
inner = false
}
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
if t.skipEnd && okN && fiN.Pk {
goto loopEnd
}
@@ -295,20 +297,20 @@ loopFor:
info = fi
if jtl == nil {
name = fi.name
name = fi.Name
} else {
name = jtl.name + ExprSep + fi.name
name = jtl.name + ExprSep + fi.Name
}
switch {
case fi.rel:
case fi.Rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case fi.Reverse:
switch fi.ReverseFieldInfo.FieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
info = fi.ReverseFieldInfo.Mi.Fields.Pk
name = info.Name
}
}
@@ -382,7 +384,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
}
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
@@ -415,7 +417,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q))
}
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
@@ -449,7 +451,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.Column, Q, order.SortString()))
}
}
@@ -458,7 +460,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
}
// generate limit sql.
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
func (t *dbTables) getLimitSQL(mi *models.ModelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
@@ -490,7 +492,7 @@ func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string)
}
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
func newDbTables(mi *models.ModelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi

231
client/orm/db_test.go Normal file
View File

@@ -0,0 +1,231 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm/internal/models"
)
func TestDbBase_InsertValueSQL(t *testing.T) {
mi := &models.ModelInfo{
Table: "test_table",
}
testCases := []struct {
name string
db *dbBase
isMulti bool
names []string
values []interface{}
wantRes string
}{
{
name: "single insert by dbBase",
db: &dbBase{
ins: &dbBase{},
},
isMulti: false,
names: []string{"name", "age"},
values: []interface{}{"test", 18},
wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)",
},
{
name: "single insert by dbBasePostgres",
db: &dbBase{
ins: newdbBasePostgres(),
},
isMulti: false,
names: []string{"name", "age"},
values: []interface{}{"test", 18},
wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)",
},
{
name: "multi insert by dbBase",
db: &dbBase{
ins: &dbBase{},
},
isMulti: true,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2", 19},
wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?), (?, ?)",
},
{
name: "multi insert by dbBasePostgres",
db: &dbBase{
ins: newdbBasePostgres(),
},
isMulti: true,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2", 19},
wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2), ($3, $4)",
},
{
name: "multi insert by dbBase but values is not enough",
db: &dbBase{
ins: &dbBase{},
},
isMulti: true,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2"},
wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)",
},
{
name: "multi insert by dbBasePostgres but values is not enough",
db: &dbBase{
ins: newdbBasePostgres(),
},
isMulti: true,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2"},
wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)",
},
{
name: "single insert by dbBase but values is double to names",
db: &dbBase{
ins: &dbBase{},
},
isMulti: false,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2", 19},
wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)",
},
{
name: "single insert by dbBasePostgres but values is double to names",
db: &dbBase{
ins: newdbBasePostgres(),
},
isMulti: false,
names: []string{"name", "age"},
values: []interface{}{"test", 18, "test2", 19},
wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.db.InsertValueSQL(tc.names, tc.values, tc.isMulti, mi)
assert.Equal(t, tc.wantRes, res)
})
}
}
func TestDbBase_UpdateSQL(t *testing.T) {
mi := &models.ModelInfo{
Table: "test_table",
}
testCases := []struct {
name string
db *dbBase
setNames []string
pkName string
wantRes string
}{
{
name: "update by dbBase",
db: &dbBase{
ins: &dbBase{},
},
setNames: []string{"name", "age", "sender"},
pkName: "id",
wantRes: "UPDATE `test_table` SET `name` = ?, `age` = ?, `sender` = ? WHERE `id` = ?",
},
{
name: "update by dbBasePostgres",
db: &dbBase{
ins: newdbBasePostgres(),
},
setNames: []string{"name", "age", "sender"},
pkName: "id",
wantRes: "UPDATE \"test_table\" SET \"name\" = $1, \"age\" = $2, \"sender\" = $3 WHERE \"id\" = $4",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.db.UpdateSQL(tc.setNames, tc.pkName, mi)
assert.Equal(t, tc.wantRes, res)
})
}
}
func TestDbBase_DeleteSQL(t *testing.T) {
mi := &models.ModelInfo{
Table: "test_table",
}
testCases := []struct {
name string
db *dbBase
whereCols []string
wantRes string
}{
{
name: "delete by dbBase with id",
db: &dbBase{
ins: &dbBase{},
},
whereCols: []string{"id"},
wantRes: "DELETE FROM `test_table` WHERE `id` = ?",
},
{
name: "delete by dbBase not id",
db: &dbBase{
ins: &dbBase{},
},
whereCols: []string{"name", "age"},
wantRes: "DELETE FROM `test_table` WHERE `name` = ? AND `age` = ?",
},
{
name: "delete by dbBasePostgres with id",
db: &dbBase{
ins: newdbBasePostgres(),
},
whereCols: []string{"id"},
wantRes: "DELETE FROM \"test_table\" WHERE \"id\" = $1",
},
{
name: "delete by dbBasePostgres not id",
db: &dbBase{
ins: newdbBasePostgres(),
},
whereCols: []string{"name", "age"},
wantRes: "DELETE FROM \"test_table\" WHERE \"name\" = $1 AND \"age\" = $2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := tc.db.DeleteSQL(tc.whereCols, mi)
assert.Equal(t, tc.wantRes, res)
})
}
}

View File

@@ -41,9 +41,9 @@ func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
// show Columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}

View File

@@ -18,6 +18,10 @@ import (
"fmt"
"reflect"
"time"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// get table alias.
@@ -29,32 +33,32 @@ func getDbAlias(name string) *alias {
}
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
func getExistPk(mi *models.ModelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.Fields.Pk
v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPositiveIntegerField > 0 {
v := ind.FieldByIndex(fi.FieldIndex)
if fi.FieldType&IsPositiveIntegerField > 0 {
vu := v.Uint()
exist = vu > 0
value = vu
} else if fi.fieldType&IsIntegerField > 0 {
} else if fi.FieldType&IsIntegerField > 0 {
vu := v.Int()
exist = true
value = vu
} else if fi.fieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else if fi.FieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.RelModelInfo, reflect.Indirect(v))
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
column = fi.Column
return
}
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
// get Fields description as flatted string.
func getFlatParams(fi *models.FieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
for _, arg := range args {
if arg == nil {
@@ -74,32 +78,32 @@ outFor:
case reflect.String:
v := val.String()
if fi != nil {
if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
if fi.FieldType == TypeTimeField || fi.FieldType == TypeDateField || fi.FieldType == TypeDateTimeField {
var t time.Time
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
t, err = time.ParseInLocation(utils.FormatDateTime, s, DefaultTimeLoc)
} else if len(v) >= 10 {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(formatDate, s, tz)
t, err = time.ParseInLocation(utils.FormatDate, s, tz)
} else {
s := v
if len(s) > 8 {
s = v[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
t, err = time.ParseInLocation(utils.FormatTime, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate)
} else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime)
if fi.FieldType == TypeDateField {
v = t.In(tz).Format(utils.FormatDate)
} else if fi.FieldType == TypeDateTimeField {
v = t.In(tz).Format(utils.FormatDateTime)
} else {
v = t.In(tz).Format(formatTime)
v = t.In(tz).Format(utils.FormatTime)
}
}
}
@@ -110,7 +114,7 @@ outFor:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
arg = val.Uint()
case reflect.Float32:
arg, _ = StrTo(ToStr(arg)).Float64()
arg, _ = utils.StrTo(utils.ToStr(arg)).Float64()
case reflect.Float64:
arg = val.Float()
case reflect.Bool:
@@ -143,18 +147,18 @@ outFor:
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate)
} else if fi != nil && fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime)
} else if fi != nil && fi.fieldType == TypeTimeField {
arg = v.In(tz).Format(formatTime)
if fi != nil && fi.FieldType == TypeDateField {
arg = v.In(tz).Format(utils.FormatDate)
} else if fi != nil && fi.FieldType == TypeDateTimeField {
arg = v.In(tz).Format(utils.FormatDateTime)
} else if fi != nil && fi.FieldType == TypeTimeField {
arg = v.In(tz).Format(utils.FormatTime)
} else {
arg = v.In(tz).Format(formatDateTime)
arg = v.In(tz).Format(utils.FormatDateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
name := models.GetFullName(typ)
var value interface{}
if mmi, ok := defaultModelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {

View File

@@ -22,12 +22,12 @@ import (
// don't forget to call next(...) inside your Filter
type FilterChain func(next Filter) Filter
// Filter's behavior is a little big strange.
// Filter behavior is a little big strange.
// it's only be called when users call methods of Ormer
// return value is an array. it's a little bit hard to understand,
// for example, the Ormer's Read method only return error
// so the filter processing this method should return an array whose first element is error
// and, Ormer's ReadOrCreateWithCtx return three values, so the Filter's result should contains three values
// and, Ormer's ReadOrCreateWithCtx return three values, so the Filter's result should contain three values
type Filter func(ctx context.Context, inv *Invocation) []interface{}
var globalFilterChains = make([]FilterChain, 0, 4)

View File

@@ -20,6 +20,10 @@ import (
"reflect"
"time"
utils2 "github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/core/utils"
)
@@ -192,13 +196,13 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS
var (
name string
md interface{}
mi *modelInfo
mi *models.ModelInfo
)
if table, ok := ptrStructOrTableName.(string); ok {
name = table
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
name = models.GetFullName(utils2.IndirectType(reflect.TypeOf(ptrStructOrTableName)))
md = ptrStructOrTableName
}
@@ -303,7 +307,7 @@ func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, erro
func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
var (
md interface{}
mi *modelInfo
mi *models.ModelInfo
)
sind := reflect.Indirect(reflect.ValueOf(mds))

View File

@@ -1,10 +1,10 @@
// Copyright 2020
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,24 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
package buffers
import (
"reflect"
"testing"
import "github.com/valyala/bytebufferpool"
"github.com/stretchr/testify/assert"
)
var _ Buffer = &bytebufferpool.ByteBuffer{}
type NotApplicableModel struct {
Id int
type Buffer interface {
Write(p []byte) (int, error)
WriteString(s string) (int, error)
WriteByte(c byte) error
String() string
}
func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool {
return db == "default"
func Get() Buffer {
return bytebufferpool.Get()
}
func TestIsApplicableTableForDB(t *testing.T) {
assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa"))
assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default"))
func Put(bf Buffer) {
bytebufferpool.Put(bf.(*bytebufferpool.ByteBuffer))
}

View File

@@ -0,0 +1,20 @@
package logs
import (
"io"
"log"
"os"
)
var DebugLog = NewLog(os.Stdout)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d
}

View File

@@ -0,0 +1,785 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
import (
"fmt"
"strconv"
"time"
"github.com/beego/beego/v2/client/orm/internal/utils"
)
// Define the Type enum
const (
TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField
TypeTextField
TypeTimeField
TypeDateField
TypeDateTimeField
TypeBitField
TypeSmallIntegerField
TypeIntegerField
TypeBigIntegerField
TypePositiveBitField
TypePositiveSmallIntegerField
TypePositiveIntegerField
TypePositiveBigIntegerField
TypeFloatField
TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1
)
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := utils.StrTo(d).Bool()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
// CharField A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeVarCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder
var _ Fielder = new(CharField)
// TimeField A time, represented in go by a time.Time instance.
// only time values like 10:00:00
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField)
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatDate)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatDateTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datetime implement fielder
var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return utils.ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Float64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder
var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)

View File

@@ -12,147 +12,149 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
package models
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/beego/beego/v2/client/orm/internal/utils"
)
var errSkipField = errors.New("skip field")
// field info collection
type fields struct {
pk *fieldInfo
columns map[string]*fieldInfo
fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo
fieldsByType map[int][]*fieldInfo
fieldsRel []*fieldInfo
fieldsReverse []*fieldInfo
fieldsDB []*fieldInfo
rels []*fieldInfo
orders []string
dbcols []string
// Fields field info collection
type Fields struct {
Pk *FieldInfo
Columns map[string]*FieldInfo
Fields map[string]*FieldInfo
FieldsLow map[string]*FieldInfo
FieldsByType map[int][]*FieldInfo
FieldsRel []*FieldInfo
FieldsReverse []*FieldInfo
FieldsDB []*FieldInfo
Rels []*FieldInfo
Orders []string
DBcols []string
}
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
f.fields[fi.name] = fi
f.fieldsLow[strings.ToLower(fi.name)] = fi
// Add adds field info
func (f *Fields) Add(fi *FieldInfo) (added bool) {
if f.Fields[fi.Name] == nil && f.Columns[fi.Column] == nil {
f.Columns[fi.Column] = fi
f.Fields[fi.Name] = fi
f.FieldsLow[strings.ToLower(fi.Name)] = fi
} else {
return
}
if _, ok := f.fieldsByType[fi.fieldType]; !ok {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
if _, ok := f.FieldsByType[fi.FieldType]; !ok {
f.FieldsByType[fi.FieldType] = make([]*FieldInfo, 0)
}
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
f.orders = append(f.orders, fi.column)
if fi.dbcol {
f.dbcols = append(f.dbcols, fi.column)
f.fieldsDB = append(f.fieldsDB, fi)
f.FieldsByType[fi.FieldType] = append(f.FieldsByType[fi.FieldType], fi)
f.Orders = append(f.Orders, fi.Column)
if fi.DBcol {
f.DBcols = append(f.DBcols, fi.Column)
f.FieldsDB = append(f.FieldsDB, fi)
}
if fi.rel {
f.fieldsRel = append(f.fieldsRel, fi)
if fi.Rel {
f.FieldsRel = append(f.FieldsRel, fi)
}
if fi.reverse {
f.fieldsReverse = append(f.fieldsReverse, fi)
if fi.Reverse {
f.FieldsReverse = append(f.FieldsReverse, fi)
}
return true
}
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
// GetByName get field info by name
func (f *Fields) GetByName(name string) *FieldInfo {
return f.Fields[name]
}
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
// GetByColumn get field info by column name
func (f *Fields) GetByColumn(column string) *FieldInfo {
return f.Columns[column]
}
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
// GetByAny get field info by string, name is prior
func (f *Fields) GetByAny(name string) (*FieldInfo, bool) {
if fi, ok := f.Fields[name]; ok {
return fi, ok
}
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
if fi, ok := f.FieldsLow[strings.ToLower(name)]; ok {
return fi, ok
}
if fi, ok := f.columns[name]; ok {
if fi, ok := f.Columns[name]; ok {
return fi, ok
}
return nil, false
}
// create new field info collection
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
f.fieldsLow = make(map[string]*fieldInfo)
f.columns = make(map[string]*fieldInfo)
f.fieldsByType = make(map[int][]*fieldInfo)
// NewFields create new field info collection
func NewFields() *Fields {
f := new(Fields)
f.Fields = make(map[string]*FieldInfo)
f.FieldsLow = make(map[string]*FieldInfo)
f.Columns = make(map[string]*FieldInfo)
f.FieldsByType = make(map[int][]*FieldInfo)
return f
}
// single field info
type fieldInfo struct {
dbcol bool // table column fk and onetoone
inModel bool
auto bool
pk bool
null bool
index bool
unique bool
colDefault bool // whether has default tag
toText bool
autoNow bool
autoNowAdd bool
rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
reverse bool
isFielder bool // implement Fielder interface
mi *modelInfo
fieldIndex []int
fieldType int
name string
fullName string
column string
addrValue reflect.Value
sf reflect.StructField
initial StrTo // store the default value
size int
reverseField string
reverseFieldInfo *fieldInfo
reverseFieldInfoTwo *fieldInfo
reverseFieldInfoM2M *fieldInfo
relTable string
relThrough string
relThroughModelInfo *modelInfo
relModelInfo *modelInfo
digits int
decimals int
onDelete string
description string
timePrecision *int
// FieldInfo single field info
type FieldInfo struct {
DBcol bool // table column fk and onetoone
InModel bool
Auto bool
Pk bool
Null bool
Index bool
Unique bool
ColDefault bool // whether has default tag
ToText bool
AutoNow bool
AutoNowAdd bool
Rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
Reverse bool
IsFielder bool // implement Fielder interface
Mi *ModelInfo
FieldIndex []int
FieldType int
Name string
FullName string
Column string
AddrValue reflect.Value
Sf reflect.StructField
Initial utils.StrTo // store the default value
Size int
ReverseField string
ReverseFieldInfo *FieldInfo
ReverseFieldInfoTwo *FieldInfo
ReverseFieldInfoM2M *FieldInfo
RelTable string
RelThrough string
RelThroughModelInfo *ModelInfo
RelModelInfo *ModelInfo
Digits int
Decimals int
OnDelete string
Description string
TimePrecision *int
}
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) {
// NewFieldInfo new field info
func NewFieldInfo(mi *ModelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *FieldInfo, err error) {
var (
tag string
tagValue string
initial StrTo // store the default value
initial utils.StrTo // store the default value
fieldType int
attrs map[string]bool
tags map[string]string
addrField reflect.Value
)
fi = new(fieldInfo)
fi = new(FieldInfo)
// if field which CanAddr is the follow type
// A value is addressable if it is an element of a slice,
@@ -168,7 +170,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
}
}
attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
attrs, tags = ParseStructTag(sf.Tag.Get(DefaultStructTagName))
if _, ok := attrs["-"]; ok {
return nil, errSkipField
@@ -187,7 +189,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
checkType:
switch f := addrField.Interface().(type) {
case Fielder:
fi.isFielder = true
fi.IsFielder = true
if field.Kind() == reflect.Ptr {
err = fmt.Errorf("the model Fielder can not be use ptr")
goto end
@@ -211,9 +213,9 @@ checkType:
case "m2m":
fieldType = RelManyToMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
fi.RelTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
fi.RelThrough = tv
}
break checkType
default:
@@ -231,9 +233,9 @@ checkType:
case "many":
fieldType = RelReverseMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
fi.RelTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
fi.RelThrough = tv
}
break checkType
default:
@@ -295,117 +297,117 @@ checkType:
goto end
}
fi.fieldType = fieldType
fi.name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = addrField
fi.sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name
fi.FieldType = fieldType
fi.Name = sf.Name
fi.Column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.AddrValue = addrField
fi.Sf = sf
fi.FullName = mi.FullName + mName + "." + sf.Name
fi.description = tags["description"]
fi.null = attrs["null"]
fi.index = attrs["index"]
fi.auto = attrs["auto"]
fi.pk = attrs["pk"]
fi.unique = attrs["unique"]
fi.Description = tags["description"]
fi.Null = attrs["null"]
fi.Index = attrs["index"]
fi.Auto = attrs["auto"]
fi.Pk = attrs["pk"]
fi.Unique = attrs["unique"]
// Mark object property if there is attribute "default" in the orm configuration
if _, ok := tags["default"]; ok {
fi.colDefault = true
fi.ColDefault = true
}
switch fieldType {
case RelManyToMany, RelReverseMany, RelReverseOne:
fi.null = false
fi.index = false
fi.auto = false
fi.pk = false
fi.unique = false
fi.Null = false
fi.Index = false
fi.Auto = false
fi.Pk = false
fi.Unique = false
default:
fi.dbcol = true
fi.DBcol = true
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
fi.rel = true
fi.Rel = true
if fieldType == RelOneToOne {
fi.unique = true
fi.Unique = true
}
case RelReverseMany, RelReverseOne:
fi.reverse = true
fi.Reverse = true
}
if fi.rel && fi.dbcol {
if fi.Rel && fi.DBcol {
switch onDelete {
case odCascade, odDoNothing:
case odSetDefault:
case OdCascade, OdDoNothing:
case OdSetDefault:
if !initial.Exist() {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case odSetNULL:
if !fi.null {
case OdSetNULL:
if !fi.Null {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = odCascade
onDelete = OdCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
}
}
fi.onDelete = onDelete
fi.OnDelete = onDelete
}
switch fieldType {
case TypeBooleanField:
case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" {
v, e := StrTo(size).Int32()
v, e := utils.StrTo(size).Int32()
if e != nil {
err = fmt.Errorf("wrong size value `%s`", size)
} else {
fi.size = int(v)
fi.Size = int(v)
}
} else {
fi.size = 255
fi.toText = true
fi.Size = 255
fi.ToText = true
}
case TypeTextField:
fi.index = false
fi.unique = false
fi.Index = false
fi.Unique = false
case TypeTimeField, TypeDateField, TypeDateTimeField:
if fieldType == TypeDateTimeField {
if precision != "" {
v, e := StrTo(precision).Int()
v, e := utils.StrTo(precision).Int()
if e != nil {
err = fmt.Errorf("convert %s to int error:%v", precision, e)
} else {
fi.timePrecision = &v
fi.TimePrecision = &v
}
}
}
if attrs["auto_now"] {
fi.autoNow = true
fi.AutoNow = true
} else if attrs["auto_now_add"] {
fi.autoNowAdd = true
fi.AutoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:
d1 := digits
d2 := decimals
v1, er1 := StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int8()
v1, er1 := utils.StrTo(d1).Int8()
v2, er2 := utils.StrTo(d2).Int8()
if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end
}
fi.digits = int(v1)
fi.decimals = int(v2)
fi.Digits = int(v1)
fi.Decimals = int(v2)
default:
switch {
case fieldType&IsIntegerField > 0:
@@ -414,33 +416,33 @@ checkType:
}
if fieldType&IsIntegerField == 0 {
if fi.auto {
if fi.Auto {
err = fmt.Errorf("non-integer type cannot set auto")
goto end
}
}
if fi.auto || fi.pk {
if fi.auto {
if fi.Auto || fi.Pk {
if fi.Auto {
switch addrField.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
default:
err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind())
goto end
}
fi.pk = true
fi.Pk = true
}
fi.null = false
fi.index = false
fi.unique = false
fi.Null = false
fi.Index = false
fi.Unique = false
}
if fi.unique {
fi.index = false
if fi.Unique {
fi.Index = false
}
// can not set default for these type
if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
if fi.Auto || fi.Pk || fi.Unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
initial.Clear()
}
@@ -474,7 +476,7 @@ checkType:
}
}
fi.initial = initial
fi.Initial = initial
end:
if err != nil {
return nil, err

View File

@@ -0,0 +1,148 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
import (
"fmt"
"os"
"reflect"
)
// ModelInfo single model info
type ModelInfo struct {
Manual bool
IsThrough bool
Pkg string
Name string
FullName string
Table string
Model interface{}
Fields *Fields
AddrField reflect.Value // store the original struct value
Uniques []string
}
// NewModelInfo new model info
func NewModelInfo(val reflect.Value) (mi *ModelInfo) {
mi = &ModelInfo{}
mi.Fields = NewFields()
ind := reflect.Indirect(val)
mi.AddrField = val
mi.Name = ind.Type().Name()
mi.FullName = GetFullName(ind.Type())
AddModelFields(mi, ind, "", []int{})
return
}
// AddModelFields index: FieldByIndex returns the nested field corresponding to index
func AddModelFields(mi *ModelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *FieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct Fields
if sf.Anonymous {
AddModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = NewFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
// record current field index
fi.FieldIndex = append(fi.FieldIndex, index...)
fi.FieldIndex = append(fi.FieldIndex, i)
fi.Mi = mi
fi.InModel = true
if !mi.Fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.Column)
break
}
if fi.Pk {
if mi.Fields.Pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
mi.Fields.Pk = fi
}
}
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
}
// NewM2MModelInfo combine related model info to new model info.
// prepare for relation models query.
func NewM2MModelInfo(m1, m2 *ModelInfo) (mi *ModelInfo) {
mi = new(ModelInfo)
mi.Fields = NewFields()
mi.Table = m1.Table + "_" + m2.Table + "s"
mi.Name = CamelString(mi.Table)
mi.FullName = m1.Pkg + "." + mi.Name
fa := new(FieldInfo) // pk
f1 := new(FieldInfo) // m1 table RelForeignKey
f2 := new(FieldInfo) // m2 table RelForeignKey
fa.FieldType = TypeBigIntegerField
fa.Auto = true
fa.Pk = true
fa.DBcol = true
fa.Name = "Id"
fa.Column = "id"
fa.FullName = mi.FullName + "." + fa.Name
f1.DBcol = true
f2.DBcol = true
f1.FieldType = RelForeignKey
f2.FieldType = RelForeignKey
f1.Name = CamelString(m1.Table)
f2.Name = CamelString(m2.Table)
f1.FullName = mi.FullName + "." + f1.Name
f2.FullName = mi.FullName + "." + f2.Name
f1.Column = m1.Table + "_id"
f2.Column = m2.Table + "_id"
f1.Rel = true
f2.Rel = true
f1.RelTable = m1.Table
f2.RelTable = m2.Table
f1.RelModelInfo = m1
f2.RelModelInfo = m2
f1.Mi = mi
f2.Mi = mi
mi.Fields.Add(fa)
mi.Fields.Add(f1)
mi.Fields.Add(f2)
mi.Fields.Pk = fa
mi.Uniques = []string{f1.Column, f2.Column}
return
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
package models
import (
"database/sql"
@@ -20,6 +20,8 @@ import (
"reflect"
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/logs"
)
// 1 is attr
@@ -48,15 +50,29 @@ var supportTag = map[string]int{
"precision": 2,
}
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
type fn func(string) string
var (
NameStrategyMap = map[string]fn{
DefaultNameStrategy: SnakeString,
SnakeAcronymNameStrategy: SnakeStringWithAcronym,
}
DefaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
NameStrategy = DefaultNameStrategy
defaultStructTagDelim = ";"
DefaultStructTagName = "orm"
)
// GetFullName get reflect.Type name with package path.
func GetFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
// getTableName get struct table name.
// GetTableName get struct table name.
// If the struct implement the TableName, then get the result as tablename
// else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string {
func GetTableName(val reflect.Value) string {
if fun := val.MethodByName("TableName"); fun.IsValid() {
vals := fun.Call([]reflect.Value{})
// has return and the first val is string
@@ -64,11 +80,11 @@ func getTableName(val reflect.Value) string {
return vals[0].String()
}
}
return snakeString(reflect.Indirect(val).Type().Name())
return SnakeString(reflect.Indirect(val).Type().Name())
}
// get table engine, myisam or innodb.
func getTableEngine(val reflect.Value) string {
// GetTableEngine get table engine, myisam or innodb.
func GetTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
@@ -79,8 +95,8 @@ func getTableEngine(val reflect.Value) string {
return ""
}
// get table index from method.
func getTableIndex(val reflect.Value) [][]string {
// GetTableIndex get table index from method.
func GetTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
@@ -93,8 +109,8 @@ func getTableIndex(val reflect.Value) [][]string {
return nil
}
// get table unique from method
func getTableUnique(val reflect.Value) [][]string {
// GetTableUnique get table unique from method
func GetTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
@@ -107,8 +123,8 @@ func getTableUnique(val reflect.Value) [][]string {
return nil
}
// get whether the table needs to be created for the database alias
func isApplicableTableForDB(val reflect.Value, db string) bool {
// IsApplicableTableForDB get whether the table needs to be created for the database alias
func IsApplicableTableForDB(val reflect.Value, db string) bool {
if !val.IsValid() {
return true
}
@@ -126,7 +142,7 @@ func isApplicableTableForDB(val reflect.Value, db string) bool {
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := col
if col == "" {
column = nameStrategyMap[nameStrategy](sf.Name)
column = NameStrategyMap[NameStrategy](sf.Name)
}
switch ft {
case RelForeignKey, RelOneToOne:
@@ -218,8 +234,8 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return
}
// parse struct tag string
func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
// ParseStructTag parse struct tag string
func ParseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
attrs = make(map[string]bool)
tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) {
@@ -236,8 +252,74 @@ func parseStructTag(data string) (attrs map[string]bool, tags map[string]string)
tags[name] = v
}
} else {
DebugLog.Println("unsupport orm tag", v)
logs.DebugLog.Println("unsupport orm tag", v)
}
}
return
}
func SnakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// SnakeString snake string, XxYy to xx_yy , XxYY to xx_y_y
func SnakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// CamelString camel string, xx_yy to XxYy
func CamelString(s string) string {
data := make([]byte, 0, len(s))
flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
if d == '_' {
flag = true
continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
}
data = append(data, d)
}
return string(data)
}
const (
OdCascade = "cascade"
OdSetNULL = "set_null"
OdSetDefault = "set_default"
OdDoNothing = "do_nothing"
)

View File

@@ -1,10 +1,10 @@
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,27 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
package models
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
type NotApplicableModel struct {
Id int
}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool {
return db == "default"
}
for _, v := range snake {
res := camelString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
func TestIsApplicableTableForDB(t *testing.T) {
assert.False(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa"))
assert.True(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default"))
}
func TestSnakeString(t *testing.T) {
@@ -45,7 +44,7 @@ func TestSnakeString(t *testing.T) {
}
for _, v := range camel {
res := snakeString(v)
res := SnakeString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
@@ -62,7 +61,24 @@ func TestSnakeStringWithAcronym(t *testing.T) {
}
for _, v := range camel {
res := snakeStringWithAcronym(v)
res := SnakeStringWithAcronym(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
for _, v := range snake {
res := CamelString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}

View File

@@ -0,0 +1,23 @@
// Copyright 2023 beego-dev. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}

View File

@@ -0,0 +1,249 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"math/big"
"reflect"
"strconv"
"time"
)
// StrTo is the target string
type StrTo string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(rune(0x1E))
}
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(rune(0x1E))
}
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
return v, err
}
return ni.Int64(), nil
}
return v, err
}
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// Uint32 string to uint32
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
return v, err
}
return ni.Uint64(), nil
}
return v, err
}
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, ArgInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, ArgInt(args).Get(0, 10))
case string:
s = v
case []byte:
s = string(v)
default:
s = fmt.Sprintf("%v", v)
}
return s
}
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
}
return
}
type ArgString []string
// Get get string by index from string slice
func (a ArgString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type ArgInt []int
// Get get int by index from int slice
func (a ArgInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// TimeParse parse time to string with location
func TimeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// IndirectType get pointer indirect type
func IndirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
return IndirectType(v.Elem())
default:
return v
}
}
const (
FormatTime = "15:04:05"
FormatDate = "2006-01-02"
FormatDateTime = "2006-01-02 15:04:05"
)
var (
DefaultTimeLoc = time.Local
)

View File

@@ -17,6 +17,8 @@ package orm
import (
"context"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// Invocation represents an "Orm" invocation
@@ -27,7 +29,7 @@ type Invocation struct {
// the args are all arguments except context.Context
Args []interface{}
mi *modelInfo
mi *models.ModelInfo
// f is the Orm operation
f func(ctx context.Context) []interface{}
@@ -39,7 +41,7 @@ type Invocation struct {
func (inv *Invocation) GetTableName() string {
if inv.mi != nil {
return inv.mi.table
return inv.mi.Table
}
return ""
}
@@ -51,8 +53,8 @@ func (inv *Invocation) execute(ctx context.Context) []interface{} {
// GetPkFieldName return the primary key of this table
// if not found, "" is returned
func (inv *Invocation) GetPkFieldName() string {
if inv.mi.fields.pk != nil {
return inv.mi.fields.pk.name
if inv.mi.Fields.Pk != nil {
return inv.mi.Fields.Pk.Name
}
return ""
}

View File

@@ -274,7 +274,7 @@ func (m *Migration) GetSQL() (sql string) {
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf(",\n PRIMARY KEY( ")
sql += ",\n PRIMARY KEY( "
}
for index, column := range m.Primary {
sql += fmt.Sprintf(" `%s`", column.Name)
@@ -284,7 +284,7 @@ func (m *Migration) GetSQL() (sql string) {
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf(")")
sql += ")"
}
for _, unique := range m.Uniques {
@@ -295,7 +295,7 @@ func (m *Migration) GetSQL() (sql string) {
sql += ","
}
}
sql += fmt.Sprintf(")")
sql += ")"
}
for _, foreign := range m.Foreigns {
sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
@@ -356,7 +356,7 @@ func (m *Migration) GetSQL() (sql string) {
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
sql += "\n DROP PRIMARY KEY,"
}
for index, unique := range m.Uniques {

View File

@@ -106,8 +106,10 @@ func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock {
// Now you may be need to use golang/mock to generate QueryM2M mock instance
// Or use DoNothingQueryM2Mer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
//
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
//
// when you write test code:
// MockQueryM2MWithCtx("post", "Tags", mockM2Mer)
// "post" is the table name of model Post structure

View File

@@ -17,6 +17,8 @@ package orm
import (
"testing"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/stretchr/testify/assert"
)
@@ -53,10 +55,10 @@ func TestDbBase_GetTables(t *testing.T) {
assert.True(t, ok)
assert.NotNil(t, mi)
engine := getTableEngine(mi.addrField)
engine := models.GetTableEngine(mi.AddrField)
assert.Equal(t, "innodb", engine)
uniques := getTableUnique(mi.addrField)
uniques := models.GetTableUnique(mi.AddrField)
assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques)
indexes := getTableIndex(mi.addrField)
indexes := models.GetTableIndex(mi.AddrField)
assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes)
}

View File

@@ -21,15 +21,8 @@ import (
"runtime/debug"
"strings"
"sync"
)
const (
odCascade = "cascade"
odSetNULL = "set_null"
odSetDefault = "set_default"
odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
imodels "github.com/beego/beego/v2/client/orm/internal/models"
)
var defaultModelCache = NewModelCacheHandler()
@@ -38,22 +31,22 @@ var defaultModelCache = NewModelCacheHandler()
type modelCache struct {
sync.RWMutex // only used outsite for bootStrap
orders []string
cache map[string]*modelInfo
cacheByFullName map[string]*modelInfo
cache map[string]*imodels.ModelInfo
cacheByFullName map[string]*imodels.ModelInfo
done bool
}
// NewModelCacheHandler generator of modelCache
func NewModelCacheHandler() *modelCache {
return &modelCache{
cache: make(map[string]*modelInfo),
cacheByFullName: make(map[string]*modelInfo),
cache: make(map[string]*imodels.ModelInfo),
cacheByFullName: make(map[string]*imodels.ModelInfo),
}
}
// get all model info
func (mc *modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
func (mc *modelCache) all() map[string]*imodels.ModelInfo {
m := make(map[string]*imodels.ModelInfo, len(mc.cache))
for k, v := range mc.cache {
m[k] = v
}
@@ -61,8 +54,8 @@ func (mc *modelCache) all() map[string]*modelInfo {
}
// get ordered model info
func (mc *modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
func (mc *modelCache) allOrdered() []*imodels.ModelInfo {
m := make([]*imodels.ModelInfo, 0, len(mc.orders))
for _, table := range mc.orders {
m = append(m, mc.cache[table])
}
@@ -70,30 +63,30 @@ func (mc *modelCache) allOrdered() []*modelInfo {
}
// get model info by table name
func (mc *modelCache) get(table string) (mi *modelInfo, ok bool) {
func (mc *modelCache) get(table string) (mi *imodels.ModelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
// get model info by full name
func (mc *modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
func (mc *modelCache) getByFullName(name string) (mi *imodels.ModelInfo, ok bool) {
mi, ok = mc.cacheByFullName[name]
return
}
func (mc *modelCache) getByMd(md interface{}) (*modelInfo, bool) {
func (mc *modelCache) getByMd(md interface{}) (*imodels.ModelInfo, bool) {
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
name := getFullName(typ)
name := imodels.GetFullName(typ)
return mc.getByFullName(name)
}
// set model info to collection
func (mc *modelCache) set(table string, mi *modelInfo) *modelInfo {
func (mc *modelCache) set(table string, mi *imodels.ModelInfo) *imodels.ModelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
mc.cacheByFullName[mi.fullName] = mi
mc.cacheByFullName[mi.FullName] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
@@ -106,8 +99,8 @@ func (mc *modelCache) clean() {
defer mc.Unlock()
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
mc.cacheByFullName = make(map[string]*modelInfo)
mc.cache = make(map[string]*imodels.ModelInfo)
mc.cacheByFullName = make(map[string]*imodels.ModelInfo)
mc.done = false
}
@@ -120,7 +113,7 @@ func (mc *modelCache) bootstrap() {
}
var (
err error
models map[string]*modelInfo
models map[string]*imodels.ModelInfo
)
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register DataBase alias named `default`")
@@ -131,51 +124,51 @@ func (mc *modelCache) bootstrap() {
// RelManyToMany set the relTable
models = mc.all()
for _, mi := range models {
for _, fi := range mi.fields.columns {
if fi.rel || fi.reverse {
elm := fi.addrValue.Type().Elem()
if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
for _, fi := range mi.Fields.Columns {
if fi.Rel || fi.Reverse {
elm := fi.AddrValue.Type().Elem()
if fi.FieldType == RelReverseMany || fi.FieldType == RelManyToMany {
elm = elm.Elem()
}
// check the rel or reverse model already register
name := getFullName(elm)
name := imodels.GetFullName(elm)
mii, ok := mc.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
if !ok || mii.Pkg != elm.PkgPath() {
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.FullName, elm.String())
goto end
}
fi.relModelInfo = mii
fi.RelModelInfo = mii
switch fi.fieldType {
switch fi.FieldType {
case RelManyToMany:
if fi.relThrough != "" {
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
rmi, ok := mc.getByFullName(fi.relThrough)
if !ok || pn != rmi.pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
if fi.RelThrough != "" {
if i := strings.LastIndex(fi.RelThrough, "."); i != -1 && len(fi.RelThrough) > (i+1) {
pn := fi.RelThrough[:i]
rmi, ok := mc.getByFullName(fi.RelThrough)
if !ok || pn != rmi.Pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.FullName, fi.RelThrough)
goto end
}
fi.relThroughModelInfo = rmi
fi.relTable = rmi.table
fi.RelThroughModelInfo = rmi
fi.RelTable = rmi.Table
} else {
err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.FullName, fi.RelThrough)
goto end
}
} else {
i := newM2MModelInfo(mi, mii)
if fi.relTable != "" {
i.table = fi.relTable
i := imodels.NewM2MModelInfo(mi, mii)
if fi.RelTable != "" {
i.Table = fi.RelTable
}
if v := mc.set(i.table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
if v := mc.set(i.Table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.RelTable)
goto end
}
fi.relTable = i.table
fi.relThroughModelInfo = i
fi.RelTable = i.Table
fi.RelThroughModelInfo = i
}
fi.relThroughModelInfo.isThrough = true
fi.RelThroughModelInfo.IsThrough = true
}
}
}
@@ -185,42 +178,42 @@ func (mc *modelCache) bootstrap() {
// if not exist, add a new field to the relModelInfo
models = mc.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
for _, fi := range mi.Fields.FieldsRel {
switch fi.FieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
inModel := false
for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
if ffi.relModelInfo == mi {
for _, ffi := range fi.RelModelInfo.Fields.FieldsReverse {
if ffi.RelModelInfo == mi {
inModel = true
break
}
}
if !inModel {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
ffi.name = mi.name
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
ffi.reverse = true
ffi.relModelInfo = mi
ffi.mi = rmi
if fi.fieldType == RelOneToOne {
ffi.fieldType = RelReverseOne
rmi := fi.RelModelInfo
ffi := new(imodels.FieldInfo)
ffi.Name = mi.Name
ffi.Column = ffi.Name
ffi.FullName = rmi.FullName + "." + ffi.Name
ffi.Reverse = true
ffi.RelModelInfo = mi
ffi.Mi = rmi
if fi.FieldType == RelOneToOne {
ffi.FieldType = RelReverseOne
} else {
ffi.fieldType = RelReverseMany
ffi.FieldType = RelReverseMany
}
if !rmi.fields.Add(ffi) {
if !rmi.Fields.Add(ffi) {
added := false
for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
if added = rmi.fields.Add(ffi); added {
ffi.Name = fmt.Sprintf("%s%d", mi.Name, cnt)
ffi.Column = ffi.Name
ffi.FullName = rmi.FullName + "." + ffi.Name
if added = rmi.Fields.Add(ffi); added {
break
}
}
if !added {
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.FullName, ffi.FullName))
}
}
}
@@ -230,24 +223,24 @@ func (mc *modelCache) bootstrap() {
models = mc.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
for _, fi := range mi.Fields.FieldsRel {
switch fi.FieldType {
case RelManyToMany:
for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
switch ffi.fieldType {
for _, ffi := range fi.RelThroughModelInfo.Fields.FieldsRel {
switch ffi.FieldType {
case RelOneToOne, RelForeignKey:
if ffi.relModelInfo == fi.relModelInfo {
fi.reverseFieldInfoTwo = ffi
if ffi.RelModelInfo == fi.RelModelInfo {
fi.ReverseFieldInfoTwo = ffi
}
if ffi.relModelInfo == mi {
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
if ffi.RelModelInfo == mi {
fi.ReverseField = ffi.Name
fi.ReverseFieldInfo = ffi
}
}
}
if fi.reverseFieldInfoTwo == nil {
if fi.ReverseFieldInfoTwo == nil {
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
fi.relThroughModelInfo.fullName)
fi.RelThroughModelInfo.FullName)
goto end
}
}
@@ -256,63 +249,63 @@ func (mc *modelCache) bootstrap() {
models = mc.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsReverse {
switch fi.fieldType {
for _, fi := range mi.Fields.FieldsReverse {
switch fi.FieldType {
case RelReverseOne:
found := false
mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
if ffi.relModelInfo == mi {
for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelOneToOne] {
if ffi.RelModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
fi.ReverseField = ffi.Name
fi.ReverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
ffi.ReverseField = fi.Name
ffi.ReverseFieldInfo = fi
break mForA
}
}
if !found {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName)
goto end
}
case RelReverseMany:
found := false
mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
if ffi.relModelInfo == mi {
for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelForeignKey] {
if ffi.RelModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
fi.ReverseField = ffi.Name
fi.ReverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
ffi.ReverseField = fi.Name
ffi.ReverseFieldInfo = fi
break mForB
}
}
if !found {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
fi.relTable != "" && fi.relTable == ffi.relTable ||
fi.relThrough == "" && fi.relTable == ""
if ffi.relModelInfo == mi && conditions {
for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelManyToMany] {
conditions := fi.RelThrough != "" && fi.RelThrough == ffi.RelThrough ||
fi.RelTable != "" && fi.RelTable == ffi.RelTable ||
fi.RelThrough == "" && fi.RelTable == ""
if ffi.RelModelInfo == mi && conditions {
found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name
fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
fi.relThroughModelInfo = ffi.relThroughModelInfo
fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
fi.reverseFieldInfoM2M = ffi
ffi.reverseFieldInfoM2M = fi
fi.ReverseField = ffi.ReverseFieldInfoTwo.Name
fi.ReverseFieldInfo = ffi.ReverseFieldInfoTwo
fi.RelThroughModelInfo = ffi.RelThroughModelInfo
fi.ReverseFieldInfoTwo = ffi.ReverseFieldInfo
fi.ReverseFieldInfoM2M = ffi
ffi.ReverseFieldInfoM2M = fi
break mForC
}
}
}
if !found {
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName)
goto end
}
}
@@ -334,7 +327,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
typ := reflect.Indirect(val).Type()
if val.Kind() != reflect.Ptr {
err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ))
err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", imodels.GetFullName(typ))
return
}
// For this case:
@@ -347,7 +340,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
if val.Elem().Kind() == reflect.Slice {
val = reflect.New(val.Elem().Type().Elem())
}
table := getTableName(val)
table := imodels.GetTableName(val)
if prefixOrSuffixStr != "" {
if prefixOrSuffix {
@@ -358,7 +351,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
}
// models's fullname is pkgpath + struct name
name := getFullName(typ)
name := imodels.GetFullName(typ)
if _, ok := mc.getByFullName(name); ok {
err = fmt.Errorf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
return
@@ -368,26 +361,26 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
return nil
}
mi := newModelInfo(val)
if mi.fields.pk == nil {
mi := imodels.NewModelInfo(val)
if mi.Fields.Pk == nil {
outFor:
for _, fi := range mi.fields.fieldsDB {
if strings.ToLower(fi.name) == "id" {
switch fi.addrValue.Elem().Kind() {
for _, fi := range mi.Fields.FieldsDB {
if strings.ToLower(fi.Name) == "id" {
switch fi.AddrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true
fi.pk = true
mi.fields.pk = fi
fi.Auto = true
fi.Pk = true
mi.Fields.Pk = fi
break outFor
}
}
}
}
mi.table = table
mi.pkg = typ.PkgPath()
mi.model = model
mi.manual = true
mi.Table = table
mi.Pkg = typ.PkgPath()
mi.Model = model
mi.Manual = true
mc.set(table, mi)
}
@@ -404,7 +397,7 @@ func (mc *modelCache) getDbDropSQL(al *alias) (queries []string, err error) {
Q := al.DbBaser.TableQuote()
for _, mi := range mc.allOrdered() {
queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.Table, Q))
}
return queries, nil
}
@@ -424,33 +417,33 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
for _, mi := range mc.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.FullName)
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.Table, Q)
columns := make([]string, 0, len(mi.fields.fieldsDB))
columns := make([]string, 0, len(mi.Fields.FieldsDB))
sqlIndexes := [][]string{}
var commentIndexes []int // store comment indexes for postgres
for i, fi := range mi.fields.fieldsDB {
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
for i, fi := range mi.Fields.FieldsDB {
column := fmt.Sprintf(" %s%s%s ", Q, fi.Column, Q)
col := getColumnTyp(al, fi)
if fi.auto {
if fi.Auto {
switch al.Driver {
case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
}
} else if fi.pk {
} else if fi.Pk {
column += col + " " + T["pk"]
} else {
column += col
if !fi.null {
if !fi.Null {
column += " " + "NOT NULL"
}
@@ -461,42 +454,42 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
// Append attribute DEFAULT
column += getColumnDefault(fi)
if fi.unique {
if fi.Unique {
column += " " + "UNIQUE"
}
if fi.index {
sqlIndexes = append(sqlIndexes, []string{fi.column})
if fi.Index {
sqlIndexes = append(sqlIndexes, []string{fi.Column})
}
}
if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1)
column = strings.Replace(column, "%COL%", fi.Column, -1)
}
if fi.description != "" && al.Driver != DRSqlite {
if fi.Description != "" && al.Driver != DRSqlite {
if al.Driver == DRPostgres {
commentIndexes = append(commentIndexes, i)
} else {
column += " " + fmt.Sprintf("COMMENT '%s'", fi.description)
column += " " + fmt.Sprintf("COMMENT '%s'", fi.Description)
}
}
columns = append(columns, column)
}
if mi.model != nil {
allnames := getTableUnique(mi.addrField)
if !mi.manual && len(mi.uniques) > 0 {
allnames = append(allnames, mi.uniques)
if mi.Model != nil {
allnames := imodels.GetTableUnique(mi.AddrField)
if !mi.Manual && len(mi.Uniques) > 0 {
allnames = append(allnames, mi.Uniques)
}
for _, names := range allnames {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
cols = append(cols, fi.column)
if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.Column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.FullName))
}
}
column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
@@ -509,8 +502,8 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
if mi.Model != nil {
engine = imodels.GetTableEngine(mi.AddrField)
}
if engine == "" {
engine = al.Engine
@@ -524,24 +517,24 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
for _, index := range commentIndexes {
sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';",
Q,
mi.table,
mi.Table,
Q,
Q,
mi.fields.fieldsDB[index].column,
mi.Fields.FieldsDB[index].Column,
Q,
mi.fields.fieldsDB[index].description)
mi.Fields.FieldsDB[index].Description)
}
}
queries = append(queries, sql)
if mi.model != nil {
for _, names := range getTableIndex(mi.addrField) {
if mi.Model != nil {
for _, names := range imodels.GetTableIndex(mi.AddrField) {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
cols = append(cols, fi.column)
if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.Column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.FullName))
}
}
sqlIndexes = append(sqlIndexes, cols)
@@ -549,16 +542,16 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
}
for _, names := range sqlIndexes {
name := mi.table + "_" + strings.Join(names, "_")
name := mi.Table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.Table, Q, Q, cols, Q)
index := dbIndex{}
index.Table = mi.table
index.Table = mi.Table
index.Name = name
index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
tableIndexes[mi.Table] = append(tableIndexes[mi.Table], index)
}
}

View File

@@ -15,91 +15,47 @@
package orm
import (
"fmt"
"strconv"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// Define the Type enum
const (
TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField
TypeTextField
TypeTimeField
TypeDateField
TypeDateTimeField
TypeBitField
TypeSmallIntegerField
TypeIntegerField
TypeBigIntegerField
TypePositiveBitField
TypePositiveSmallIntegerField
TypePositiveIntegerField
TypePositiveBigIntegerField
TypeFloatField
TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
TypeBooleanField = models.TypeBooleanField
TypeVarCharField = models.TypeVarCharField
TypeCharField = models.TypeCharField
TypeTextField = models.TypeTextField
TypeTimeField = models.TypeTimeField
TypeDateField = models.TypeDateField
TypeDateTimeField = models.TypeDateTimeField
TypeBitField = models.TypeBitField
TypeSmallIntegerField = models.TypeSmallIntegerField
TypeIntegerField = models.TypeIntegerField
TypeBigIntegerField = models.TypeBigIntegerField
TypePositiveBitField = models.TypePositiveBitField
TypePositiveSmallIntegerField = models.TypePositiveSmallIntegerField
TypePositiveIntegerField = models.TypePositiveIntegerField
TypePositiveBigIntegerField = models.TypePositiveBigIntegerField
TypeFloatField = models.TypeFloatField
TypeDecimalField = models.TypeDecimalField
TypeJSONField = models.TypeJSONField
TypeJsonbField = models.TypeJsonbField
RelForeignKey = models.RelForeignKey
RelOneToOne = models.RelOneToOne
RelManyToMany = models.RelManyToMany
RelReverseOne = models.RelReverseOne
RelReverseMany = models.RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1
IsIntegerField = models.IsIntegerField
IsPositiveIntegerField = models.IsPositiveIntegerField
IsRelField = models.IsRelField
IsFieldType = models.IsFieldType
)
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
type BooleanField = models.BooleanField
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
@@ -108,43 +64,7 @@ var _ Fielder = new(BooleanField)
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeVarCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
type CharField = models.CharField
// verify CharField implement Fielder
var _ Fielder = new(CharField)
@@ -162,49 +82,7 @@ var _ Fielder = new(CharField)
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
type TimeField = models.TimeField
var _ Fielder = new(TimeField)
@@ -221,49 +99,7 @@ var _ Fielder = new(TimeField)
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDate)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
type DateField = models.DateField
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
@@ -271,513 +107,67 @@ var _ Fielder = new(DateField)
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDateTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
type DateTimeField = models.DateTimeField
// verify datetime implement fielder
var _ Fielder = new(DateTimeField)
var _ models.Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
type FloatField = models.FloatField
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
type SmallIntegerField = models.SmallIntegerField
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
type IntegerField = models.IntegerField
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
type BigIntegerField = models.BigIntegerField
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
type PositiveSmallIntegerField = models.PositiveSmallIntegerField
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
type PositiveIntegerField = models.PositiveIntegerField
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
type PositiveBigIntegerField = models.PositiveBigIntegerField
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
type TextField = models.TextField
// verify TextField implement Fielder
var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
type JSONField = models.JSONField
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
var _ models.Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
type JsonbField = models.JsonbField
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)
var _ models.Fielder = new(JsonbField)

View File

@@ -1,148 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"os"
"reflect"
)
// single model info
type modelInfo struct {
manual bool
isThrough bool
pkg string
name string
fullName string
table string
model interface{}
fields *fields
addrField reflect.Value // store the original struct value
uniques []string
}
// new model info
func newModelInfo(val reflect.Value) (mi *modelInfo) {
mi = &modelInfo{}
mi.fields = newFields()
ind := reflect.Indirect(val)
mi.addrField = val
mi.name = ind.Type().Name()
mi.fullName = getFullName(ind.Type())
addModelFields(mi, ind, "", []int{})
return
}
// index: FieldByIndex returns the nested field corresponding to index
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct fields
if sf.Anonymous {
addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = newFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
// record current field index
fi.fieldIndex = append(fi.fieldIndex, index...)
fi.fieldIndex = append(fi.fieldIndex, i)
fi.mi = mi
fi.inModel = true
if !mi.fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
mi.fields.pk = fi
}
}
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
}
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
mi = new(modelInfo)
mi.fields = newFields()
mi.table = m1.table + "_" + m2.table + "s"
mi.name = camelString(mi.table)
mi.fullName = m1.pkg + "." + mi.name
fa := new(fieldInfo) // pk
f1 := new(fieldInfo) // m1 table RelForeignKey
f2 := new(fieldInfo) // m2 table RelForeignKey
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
f1.fieldType = RelForeignKey
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = mi.fullName + "." + f1.name
f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
f2.rel = true
f1.relTable = m1.table
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = mi
f2.mi = mi
mi.fields.Add(fa)
mi.fields.Add(f1)
mi.fields.Add(f2)
mi.fields.pk = fa
mi.uniques = []string{f1.column, f2.column}
return
}

View File

@@ -22,6 +22,8 @@ import (
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
@@ -79,7 +81,7 @@ func (e *SliceStringField) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(SliceStringField)
var _ models.Fielder = new(SliceStringField)
// A json field.
type JSONFieldTest struct {
@@ -111,7 +113,7 @@ func (e *JSONFieldTest) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(JSONFieldTest)
var _ models.Fielder = new(JSONFieldTest)
type Data struct {
ID int `orm:"column(id)"`

View File

@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build go1.8
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
@@ -50,7 +47,6 @@
// // delete
// num, err = o.Delete(&u)
// }
//
package orm
import (
@@ -58,9 +54,12 @@ import (
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"time"
ilogs "github.com/beego/beego/v2/client/orm/internal/logs"
iutils "github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints"
@@ -76,10 +75,10 @@ const (
// Define common vars
var (
Debug = false
DebugLog = NewLog(os.Stdout)
DebugLog = ilogs.DebugLog
DefaultRowsLimit = -1
DefaultRelsDepth = 2
DefaultTimeLoc = time.Local
DefaultTimeLoc = iutils.DefaultTimeLoc
ErrTxDone = errors.New("<TxOrmer.Commit/Rollback> transaction already done")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found")
@@ -108,7 +107,7 @@ var (
)
// get model info and model reflect value
func (*ormBase) getMi(md interface{}) (mi *modelInfo) {
func (*ormBase) getMi(md interface{}) (mi *models.ModelInfo) {
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
@@ -117,19 +116,19 @@ func (*ormBase) getMi(md interface{}) (mi *modelInfo) {
}
// get need ptr model info and model reflect value
func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
func (*ormBase) getPtrMiInd(md interface{}) (mi *models.ModelInfo, ind reflect.Value) {
val := reflect.ValueOf(md)
ind = reflect.Indirect(val)
typ := ind.Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", models.GetFullName(typ)))
}
mi = getTypeMi(typ)
return
}
func getTypeMi(mdTyp reflect.Type) *modelInfo {
name := getFullName(mdTyp)
func getTypeMi(mdTyp reflect.Type) *models.ModelInfo {
name := models.GetFullName(mdTyp)
if mi, ok := defaultModelCache.getByFullName(name); ok {
return mi
}
@@ -137,10 +136,10 @@ func getTypeMi(mdTyp reflect.Type) *modelInfo {
}
// get field info from model info by given field name
func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name)
func (*ormBase) getFieldInfo(mi *models.ModelInfo, name string) *models.FieldInfo {
fi, ok := mi.Fields.GetByAny(name)
if !ok {
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.FullName))
}
return fi
}
@@ -180,11 +179,11 @@ func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1
return err == nil, id, err
}
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id, vid := int64(0), ind.FieldByIndex(mi.Fields.Pk.FieldIndex)
if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint())
} else if mi.fields.pk.rel {
return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
} else if mi.Fields.Pk.Rel {
return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.Fields.Pk.RelModelInfo.Fields.Pk.Name)
} else {
id = vid.Int()
}
@@ -210,12 +209,12 @@ func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, err
}
// set auto pk field
func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
func (*ormBase) setPk(mi *models.ModelInfo, ind reflect.Value, id int64) {
if mi.Fields.Pk != nil && mi.Fields.Pk.Auto {
if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetInt(id)
}
}
}
@@ -277,7 +276,7 @@ func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, col
}
// update model to database.
// cols set the columns those want to update.
// cols set the Columns those want to update.
func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
return o.UpdateWithCtx(context.Background(), md, cols...)
}
@@ -305,10 +304,10 @@ func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
fi := o.getFieldInfo(mi, name)
switch {
case fi.fieldType == RelManyToMany:
case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
case fi.FieldType == RelManyToMany:
case fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough:
default:
panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.Name, mi.FullName))
}
return newQueryM2M(md, o, mi, fi, ind)
@@ -324,8 +323,9 @@ func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) {
@@ -362,7 +362,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
}
})
switch fi.fieldType {
switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
limit = 1
offset = 0
@@ -376,11 +376,11 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
qs.orders = order_clause.ParseOrder(order)
}
find := ind.FieldByIndex(fi.fieldIndex)
find := ind.FieldByIndex(fi.FieldIndex)
var nums int64
var err error
switch fi.fieldType {
switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
val := reflect.New(find.Type().Elem())
container := val.Interface()
@@ -397,7 +397,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
}
// get QuerySeter for related models to md model
func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) {
func (o *ormBase) queryRelated(md interface{}, name string) (*models.ModelInfo, *models.FieldInfo, reflect.Value, *querySet) {
mi, ind := o.getPtrMiInd(md)
fi := o.getFieldInfo(mi, name)
@@ -408,14 +408,14 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI
var qs *querySet
switch fi.fieldType {
switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
if !fi.inModel {
if !fi.InModel {
break
}
qs = o.getRelQs(md, mi, fi)
case RelReverseOne, RelReverseMany:
if !fi.inModel {
if !fi.InModel {
break
}
qs = o.getReverseQs(md, mi, fi)
@@ -429,41 +429,41 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI
}
// get reverse relation QuerySeter
func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
func (o *ormBase) getReverseQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet {
switch fi.FieldType {
case RelReverseOne, RelReverseMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.Name, mi.FullName))
}
var q *querySet
if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
q = newQuerySet(o, fi.relModelInfo).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
if fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough {
q = newQuerySet(o, fi.RelModelInfo).(*querySet)
q.cond = NewCondition().And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md)
} else {
q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
q = newQuerySet(o, fi.ReverseFieldInfo.Mi).(*querySet)
q.cond = NewCondition().And(fi.ReverseFieldInfo.Column, md)
}
return q
}
// get relation QuerySeter
func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
func (o *ormBase) getRelQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet {
switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.Name, mi.FullName))
}
q := newQuerySet(o, fi.relModelInfo).(*querySet)
q := newQuerySet(o, fi.RelModelInfo).(*querySet)
q.cond = NewCondition()
if fi.fieldType == RelManyToMany {
q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
if fi.FieldType == RelManyToMany {
q.cond = q.cond.And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md)
} else {
q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
q.cond = q.cond.And(fi.ReverseFieldInfo.Column, md)
}
return q
@@ -475,12 +475,12 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table)
name = models.NameStrategyMap[models.DefaultNameStrategy](table)
if mi, ok := defaultModelCache.get(name); ok {
qs = newQuerySet(o, mi)
}
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
name = models.GetFullName(iutils.IndirectType(reflect.TypeOf(ptrStructOrTableName)))
if mi, ok := defaultModelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi)
}

View File

@@ -22,23 +22,22 @@ import (
"log"
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/logs"
)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// costomer log func
var LogFunc func(query map[string]interface{})
type Log = logs.Log
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
func NewLog(out io.Writer) *logs.Log {
d := new(logs.Log)
d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d
}
// LogFunc costomer log func
var LogFunc func(query map[string]interface{})
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
logMap := make(map[string]interface{})
sub := time.Since(t) / 1e5
@@ -64,7 +63,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if LogFunc != nil {
LogFunc(logMap)
}
DebugLog.Println(con)
logs.DebugLog.Println(con)
}
// statement query logger struct.

View File

@@ -18,11 +18,13 @@ import (
"context"
"fmt"
"reflect"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// an insert queryer struct
type insertSet struct {
mi *modelInfo
mi *models.ModelInfo
orm *ormBase
stmt stmtQuerier
closed bool
@@ -42,23 +44,23 @@ func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, e
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
name := getFullName(typ)
name := models.GetFullName(typ)
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
}
if name != o.mi.fullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
if name != o.mi.FullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.FullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if o.mi.fields.pk.auto {
if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
if o.mi.Fields.Pk.Auto {
if o.mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetInt(id)
}
}
}
@@ -75,7 +77,7 @@ func (o *insertSet) Close() error {
}
// create new insert queryer.
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
func newInsertSet(ctx context.Context, orm *ormBase, mi *models.ModelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi

View File

@@ -17,22 +17,25 @@ package orm
import (
"context"
"reflect"
"github.com/beego/beego/v2/client/orm/internal/models"
)
// model to model struct
type queryM2M struct {
md interface{}
mi *modelInfo
fi *fieldInfo
mi *models.ModelInfo
fi *models.FieldInfo
qs *querySet
ind reflect.Value
}
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
@@ -41,9 +44,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo
rfi := fi.reverseFieldInfoTwo
mi := fi.RelThroughModelInfo
mfi := fi.ReverseFieldInfo
rfi := fi.ReverseFieldInfoTwo
orm := o.qs.orm
dbase := orm.alias.DbBaser
@@ -52,9 +55,9 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
var otherValues []interface{}
var otherNames []string
for _, colname := range mi.fields.dbcols {
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
mi.fields.columns[colname] != mi.fields.pk {
for _, colname := range mi.Fields.DBcols {
if colname != mfi.Column && colname != rfi.Column && colname != fi.Mi.Fields.Pk.Column &&
mi.Fields.Columns[colname] != mi.Fields.Pk {
otherNames = append(otherNames, colname)
}
}
@@ -83,7 +86,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
panic(ErrMissPK)
}
names := []string{mfi.column, rfi.column}
names := []string{mfi.Column, rfi.Column}
values := make([]interface{}, 0, len(models)*2)
for _, md := range models {
@@ -93,7 +96,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
} else {
_, v2, exist = getExistPk(fi.relModelInfo, ind)
_, v2, exist = getExistPk(fi.RelModelInfo, ind)
if !exist {
panic(ErrMissPK)
}
@@ -113,9 +116,9 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
qs := o.qs.Filter(fi.ReverseFieldInfo.Name, o.md)
return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
return qs.Filter(fi.ReverseFieldInfoTwo.Name+ExprSep+"in", mds).Delete()
}
// check model is existed in relationship of origin model
@@ -125,8 +128,8 @@ func (o *queryM2M) Exist(md interface{}) bool {
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx)
return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).
Filter(fi.ReverseFieldInfoTwo.Name, md).ExistWithCtx(ctx)
}
// clean all models in related of origin model
@@ -136,7 +139,7 @@ func (o *queryM2M) Clear() (int64, error) {
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx)
return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).DeleteWithCtx(ctx)
}
// count all related models of origin model
@@ -146,18 +149,18 @@ func (o *queryM2M) Count() (int64, error) {
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx)
return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).CountWithCtx(ctx)
}
var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
func newQueryM2M(md interface{}, o *ormBase, mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M)
qm2m.md = md
qm2m.mi = mi
qm2m.fi = fi
qm2m.ind = ind
qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet)
qm2m.qs = newQuerySet(o, fi.RelThroughModelInfo).(*querySet)
return qm2m
}

View File

@@ -18,6 +18,10 @@ import (
"context"
"fmt"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints"
)
@@ -43,9 +47,10 @@ const (
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
//
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift,
@@ -53,7 +58,7 @@ func ColValue(opt operator, value interface{}) interface{} {
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}
v, err := StrTo(ToStr(value)).Int64()
v, err := utils.StrTo(utils.ToStr(value)).Int64()
if err != nil {
panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err))
}
@@ -65,7 +70,7 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct
type querySet struct {
mi *modelInfo
mi *models.ModelInfo
cond *Condition
related []string
relDepth int
@@ -112,13 +117,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
// set offset number
func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num)
o.offset = utils.ToInt64(num)
}
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit)
o.limit = utils.ToInt64(limit)
if len(args) > 0 {
o.setOffset(args[0])
}
@@ -260,8 +265,9 @@ func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) {
// return an insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
//
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return o.PrepareInsertWithCtx(context.Background())
}
@@ -271,7 +277,7 @@ func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
}
// query all data and map to containers.
// cols means the columns when querying.
// cols means the Columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.AllWithCtx(context.Background(), container, cols...)
}
@@ -281,7 +287,7 @@ func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols .
}
// query one row data and map to containers.
// cols means the columns when querying.
// cols means the Columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...)
}
@@ -339,10 +345,11 @@ func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, ex
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
//
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
@@ -353,16 +360,17 @@ func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, er
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
//
// to struct {
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// create new QuerySeter.
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
func newQuerySet(orm *ormBase, mi *models.ModelInfo) QuerySeter {
o := new(querySet)
o.mi = mi
o.orm = orm

View File

@@ -20,6 +20,10 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/pkg/errors"
)
@@ -95,7 +99,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} else if v, ok := value.(bool); ok {
ind.SetBool(v)
} else {
v, _ := StrTo(ToStr(value)).Bool()
v, _ := utils.StrTo(utils.ToStr(value)).Bool()
ind.SetBool(v)
}
@@ -103,7 +107,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if value == nil {
ind.SetString("")
} else {
ind.SetString(ToStr(value))
ind.SetString(utils.ToStr(value))
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -117,7 +121,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetInt(int64(val.Uint()))
default:
v, _ := StrTo(ToStr(value)).Int64()
v, _ := utils.StrTo(utils.ToStr(value)).Int64()
ind.SetInt(v)
}
}
@@ -132,7 +136,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetUint(val.Uint())
default:
v, _ := StrTo(ToStr(value)).Uint64()
v, _ := utils.StrTo(utils.ToStr(value)).Uint64()
ind.SetUint(v)
}
}
@@ -145,7 +149,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Float64:
ind.SetFloat(val.Float())
default:
v, _ := StrTo(ToStr(value)).Float64()
v, _ := utils.StrTo(utils.ToStr(value)).Float64()
ind.SetFloat(v)
}
}
@@ -170,20 +174,20 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" {
if len(str) >= 19 {
str = str[:19]
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
t, err := time.ParseInLocation(utils.FormatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
t, err := time.ParseInLocation(utils.FormatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 8 {
str = str[:8]
t, err := time.ParseInLocation(formatTime, str, DefaultTimeLoc)
t, err := time.ParseInLocation(utils.FormatTime, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
@@ -287,7 +291,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
sMi *models.ModelInfo
)
structMode := false
for _, container := range containers {
@@ -313,7 +317,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
}
structMode = true
fn := getFullName(typ)
fn := models.GetFullName(typ)
if mi, ok := defaultModelCache.getByFullName(fn); ok {
sMi = mi
}
@@ -370,16 +374,16 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
if fi := sMi.Fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
field := ind.FieldByIndex(fi.FieldIndex)
if fi.FieldType&IsRelField > 0 {
mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type())
field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex)
}
if fi.isFielder {
fd := field.Addr().Interface().(Fielder)
if fi.IsFielder {
fd := field.Addr().Interface().(models.Fielder)
err := fd.SetRaw(value)
if err != nil {
return errors.Errorf("set raw error:%s", err)
@@ -406,12 +410,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// thanks @Gazeboxu.
tags := structTagMap[fe.Tag]
if tags == nil {
_, tags = parseStructTag(fe.Tag.Get(defaultStructTagName))
_, tags = models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName))
structTagMap[fe.Tag] = tags
}
var col string
if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name)
col = models.NameStrategyMap[models.NameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
@@ -443,13 +447,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
return nil
}
// query data rows and map to container
// QueryRows query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
sMi *models.ModelInfo
)
structMode := false
for _, container := range containers {
@@ -474,7 +478,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
}
structMode = true
fn := getFullName(typ)
fn := models.GetFullName(typ)
if mi, ok := defaultModelCache.getByFullName(fn); ok {
sMi = mi
}
@@ -500,7 +504,6 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInd := sInds[0]
for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
@@ -537,16 +540,16 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
if fi := sMi.Fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
field := ind.FieldByIndex(fi.FieldIndex)
if fi.FieldType&IsRelField > 0 {
mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type())
field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex)
}
if fi.isFielder {
fd := field.Addr().Interface().(Fielder)
if fi.IsFielder {
fd := field.Addr().Interface().(models.Fielder)
err := fd.SetRaw(value)
if err != nil {
return 0, errors.Errorf("set raw error:%s", err)
@@ -570,10 +573,10 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
recursiveSetField(f)
}
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
_, tags := models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name)
col = models.NameStrategyMap[models.NameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
@@ -593,16 +596,18 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
if err = rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++
}
if err = rows.Err(); err != nil {
return 0, err
}
if cnt > 0 {
if structMode {
sInds[0].Set(sInd)
@@ -730,6 +735,10 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
cnt++
}
if err = rs.Err(); err != nil {
return 0, err
}
switch v := container.(type) {
case *[]Params:
*v = maps
@@ -837,7 +846,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
}
default:
if id := ind.FieldByName(camelString(key)); id.IsValid() {
if id := ind.FieldByName(models.CamelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
}
}
@@ -845,6 +854,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
cnt++
}
if err = rs.Err(); err != nil {
return 0, err
}
if typ == 1 {
v, _ := container.(*Params)
*v = maps
@@ -874,10 +886,11 @@ func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
//
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(result, keyCol, valueCol)
}
@@ -888,10 +901,11 @@ func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, erro
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
//
// to struct {
// Total int
// Found int
// }
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
}

View File

@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build go1.8
// +build go1.8
package orm
import (
@@ -22,7 +19,6 @@ import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"math"
"os"
"path/filepath"
@@ -32,6 +28,12 @@ import (
"testing"
"time"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
@@ -41,9 +43,9 @@ import (
var _ = os.PathSeparator
var (
testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700"
testTime = formatTime + " -0700"
testDate = utils.FormatDate + " -0700"
testDateTime = utils.FormatDateTime + " -0700"
testTime = utils.FormatTime + " -0700"
)
type argAny []interface{}
@@ -72,7 +74,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er
case time.Time:
if v2, vo := b.(time.Time); vo {
if arg.Get(1) != nil {
format := ToStr(arg.Get(1))
format := utils.ToStr(arg.Get(1))
a = v.Format(format)
b = v2.Format(format)
ok = a == b
@@ -82,7 +84,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er
}
}
default:
ok = ToStr(a) == ToStr(b)
ok = utils.ToStr(a) == utils.ToStr(b)
}
ok = is && ok || !is && !ok
if !ok {
@@ -115,7 +117,7 @@ func getCaller(skip int) string {
pc, file, line, _ := runtime.Caller(skip)
fun := runtime.FuncForPC(pc)
_, fn := filepath.Split(file)
data, err := ioutil.ReadFile(file)
data, err := os.ReadFile(file)
var codes []string
if err == nil {
lines := bytes.Split(data, []byte{'\n'})
@@ -250,14 +252,14 @@ func TestRegisterModels(_ *testing.T) {
func TestModelSyntax(t *testing.T) {
user := &User{}
ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type())
fn := models.GetFullName(ind.Type())
_, ok := defaultModelCache.getByFullName(fn)
throwFail(t, AssertIs(ok, true))
mi, ok := defaultModelCache.get("user")
throwFail(t, AssertIs(ok, true))
if ok {
throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true))
throwFail(t, AssertIs(mi.Fields.GetByName("ShouldSkip") == nil, true))
}
}
@@ -561,7 +563,7 @@ func TestNullDataTypes(t *testing.T) {
assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second)
assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second)
// test support for pointer fields using RawSeter.QueryRows()
// test support for pointer Fields using RawSeter.QueryRows()
var dnList []*DataNull
Q := dDbBaser.TableQuote()
num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList)
@@ -1894,7 +1896,7 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(row.Id, 4))
throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com"))
// test for sql.Null* fields
// test for sql.Null* Fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
@@ -2003,7 +2005,7 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30))
// test for sql.Null* fields
// test for sql.Null* Fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
@@ -2616,7 +2618,7 @@ func TestSnake(t *testing.T) {
"tag_666Name": "tag_666_name",
}
for name, want := range cases {
got := snakeString(name)
got := models.SnakeString(name)
throwFail(t, AssertIs(got, want))
}
}
@@ -2637,10 +2639,10 @@ func TestIgnoreCaseTag(t *testing.T) {
if t == nil {
return
}
throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n"))
throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true))
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
throwFail(t, AssertIs(info.Fields.GetByName("NOO").Column, "n"))
throwFail(t, AssertIs(info.Fields.GetByName("Name01").Null, true))
throwFail(t, AssertIs(info.Fields.GetByName("Name02").Column, "Name"))
throwFail(t, AssertIs(info.Fields.GetByName("Name03").Column, "name"))
}
func TestInsertOrUpdate(t *testing.T) {
@@ -2934,9 +2936,9 @@ func TestDebugLog(t *testing.T) {
func captureDebugLogOutput(f func()) string {
var buf bytes.Buffer
DebugLog.SetOutput(&buf)
logs.DebugLog.SetOutput(&buf)
defer func() {
DebugLog.SetOutput(os.Stderr)
logs.DebugLog.SetOutput(os.Stderr)
}()
f()
return buf.String()

View File

@@ -28,7 +28,7 @@ type MySQLQueryBuilder struct {
tokens []string
}
// Select will join the fields
// Select will join the Fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
@@ -94,7 +94,7 @@ func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
return qb
}
// OrderBy join the Order by fields
// OrderBy join the Order by Fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
@@ -124,7 +124,7 @@ func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
return qb
}
// GroupBy join the Group by fields
// GroupBy join the Group by Fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb

View File

@@ -19,7 +19,7 @@ func processingStr(str []string) string {
return s
}
// Select will join the fields
// Select will join the Fields
func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder {
var str string
n := len(fields)
@@ -121,7 +121,7 @@ func (qb *PostgresQueryBuilder) In(vals ...string) QueryBuilder {
return qb
}
// OrderBy join the Order by fields
// OrderBy join the Order by Fields
func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder {
str := processingStr(fields)
qb.tokens = append(qb.tokens, "ORDER BY", str)
@@ -152,7 +152,7 @@ func (qb *PostgresQueryBuilder) Offset(offset int) QueryBuilder {
return qb
}
// GroupBy join the Group by fields
// GroupBy join the Group by Fields
func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder {
str := processingStr(fields)
qb.tokens = append(qb.tokens, "GROUP BY", str)

View File

@@ -20,19 +20,23 @@ import (
"reflect"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/core/utils"
)
// TableNaming is usually used by model
// TableNameI is usually used by model
// when you custom your table name, please implement this interfaces
// for example:
// type User struct {
// ...
// }
// func (u *User) TableName() string {
// return "USER_TABLE"
// }
//
// type User struct {
// ...
// }
//
// func (u *User) TableName() string {
// return "USER_TABLE"
// }
type TableNameI interface {
TableName() string
}
@@ -40,12 +44,14 @@ type TableNameI interface {
// TableEngineI is usually used by model
// when you want to use specific engine, like myisam, you can implement this interface
// for example:
// type User struct {
// ...
// }
// func (u *User) TableEngine() string {
// return "myisam"
// }
//
// type User struct {
// ...
// }
//
// func (u *User) TableEngine() string {
// return "myisam"
// }
type TableEngineI interface {
TableEngine() string
}
@@ -53,12 +59,14 @@ type TableEngineI interface {
// TableIndexI is usually used by model
// when you want to create indexes, you can implement this interface
// for example:
// type User struct {
// ...
// }
// func (u *User) TableIndex() [][]string {
// return [][]string{{"Name"}}
// }
//
// type User struct {
// ...
// }
//
// func (u *User) TableIndex() [][]string {
// return [][]string{{"Name"}}
// }
type TableIndexI interface {
TableIndex() [][]string
}
@@ -66,12 +74,14 @@ type TableIndexI interface {
// TableUniqueI is usually used by model
// when you want to create unique indexes, you can implement this interface
// for example:
// type User struct {
// ...
// }
// func (u *User) TableUnique() [][]string {
// return [][]string{{"Email"}}
// }
//
// type User struct {
// ...
// }
//
// func (u *User) TableUnique() [][]string {
// return [][]string{{"Email"}}
// }
type TableUniqueI interface {
TableUnique() [][]string
}
@@ -87,22 +97,16 @@ type Driver interface {
Type() DriverType
}
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}
type Fielder = models.Fielder
type TxBeginner interface {
// self control transaction
// Begin self control transaction
Begin() (TxOrmer, error)
BeginWithCtx(ctx context.Context) (TxOrmer, error)
BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error)
BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error)
// closure control transaction
// DoTx closure control transaction
DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
@@ -138,27 +142,27 @@ type txEnder interface {
RollbackUnlessCommit() error
}
// Data Manipulation Language
// DML Data Manipulation Language
type DML interface {
// insert model data to database
// Insert insert model data to database
// for example:
// user := new(User)
// id, err = Ormer.Insert(user)
// user must be a pointer and Insert will set user's pk field
Insert(md interface{}) (int64, error)
InsertWithCtx(ctx context.Context, md interface{}) (int64, error)
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// InsertOrUpdate mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// if colu type is integer : can use(+-*/), string : convert(colu,"value")
// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
// if colu type is integer : can use(+-*/), string : colu || "value"
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error)
// insert some models to database
// InsertMulti inserts some models to database
InsertMulti(bulk int, mds interface{}) (int64, error)
InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error)
// update model to database.
// cols set the columns those want to update.
// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
// Update updates model to database.
// cols set the Columns those want to update.
// find model by Id(pk) field and update Columns specified by Fields, if cols is null then update all Columns
// for example:
// user := User{Id: 2}
// user.Langs = append(user.Langs, "zh-CN", "en-US")
@@ -167,11 +171,11 @@ type DML interface {
// num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error)
UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
// delete model in database
// Delete deletes model in database
Delete(md interface{}, cols ...string) (int64, error)
DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
// return a raw query seter for raw sql string.
// Raw return a raw query seter for raw sql string.
// for example:
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
// // update user testing's name to slene
@@ -179,9 +183,9 @@ type DML interface {
RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter
}
// Data Query Language
// DQL Data Query Language
type DQL interface {
// read data to model
// Read reads data to model
// for example:
// this will find User by Id field
// u = &User{Id: user.Id}
@@ -192,16 +196,16 @@ type DQL interface {
Read(md interface{}, cols ...string) error
ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error
// Like Read(), but with "FOR UPDATE" clause, useful in transaction.
// ReadForUpdate Like Read(), but with "FOR UPDATE" clause, useful in transaction.
// Some databases are not support this feature.
ReadForUpdate(md interface{}, cols ...string) error
ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist
// ReadOrCreate Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error)
// load related models to md model.
// LoadRelated load related models to md model.
// args are limit, offset int and order string.
//
// example:
@@ -216,20 +220,20 @@ type DQL interface {
LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error)
LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error)
// create a models to models queryer
// QueryM2M create a models to models queryer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// NOTE: this method is deprecated, context parameter will not take effect.
// QueryM2MWithCtx NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// QueryTable return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// NOTE: this method is deprecated, context parameter will not take effect.
// QueryTableWithCtx NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
@@ -270,7 +274,7 @@ type Inserter interface {
// QuerySeter query seter
type QuerySeter interface {
// add condition expression to QuerySeter.
// Filter add condition expression to QuerySeter.
// for example:
// filter by UserName == 'slene'
// qs.Filter("UserName", "slene")
@@ -279,22 +283,22 @@ type QuerySeter interface {
// // time compare
// qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
// add raw sql to querySeter.
// FilterRaw add raw sql to querySeter.
// for example:
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
FilterRaw(string, string) QuerySeter
// add NOT condition to querySeter.
// Exclude add NOT condition to querySeter.
// have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
// set condition to QuerySeter.
// SetCond set condition to QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
// get condition from QuerySeter.
// GetCond get condition from QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
@@ -304,7 +308,7 @@ type QuerySeter interface {
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond).Count()
GetCond() *Condition
// add LIMIT value.
// Limit add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
@@ -312,19 +316,19 @@ type QuerySeter interface {
// qs.Limit(10, 2)
// // sql-> limit 10 offset 2
Limit(limit interface{}, args ...interface{}) QuerySeter
// add OFFSET value
// Offset add OFFSET value
// same as Limit function's args[0]
Offset(offset interface{}) QuerySeter
// add GROUP BY expression
// GroupBy add GROUP BY expression
// for example:
// qs.GroupBy("id")
GroupBy(exprs ...string) QuerySeter
// add ORDER expression.
// OrderBy add ORDER expression.
// "column" means ASC, "-column" means DESC.
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// add ORDER expression by order clauses
// OrderClauses add ORDER expression by order clauses
// for example:
// OrderClauses(
// order_clause.Clause(
@@ -346,50 +350,50 @@ type QuerySeter interface {
// order_clause.Raw(),//default false.if true, do not check field is valid or not
// ))
OrderClauses(orders ...*order_clause.Order) QuerySeter
// add FORCE INDEX expression.
// ForceIndex add FORCE INDEX expression.
// for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
ForceIndex(indexes ...string) QuerySeter
// add USE INDEX expression.
// UseIndex add USE INDEX expression.
// for example:
// qs.UseIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
UseIndex(indexes ...string) QuerySeter
// add IGNORE INDEX expression.
// IgnoreIndex add IGNORE INDEX expression.
// for example:
// qs.IgnoreIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
IgnoreIndex(indexes ...string) QuerySeter
// set relation model to query together.
// RelatedSel set relation model to query together.
// it will query relation models and assign to parent model.
// for example:
// // will load all related fields use left join .
// // will load all related Fields use left join .
// qs.RelatedSel().One(&user)
// // will load related field only profile
// qs.RelatedSel("profile").One(&user)
// user.Profile.Age = 32
RelatedSel(params ...interface{}) QuerySeter
// Set Distinct
// Distinct Set Distinct
// for example:
// o.QueryTable("policy").Filter("Groups__Group__Users__User", user).
// Distinct().
// All(&permissions)
Distinct() QuerySeter
// set FOR UPDATE to query.
// ForUpdate set FOR UPDATE to query.
// for example:
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
ForUpdate() QuerySeter
// return QuerySeter execution result number
// Count returns QuerySeter execution result number
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
// check result empty or not after QuerySeter executed
// Exist check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool
ExistWithCtx(context.Context) bool
// execute update with parameters
// Update execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
// "Nums": ColValue(Col_Minus, 50),
@@ -399,13 +403,13 @@ type QuerySeter interface {
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
UpdateWithCtx(ctx context.Context, values Params) (int64, error)
// delete from table
// Delete delete from table
// for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error)
DeleteWithCtx(context.Context) (int64, error)
// return an insert queryer.
// PrepareInsert return an insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
@@ -414,21 +418,21 @@ type QuerySeter interface {
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
PrepareInsertWithCtx(context.Context) (Inserter, error)
// query all data and map to containers.
// cols means the columns when querying.
// All query all data and map to containers.
// cols means the Columns when querying.
// for example:
// var users []*User
// qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error)
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
// query one row data and map to containers.
// cols means the columns when querying.
// One query one row data and map to containers.
// cols means the Columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// Values query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
// for example:
@@ -436,21 +440,21 @@ type QuerySeter interface {
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// ValuesList query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// ValuesFlat query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// RowsToMap query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@@ -461,7 +465,7 @@ type QuerySeter interface {
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// RowsToStruct query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@@ -472,7 +476,7 @@ type QuerySeter interface {
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// aggregate func.
// Aggregate aggregate func.
// for example:
// type result struct {
// DeptName string
@@ -486,7 +490,7 @@ type QuerySeter interface {
// QueryM2Mer model to model query struct
// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface {
// add models to origin models when creating queryM2M.
// Add adds models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
@@ -499,20 +503,20 @@ type QueryM2Mer interface {
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
AddWithCtx(context.Context, ...interface{}) (int64, error)
// remove models following the origin model relationship
// Remove removes models following the origin model relationship
// only delete rows from m2m table
// for example:
// tag3 := &Tag{Id:5,Name: "TestTag3"}
// num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
RemoveWithCtx(context.Context, ...interface{}) (int64, error)
// check model is existed in relationship of origin model
// Exist checks model is existed in relationship of origin model
Exist(interface{}) bool
ExistWithCtx(context.Context, interface{}) bool
// clean all models in related of origin model
// Clear cleans all models in related of origin model
Clear() (int64, error)
ClearWithCtx(context.Context) (int64, error)
// count all related models of origin model
// Count counts all related models of origin model
Count() (int64, error)
CountWithCtx(context.Context) (int64, error)
}
@@ -526,35 +530,36 @@ type RawPreparer interface {
// RawSeter raw query seter
// create From Ormer.Raw
// for example:
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1)
//
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1)
type RawSeter interface {
// execute sql and get result
// Exec execute sql and get result
Exec() (sql.Result, error)
// query data and map to container
// QueryRow query data and map to container
// for example:
// var name string
// var id int
// rs.QueryRow(&id,&name) // id==2 name=="slene"
QueryRow(containers ...interface{}) error
// query data rows and map to container
// QueryRows query data rows and map to container
// var ids []int
// var names []int
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
// query data to []map[string]interface
// Values query data to []map[string]interface
// see QuerySeter's Values
Values(container *[]Params, cols ...string) (int64, error)
// query data to [][]interface
// ValuesList query data to [][]interface
// see QuerySeter's ValuesList
ValuesList(container *[]ParamsList, cols ...string) (int64, error)
// query data to []interface
// ValuesFlat query data to []interface
// see QuerySeter's ValuesFlat
ValuesFlat(container *ParamsList, cols ...string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// RowsToMap query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@@ -565,7 +570,7 @@ type RawSeter interface {
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// RowsToStruct query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
@@ -577,7 +582,7 @@ type RawSeter interface {
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// return prepared raw statement for used in times.
// Prepare return prepared raw statement for used in times.
// for example:
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
@@ -617,32 +622,32 @@ type dbQuerier interface {
// base database struct
type dbBaser interface {
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Read(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(context.Context, dbQuerier, *models.ModelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(context.Context, stmtQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error)
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
Update(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Delete(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error)
GenerateOperatorSQL(*models.ModelInfo, *models.FieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*models.FieldInfo, string, *string)
PrepareInsert(context.Context, dbQuerier, *models.ModelInfo) (stmtQuerier, string, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool
HasReturningID(*models.ModelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
@@ -651,8 +656,8 @@ type dbBaser interface {
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(context.Context, dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(context.Context, dbQuerier, *modelInfo, []string) error
collectFieldValue(*models.ModelInfo, *models.FieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(context.Context, dbQuerier, *models.ModelInfo, []string) error
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
}

View File

@@ -15,305 +15,15 @@
package orm
import (
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"time"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/internal/utils"
)
type fn func(string) string
type StrTo = utils.StrTo
var (
nameStrategyMap = map[string]fn{
defaultNameStrategy: snakeString,
SnakeAcronymNameStrategy: snakeStringWithAcronym,
}
defaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
nameStrategy = defaultNameStrategy
)
// StrTo is the target string
type StrTo string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(rune(0x1E))
}
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(rune(0x1E))
}
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
return v, err
}
return ni.Int64(), nil
}
return v, err
}
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// Uint32 string to uint32
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
return v, err
}
return ni.Uint64(), nil
}
return v, err
}
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
case string:
s = v
case []byte:
s = string(v)
default:
s = fmt.Sprintf("%v", v)
}
return s
}
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
}
return
}
func snakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// snake string, XxYy to xx_yy , XxYY to xx_y_y
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// SetNameStrategy set different name strategy
func SetNameStrategy(s string) {
if SnakeAcronymNameStrategy != s {
nameStrategy = defaultNameStrategy
}
nameStrategy = s
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
if d == '_' {
flag = true
continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
}
data = append(data, d)
}
return string(data)
}
type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
return indirectType(v.Elem())
default:
return v
if models.SnakeAcronymNameStrategy != s {
models.NameStrategy = models.DefaultNameStrategy
}
models.NameStrategy = s
}