Merge pull request #5238 from flycash/develop

orm: refactor ORM introducing internal/models pkg
This commit is contained in:
Ming Deng 2023-06-09 10:35:58 +08:00 committed by GitHub
commit b7371715a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 2393 additions and 2086 deletions

View File

@ -2,6 +2,9 @@
- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232) - [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232)
- [remove adapter package](https://github.com/beego/beego/pull/5239) - [remove adapter package](https://github.com/beego/beego/pull/5239)
## ORM refactoring
- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238)
# v2.1.0 # v2.1.0
- [unified gopkg.in/yaml version to v2](https://github.com/beego/beego/pull/5169) - [unified gopkg.in/yaml version to v2](https://github.com/beego/beego/pull/5169)
- [add non-block write log in asynchronous mode](https://github.com/beego/beego/pull/5150) - [add non-block write log in asynchronous mode](https://github.com/beego/beego/pull/5150)

View File

@ -1,4 +1,4 @@
Copyright 2014 astaxie Copyright 2014 Beego
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View File

@ -20,6 +20,10 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
type commander interface { type commander interface {
@ -53,7 +57,7 @@ func RunCommand() {
BootStrap() BootStrap()
args := argString(os.Args[2:]) args := utils.ArgString(os.Args[2:])
name := args.Get(0) name := args.Get(0)
if name == "help" { if name == "help" {
@ -112,7 +116,7 @@ func (d *commandSyncDb) Run() error {
for i, mi := range defaultModelCache.allOrdered() { for i, mi := range defaultModelCache.allOrdered() {
query := drops[i] query := drops[i]
if !d.noInfo { if !d.noInfo {
fmt.Printf("drop table `%s`\n", mi.table) fmt.Printf("drop table `%s`\n", mi.Table)
} }
_, err := db.Exec(query) _, err := db.Exec(query)
if d.verbose { if d.verbose {
@ -143,18 +147,18 @@ func (d *commandSyncDb) Run() error {
ctx := context.Background() ctx := context.Background()
for i, mi := range defaultModelCache.allOrdered() { for i, mi := range defaultModelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) { if !models.IsApplicableTableForDB(mi.AddrField, d.al.Name) {
fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.Table, d.al.Name)
continue continue
} }
if tables[mi.table] { if tables[mi.Table] {
if !d.noInfo { if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table) fmt.Printf("table `%s` already exists, skip\n", mi.Table)
} }
var fields []*fieldInfo var fields []*models.FieldInfo
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.Table)
if err != nil { if err != nil {
if d.rtOnError { if d.rtOnError {
return err return err
@ -162,8 +166,8 @@ func (d *commandSyncDb) Run() error {
fmt.Printf(" %s\n", err.Error()) fmt.Printf(" %s\n", err.Error())
} }
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.Fields.FieldsDB {
if _, ok := columns[fi.column]; !ok { if _, ok := columns[fi.Column]; !ok {
fields = append(fields, fi) fields = append(fields, fi)
} }
} }
@ -172,7 +176,7 @@ func (d *commandSyncDb) Run() error {
query := getColumnAddQuery(d.al, fi) query := getColumnAddQuery(d.al, fi)
if !d.noInfo { if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) fmt.Printf("add column `%s` for table `%s`\n", fi.FullName, mi.Table)
} }
_, err := db.Exec(query) _, err := db.Exec(query)
@ -187,7 +191,7 @@ func (d *commandSyncDb) Run() error {
} }
} }
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.Table] {
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo { if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
@ -211,11 +215,11 @@ func (d *commandSyncDb) Run() error {
} }
if !d.noInfo { if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table) fmt.Printf("create table `%s` \n", mi.Table)
} }
queries := []string{createQueries[i]} queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.Table] {
queries = append(queries, idx.SQL) queries = append(queries, idx.SQL)
} }
@ -265,7 +269,7 @@ func (d *commandSQLAll) Run() error {
var all []string var all []string
for i, mi := range defaultModelCache.allOrdered() { for i, mi := range defaultModelCache.allOrdered() {
queries := []string{createQueries[i]} queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.Table] {
queries = append(queries, idx.SQL) queries = append(queries, idx.SQL)
} }
sql := strings.Join(queries, "\n") sql := strings.Join(queries, "\n")

View File

@ -17,6 +17,8 @@ package orm
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
type dbIndex struct { type dbIndex struct {
@ -26,17 +28,17 @@ type dbIndex struct {
} }
// get database column type string. // get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) { func getColumnTyp(al *alias, fi *models.FieldInfo) (col string) {
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
fieldType := fi.fieldType fieldType := fi.FieldType
fieldSize := fi.size fieldSize := fi.Size
checkColumn: checkColumn:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeVarCharField: case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText { if al.Driver == DRPostgres && fi.ToText {
col = T["string-text"] col = T["string-text"]
} else { } else {
col = fmt.Sprintf(T["string"], fieldSize) col = fmt.Sprintf(T["string"], fieldSize)
@ -51,11 +53,11 @@ checkColumn:
col = T["time.Time-date"] col = T["time.Time-date"]
case TypeDateTimeField: case TypeDateTimeField:
// the precision of sqlite is not implemented // the precision of sqlite is not implemented
if al.Driver == 2 || fi.timePrecision == nil { if al.Driver == 2 || fi.TimePrecision == nil {
col = T["time.Time"] col = T["time.Time"]
} else { } else {
s := T["time.Time-precision"] s := T["time.Time-precision"]
col = fmt.Sprintf(s, *fi.timePrecision) col = fmt.Sprintf(s, *fi.TimePrecision)
} }
case TypeBitField: case TypeBitField:
@ -85,7 +87,7 @@ checkColumn:
if !strings.Contains(s, "%d") { if !strings.Contains(s, "%d") {
col = s col = s
} else { } else {
col = fmt.Sprintf(s, fi.digits, fi.decimals) col = fmt.Sprintf(s, fi.Digits, fi.Decimals)
} }
case TypeJSONField: case TypeJSONField:
if al.Driver != DRPostgres { if al.Driver != DRPostgres {
@ -100,8 +102,8 @@ checkColumn:
} }
col = T["jsonb"] col = T["jsonb"]
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.RelModelInfo.Fields.Pk.FieldType
fieldSize = fi.relModelInfo.fields.pk.size fieldSize = fi.RelModelInfo.Fields.Pk.Size
goto checkColumn goto checkColumn
} }
@ -109,34 +111,34 @@ checkColumn:
} }
// create alter sql string. // create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string { func getColumnAddQuery(al *alias, fi *models.FieldInfo) string {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi) typ := getColumnTyp(al, fi)
if !fi.null { if !fi.Null {
typ += " " + "NOT NULL" typ += " " + "NOT NULL"
} }
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
Q, fi.mi.table, Q, Q, fi.Mi.Table, Q,
Q, fi.column, Q, Q, fi.Column, Q,
typ, getColumnDefault(fi), typ, getColumnDefault(fi),
) )
} }
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string { func getColumnDefault(fi *models.FieldInfo) string {
var v, t, d string var v, t, d string
// Skip default attribute if field is in relations // Skip default attribute if field is in relations
if fi.rel || fi.reverse { if fi.Rel || fi.Reverse {
return v return v
} }
t = " DEFAULT '%s' " t = " DEFAULT '%s' "
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType { switch fi.FieldType {
case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v return v
@ -153,14 +155,14 @@ func getColumnDefault(fi *fieldInfo) string {
d = "{}" d = "{}"
} }
if fi.colDefault { if fi.ColDefault {
if !fi.initial.Exist() { if !fi.Initial.Exist() {
v = fmt.Sprintf(t, "") v = fmt.Sprintf(t, "")
} else { } else {
v = fmt.Sprintf(t, fi.initial.String()) v = fmt.Sprintf(t, fi.Initial.String())
} }
} else { } else {
if !fi.null { if !fi.Null {
v = fmt.Sprintf(t, d) v = fmt.Sprintf(t, d)
} }
} }

View File

@ -23,13 +23,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/client/orm/internal/logs"
)
const ( "github.com/beego/beego/v2/client/orm/internal/utils"
formatTime = "15:04:05"
formatDate = "2006-01-02" "github.com/beego/beego/v2/client/orm/internal/models"
formatDateTime = "2006-01-02 15:04:05"
"github.com/beego/beego/v2/client/orm/hints"
) )
// ErrMissPK missing pk error // ErrMissPK missing pk error
@ -72,8 +72,8 @@ type dbBase struct {
// check dbBase implements dbBaser interface. // check dbBase implements dbBaser interface.
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
// get struct columns values as interface slice. // get struct Columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { func (d *dbBase) collectValues(mi *models.ModelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
if names == nil { if names == nil {
ns := make([]string, 0, len(cols)) ns := make([]string, 0, len(cols))
names = &ns names = &ns
@ -81,13 +81,13 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
values = make([]interface{}, 0, len(cols)) values = make([]interface{}, 0, len(cols))
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *models.FieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil { if fi, _ = mi.Fields.GetByAny(column); fi != nil {
column = fi.column column = fi.Column
} else { } else {
panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.FullName))
} }
if !fi.dbcol || fi.auto && skipAuto { if !fi.DBcol || fi.Auto && skipAuto {
continue continue
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
@ -96,8 +96,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
// ignore empty value auto field // ignore empty value auto field
if insert && fi.auto { if insert && fi.Auto {
if fi.fieldType&IsPositiveIntegerField > 0 { if fi.FieldType&IsPositiveIntegerField > 0 {
if vu, ok := value.(uint64); !ok || vu == 0 { if vu, ok := value.(uint64); !ok || vu == 0 {
continue continue
} }
@ -106,7 +106,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
continue continue
} }
} }
autoFields = append(autoFields, fi.column) autoFields = append(autoFields, fi.Column)
} }
*names, values = append(*names, column), append(values, value) *names, values = append(*names, column), append(values, value)
@ -116,17 +116,17 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
// get one field value in struct column as interface. // get one field value in struct column as interface.
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { func (d *dbBase) collectFieldValue(mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{} var value interface{}
if fi.pk { if fi.Pk {
_, value, _ = getExistPk(mi, ind) _, value, _ = getExistPk(mi, ind)
} else { } else {
field := ind.FieldByIndex(fi.fieldIndex) field := ind.FieldByIndex(fi.FieldIndex)
if fi.isFielder { if fi.IsFielder {
f := field.Addr().Interface().(Fielder) f := field.Addr().Interface().(models.Fielder)
value = f.RawValue() value = f.RawValue()
} else { } else {
switch fi.fieldType { switch fi.FieldType {
case TypeBooleanField: case TypeBooleanField:
if nb, ok := field.Interface().(sql.NullBool); ok { if nb, ok := field.Interface().(sql.NullBool); ok {
value = nil value = nil
@ -172,7 +172,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
vu := field.Interface() vu := field.Interface()
if _, ok := vu.(float32); ok { if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64() value, _ = utils.StrTo(utils.ToStr(vu)).Float64()
} else { } else {
value = field.Float() value = field.Float()
} }
@ -189,7 +189,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
default: default:
switch { switch {
case fi.fieldType&IsPositiveIntegerField > 0: case fi.FieldType&IsPositiveIntegerField > 0:
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Ptr {
if field.IsNil() { if field.IsNil() {
value = nil value = nil
@ -199,7 +199,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Uint() value = field.Uint()
} }
case fi.fieldType&IsIntegerField > 0: case fi.FieldType&IsIntegerField > 0:
if ni, ok := field.Interface().(sql.NullInt64); ok { if ni, ok := field.Interface().(sql.NullInt64); ok {
value = nil value = nil
if ni.Valid { if ni.Valid {
@ -214,25 +214,25 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Int() value = field.Int()
} }
case fi.fieldType&IsRelField > 0: case fi.FieldType&IsRelField > 0:
if field.IsNil() { if field.IsNil() {
value = nil value = nil
} else { } else {
if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { if _, vu, ok := getExistPk(fi.RelModelInfo, reflect.Indirect(field)); ok {
value = vu value = vu
} else { } else {
value = nil value = nil
} }
} }
if !fi.null && value == nil { if !fi.Null && value == nil {
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) return nil, fmt.Errorf("field `%s` cannot be NULL", fi.FullName)
} }
} }
} }
} }
switch fi.fieldType { switch fi.FieldType {
case TypeTimeField, TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if fi.autoNow || fi.autoNowAdd && insert { if fi.AutoNow || fi.AutoNowAdd && insert {
if insert { if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() { if t, ok := value.(time.Time); ok && !t.IsZero() {
break break
@ -241,8 +241,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
tnow := time.Now() tnow := time.Now()
d.ins.TimeToDB(&tnow, tz) d.ins.TimeToDB(&tnow, tz)
value = tnow value = tnow
if fi.isFielder { if fi.IsFielder {
f := field.Addr().Interface().(Fielder) f := field.Addr().Interface().(models.Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc)) f.SetRaw(tnow.In(DefaultTimeLoc))
} else if field.Kind() == reflect.Ptr { } else if field.Kind() == reflect.Ptr {
v := tnow.In(DefaultTimeLoc) v := tnow.In(DefaultTimeLoc)
@ -253,8 +253,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
case TypeJSONField, TypeJsonbField: case TypeJSONField, TypeJsonbField:
if s, ok := value.(string); (ok && len(s) == 0) || value == nil { if s, ok := value.(string); (ok && len(s) == 0) || value == nil {
if fi.colDefault && fi.initial.Exist() { if fi.ColDefault && fi.Initial.Exist() {
value = fi.initial.String() value = fi.Initial.String()
} else { } else {
value = nil value = nil
} }
@ -265,14 +265,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
// PrepareInsert create insert sql preparation statement object. // PrepareInsert create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *models.ModelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
dbcols := make([]string, 0, len(mi.fields.dbcols)) dbcols := make([]string, 0, len(mi.Fields.DBcols))
marks := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.Fields.DBcols))
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.Fields.FieldsDB {
if !fi.auto { if !fi.Auto {
dbcols = append(dbcols, fi.column) dbcols = append(dbcols, fi.Column)
marks = append(marks, "?") marks = append(marks, "?")
} }
} }
@ -280,7 +280,7 @@ func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo)
sep := fmt.Sprintf("%s, %s", Q, Q) sep := fmt.Sprintf("%s, %s", Q, Q)
columns := strings.Join(dbcols, sep) columns := strings.Join(dbcols, sep)
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -291,8 +291,8 @@ func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo)
} }
// InsertStmt insert struct with prepared statement and given struct reflect value. // InsertStmt insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -311,7 +311,7 @@ func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo
} }
// query sql ,read records and persist in dbBaser. // query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
@ -336,8 +336,8 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
sep := fmt.Sprintf("%s, %s", Q, Q) sep := fmt.Sprintf("%s, %s", Q, Q)
sels := strings.Join(mi.fields.dbcols, sep) sels := strings.Join(mi.Fields.DBcols, sep)
colsNum := len(mi.fields.dbcols) colsNum := len(mi.Fields.DBcols)
sep = fmt.Sprintf("%s = ? AND %s", Q, Q) sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
wheres := strings.Join(whereCols, sep) wheres := strings.Join(whereCols, sep)
@ -347,7 +347,7 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle
forUpdate = "FOR UPDATE" forUpdate = "FOR UPDATE"
} }
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.Table, Q, Q, wheres, Q, forUpdate)
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
@ -364,17 +364,17 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle
} }
return err return err
} }
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.AddrField.Elem().Type())
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) d.setColsValues(mi, &mind, mi.Fields.DBcols, refs, tz)
ind.Set(mind) ind.Set(mind)
return nil return nil
} }
// Insert execute insert sql dbQuerier with given struct reflect.Value. // Insert execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols)) names := make([]string, 0, len(mi.Fields.DBcols))
values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) values, autoFields, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -391,7 +391,7 @@ func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
} }
// InsertMulti multi-insert sql with given slice struct reflect.Value. // InsertMulti multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *models.ModelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var ( var (
cnt int64 cnt int64
nums int nums int
@ -399,32 +399,25 @@ func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, si
names []string names []string
) )
// typ := reflect.Indirect(mi.addrField).Type()
length, autoFields := sind.Len(), make([]string, 0, 1) length, autoFields := sind.Len(), make([]string, 0, 1)
for i := 1; i <= length; i++ { for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1)) ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 { if i == 1 {
var ( var (
vus []interface{} vus []interface{}
err error err error
) )
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) vus, autoFields, err = d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
values = make([]interface{}, bulk*len(vus)) values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus) nums += copy(values, vus)
} else { } else {
vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) vus, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, nil, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -456,7 +449,7 @@ func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, si
// InsertValue execute insert sql with given struct and given values. // InsertValue execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -474,7 +467,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
} }
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -487,7 +480,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is
lastInsertId, err := res.LastInsertId() lastInsertId, err := res.LastInsertId()
if err != nil { if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable return lastInsertId, ErrLastInsertIdUnavailable
} else { } else {
return lastInsertId, nil return lastInsertId, nil
@ -504,7 +497,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is
// InsertOrUpdate a row // InsertOrUpdate a row
// If your primary key or unique column conflict will update // If your primary key or unique column conflict will update
// If no will insert // If no will insert
func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
args0 := "" args0 := ""
iouStr := "" iouStr := ""
argsMap := map[string]string{} argsMap := map[string]string{}
@ -530,9 +523,9 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo,
} }
isMulti := false isMulti := false
names := make([]string, 0, len(mi.fields.dbcols)-1) names := make([]string, 0, len(mi.Fields.DBcols)-1)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -556,7 +549,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo,
case DRPostgres: case DRPostgres:
if conflitValue != nil { if conflitValue != nil {
// postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0)
updateValues = append(updateValues, conflitValue) updateValues = append(updateValues, conflitValue)
} else { } else {
return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
@ -581,7 +574,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo,
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
} }
// conflitValue maybe is a int,can`t use fmt.Sprintf // conflitValue maybe is a int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -594,7 +587,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo,
lastInsertId, err := res.LastInsertId() lastInsertId, err := res.LastInsertId()
if err != nil { if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable return lastInsertId, ErrLastInsertIdUnavailable
} else { } else {
return lastInsertId, nil return lastInsertId, nil
@ -613,7 +606,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo,
} }
// Update execute update sql dbQuerier with given struct reflect.Value. // Update execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if !ok { if !ok {
return 0, ErrMissPK return 0, ErrMissPK
@ -621,10 +614,10 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
var setNames []string var setNames []string
// if specify cols length is zero, then commit all columns. // if specify cols length is zero, then commit all Columns.
if len(cols) == 0 { if len(cols) == 0 {
cols = mi.fields.dbcols cols = mi.Fields.DBcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1) setNames = make([]string, 0, len(mi.Fields.DBcols)-1)
} else { } else {
setNames = make([]string, 0, len(cols)) setNames = make([]string, 0, len(cols))
} }
@ -637,11 +630,11 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
var findAutoNowAdd, findAutoNow bool var findAutoNowAdd, findAutoNow bool
var index int var index int
for i, col := range setNames { for i, col := range setNames {
if mi.fields.GetByColumn(col).autoNowAdd { if mi.Fields.GetByColumn(col).AutoNowAdd {
index = i index = i
findAutoNowAdd = true findAutoNowAdd = true
} }
if mi.fields.GetByColumn(col).autoNow { if mi.Fields.GetByColumn(col).AutoNow {
findAutoNow = true findAutoNow = true
} }
} }
@ -651,8 +644,8 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
} }
if !findAutoNow { if !findAutoNow {
for col, info := range mi.fields.columns { for col, info := range mi.Fields.Columns {
if info.autoNow { if info.AutoNow {
setNames = append(setNames, col) setNames = append(setNames, col)
setValues = append(setValues, time.Now()) setValues = append(setValues, time.Now())
} }
@ -666,7 +659,7 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
sep := fmt.Sprintf("%s = ?, %s", Q, Q) sep := fmt.Sprintf("%s = ?, %s", Q, Q)
setColumns := strings.Join(setNames, sep) setColumns := strings.Join(setNames, sep)
query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.Table, Q, Q, setColumns, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -679,7 +672,7 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
// Delete execute delete sql dbQuerier with given struct reflect.Value. // Delete execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk. // delete index is pk.
func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
@ -705,7 +698,7 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
sep := fmt.Sprintf("%s = ? AND %s", Q, Q) sep := fmt.Sprintf("%s = ? AND %s", Q, Q)
wheres := strings.Join(whereCols, sep) wheres := strings.Join(whereCols, sep)
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.Table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.ExecContext(ctx, query, args...) res, err := q.ExecContext(ctx, query, args...)
@ -727,14 +720,14 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref
// UpdateBatch update table-related record by querySet. // UpdateBatch update table-related record by querySet.
// need querySet not struct reflect.Value to update related records. // need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
for col, val := range params { for col, val := range params {
if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { if fi, ok := mi.Fields.GetByAny(col); !ok || !fi.DBcol {
panic(fmt.Errorf("wrong field/column name `%s`", col)) panic(fmt.Errorf("wrong field/column name `%s`", col))
} else { } else {
columns = append(columns, fi.column) columns = append(columns, fi.Column)
values = append(values, val) values = append(values, val)
} }
} }
@ -747,7 +740,7 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
var specifyIndexes string var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
} }
where, args := tables.getCondSQL(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
@ -798,13 +791,13 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
sets := strings.Join(cols, ", ") + " " sets := strings.Join(cols, ", ") + " "
if d.ins.SupportUpdateJoin() { if d.ins.SupportUpdateJoin() {
query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where) query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.Table, Q, specifyIndexes, join, sets, where)
} else { } else {
supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s", supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s",
Q, mi.fields.pk.column, Q, Q, mi.Fields.Pk.Column, Q,
Q, mi.table, Q, Q, mi.Table, Q,
specifyIndexes, join, where) specifyIndexes, join, where)
query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.Table, Q, sets, Q, mi.Fields.Pk.Column, Q, supQuery)
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -817,41 +810,41 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
// delete related records. // delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship. // do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *models.ModelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.Fields.FieldsReverse {
fi = fi.reverseFieldInfo fi = fi.ReverseFieldInfo
switch fi.onDelete { switch fi.OnDelete {
case odCascade: case models.OdCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...)
_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) _, err := d.DeleteBatch(ctx, q, nil, fi.Mi, cond, tz)
if err != nil { if err != nil {
return err return err
} }
case odSetDefault, odSetNULL: case models.OdSetDefault, models.OdSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...)
params := Params{fi.column: nil} params := Params{fi.Column: nil}
if fi.onDelete == odSetDefault { if fi.OnDelete == models.OdSetDefault {
params[fi.column] = fi.initial.String() params[fi.Column] = fi.Initial.String()
} }
_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) _, err := d.UpdateBatch(ctx, q, nil, fi.Mi, cond, params, tz)
if err != nil { if err != nil {
return err return err
} }
case odDoNothing: case models.OdDoNothing:
} }
} }
return nil return nil
} }
// DeleteBatch delete table-related records. // DeleteBatch delete table-related records.
func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true tables.skipEnd = true
var specifyIndexes string var specifyIndexes string
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
} }
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
@ -863,8 +856,8 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
where, args := tables.getCondSQL(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
join := tables.getJoinSQL() join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.Fields.Pk.Column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.Table, Q, specifyIndexes, join, where)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -883,7 +876,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
if err := rs.Scan(&ref); err != nil { if err := rs.Scan(&ref); err != nil {
return 0, err return 0, err
} }
pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) pkValue, err := d.convertValueFromDB(mi.Fields.Pk, reflect.ValueOf(ref).Interface(), tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -900,7 +893,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
marks[i] = "?" marks[i] = "?"
} }
sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.Table, Q, Q, mi.Fields.Pk.Column, Q, sqlIn)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.ExecContext(ctx, query, args...) res, err := q.ExecContext(ctx, query, args...)
@ -921,7 +914,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi
} }
// ReadBatch read related records. // ReadBatch read related records.
func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -937,17 +930,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
typ := ind.Type().Elem() typ := ind.Type().Elem()
switch typ.Kind() { switch typ.Kind() {
case reflect.Ptr: case reflect.Ptr:
fn = getFullName(typ.Elem()) fn = models.GetFullName(typ.Elem())
case reflect.Struct: case reflect.Struct:
isPtr = false isPtr = false
fn = getFullName(typ) fn = models.GetFullName(typ)
name = getTableName(reflect.New(typ)) name = models.GetTableName(reflect.New(typ))
} }
} else { } else {
fn = getFullName(ind.Type()) fn = models.GetFullName(ind.Type())
name = getTableName(ind) name = models.GetTableName(ind)
} }
unregister = fn != mi.fullName unregister = fn != mi.FullName
} }
if unregister { if unregister {
@ -968,26 +961,26 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
maps = make(map[string]bool) maps = make(map[string]bool)
} }
for _, col := range cols { for _, col := range cols {
if fi, ok := mi.fields.GetByAny(col); ok { if fi, ok := mi.Fields.GetByAny(col); ok {
tCols = append(tCols, fi.column) tCols = append(tCols, fi.Column)
if hasRel { if hasRel {
maps[fi.column] = true maps[fi.Column] = true
} }
} else { } else {
return 0, fmt.Errorf("wrong field/column name `%s`", col) return 0, fmt.Errorf("wrong field/column name `%s`", col)
} }
} }
if hasRel { if hasRel {
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.Fields.FieldsDB {
if fi.fieldType&IsRelField > 0 { if fi.FieldType&IsRelField > 0 {
if !maps[fi.column] { if !maps[fi.Column] {
tCols = append(tCols, fi.column) tCols = append(tCols, fi.Column)
} }
} }
} }
} }
} else { } else {
tCols = mi.fields.dbcols tCols = mi.Fields.DBcols
} }
colsNum := len(tCols) colsNum := len(tCols)
@ -1002,13 +995,13 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit) limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
if tbl.sel { if tbl.sel {
colsNum += len(tbl.mi.fields.dbcols) colsNum += len(tbl.mi.Fields.DBcols)
sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.Fields.DBcols, sep), Q)
} }
} }
@ -1020,7 +1013,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
sels = qs.aggregate sels = qs.aggregate
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q, sqlSelect, sels, Q, mi.Table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit) specifyIndexes, join, where, groupBy, orderBy, limit)
if qs.forUpdate { if qs.forUpdate {
@ -1039,7 +1032,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
slice := ind slice := ind
if unregister { if unregister {
mi, _ = defaultModelCache.get(name) mi, _ = defaultModelCache.get(name)
tCols = mi.fields.dbcols tCols = mi.Fields.DBcols
colsNum = len(tCols) colsNum = len(tCols)
} }
@ -1055,11 +1048,11 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
return 0, err return 0, err
} }
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.AddrField.Elem().Type())
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
cacheV := make(map[string]*reflect.Value) cacheV := make(map[string]*reflect.Value)
cacheM := make(map[string]*modelInfo) cacheM := make(map[string]*models.ModelInfo)
trefs := refs trefs := refs
d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
@ -1078,18 +1071,18 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
last = *val last = *val
mmi = cacheM[names] mmi = cacheM[names]
} else { } else {
fi := mmi.fields.GetByName(name) fi := mmi.Fields.GetByName(name)
lastm := mmi lastm := mmi
mmi = fi.relModelInfo mmi = fi.RelModelInfo
field := last field := last
if last.Kind() != reflect.Invalid { if last.Kind() != reflect.Invalid {
field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) field = reflect.Indirect(last.FieldByIndex(fi.FieldIndex))
if field.IsValid() { if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) d.setColsValues(mmi, &field, mmi.Fields.DBcols, trefs[:len(mmi.Fields.DBcols)], tz)
for _, fi := range mmi.fields.fieldsReverse { for _, fi := range mmi.Fields.FieldsReverse {
if fi.inModel && fi.reverseFieldInfo.mi == lastm { if fi.InModel && fi.ReverseFieldInfo.Mi == lastm {
if fi.reverseFieldInfo != nil { if fi.ReverseFieldInfo != nil {
f := field.FieldByIndex(fi.fieldIndex) f := field.FieldByIndex(fi.FieldIndex)
if f.Kind() == reflect.Ptr { if f.Kind() == reflect.Ptr {
f.Set(last.Addr()) f.Set(last.Addr())
} }
@ -1103,7 +1096,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
cacheM[names] = mmi cacheM[names] = mmi
} }
} }
trefs = trefs[len(mmi.fields.dbcols):] trefs = trefs[len(mmi.Fields.DBcols):]
} }
} }
@ -1146,7 +1139,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m
} }
// Count excute count sql and return count result int64. // Count excute count sql and return count result int64.
func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -1154,12 +1147,12 @@ func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *model
groupBy := tables.getGroupSQL(qs.groups) groupBy := tables.getGroupSQL(qs.groups)
tables.getOrderSQL(qs.orders) tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s",
Q, mi.table, Q, Q, mi.Table, Q,
specifyIndexes, join, where, groupBy) specifyIndexes, join, where, groupBy)
if groupBy != "" { if groupBy != "" {
@ -1174,7 +1167,7 @@ func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *model
} }
// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. // GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
var sql string var sql string
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -1206,7 +1199,7 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri
params[0] = "IS NULL" params[0] = "IS NULL"
} }
case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith":
param := strings.Replace(ToStr(arg), `%`, `\%`, -1) param := strings.Replace(utils.ToStr(arg), `%`, `\%`, -1)
switch operator { switch operator {
case "iexact": case "iexact":
case "contains", "icontains": case "contains", "icontains":
@ -1234,18 +1227,18 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri
} }
// GenerateOperatorLeftCol gernerate sql string with inner function, such as UPPER(text). // GenerateOperatorLeftCol gernerate sql string with inner function, such as UPPER(text).
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { func (d *dbBase) GenerateOperatorLeftCol(*models.FieldInfo, string, *string) {
// default not use // default not use
} }
// set values to struct column. // set values to struct column.
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { func (d *dbBase) setColsValues(mi *models.ModelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols { for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
fi := mi.fields.GetByColumn(column) fi := mi.Fields.GetByColumn(column)
field := ind.FieldByIndex(fi.fieldIndex) field := ind.FieldByIndex(fi.FieldIndex)
value, err := d.convertValueFromDB(fi, val, tz) value, err := d.convertValueFromDB(fi, val, tz)
if err != nil { if err != nil {
@ -1261,7 +1254,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
} }
// convert value from database result to value following in field type. // convert value from database result to value following in field type.
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { func (d *dbBase) convertValueFromDB(fi *models.FieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil { if val == nil {
return nil, nil return nil, nil
} }
@ -1269,17 +1262,17 @@ func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Loc
var value interface{} var value interface{}
var tErr error var tErr error
var str *StrTo var str *utils.StrTo
switch v := val.(type) { switch v := val.(type) {
case []byte: case []byte:
s := StrTo(string(v)) s := utils.StrTo(string(v))
str = &s str = &s
case string: case string:
s := StrTo(v) s := utils.StrTo(v)
str = &s str = &s
} }
fieldType := fi.fieldType fieldType := fi.FieldType
setValue: setValue:
switch { switch {
@ -1290,7 +1283,7 @@ setValue:
b := v == 1 b := v == 1
value = b value = b
default: default:
s := StrTo(ToStr(v)) s := utils.StrTo(utils.ToStr(v))
str = &s str = &s
} }
} }
@ -1304,7 +1297,7 @@ setValue:
} }
case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil { if str == nil {
value = ToStr(val) value = utils.ToStr(val)
} else { } else {
value = str.String() value = str.String()
} }
@ -1315,7 +1308,7 @@ setValue:
d.ins.TimeFromDB(&t, tz) d.ins.TimeFromDB(&t, tz)
value = t value = t
default: default:
s := StrTo(ToStr(t)) s := utils.StrTo(utils.ToStr(t))
str = &s str = &s
} }
} }
@ -1326,25 +1319,25 @@ setValue:
err error err error
) )
if fi.timePrecision != nil && len(s) >= (20+*fi.timePrecision) { if fi.TimePrecision != nil && len(s) >= (20+*fi.TimePrecision) {
layout := formatDateTime + "." layout := utils.FormatDateTime + "."
for i := 0; i < *fi.timePrecision; i++ { for i := 0; i < *fi.TimePrecision; i++ {
layout += "0" layout += "0"
} }
t, err = time.ParseInLocation(layout, s[:20+*fi.timePrecision], tz) t, err = time.ParseInLocation(layout, s[:20+*fi.TimePrecision], tz)
} else if len(s) >= 19 { } else if len(s) >= 19 {
s = s[:19] s = s[:19]
t, err = time.ParseInLocation(formatDateTime, s, tz) t, err = time.ParseInLocation(utils.FormatDateTime, s, tz)
} else if len(s) >= 10 { } else if len(s) >= 10 {
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(utils.FormatDate, s, tz)
} else if len(s) >= 8 { } else if len(s) >= 8 {
if len(s) > 8 { if len(s) > 8 {
s = s[:8] s = s[:8]
} }
t, err = time.ParseInLocation(formatTime, s, tz) t, err = time.ParseInLocation(utils.FormatTime, s, tz)
} }
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
@ -1356,7 +1349,7 @@ setValue:
} }
case fieldType&IsIntegerField > 0: case fieldType&IsIntegerField > 0:
if str == nil { if str == nil {
s := StrTo(ToStr(val)) s := utils.StrTo(utils.ToStr(val))
str = &s str = &s
} }
if str != nil { if str != nil {
@ -1397,7 +1390,7 @@ setValue:
case float64: case float64:
value = v value = v
default: default:
s := StrTo(ToStr(v)) s := utils.StrTo(utils.ToStr(v))
str = &s str = &s
} }
} }
@ -1410,14 +1403,14 @@ setValue:
value = v value = v
} }
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
fi = fi.relModelInfo.fields.pk fi = fi.RelModelInfo.Fields.Pk
fieldType = fi.fieldType fieldType = fi.FieldType
goto setValue goto setValue
} }
end: end:
if tErr != nil { if tErr != nil {
err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.AddrValue.Type(), fi.FullName, tErr)
return nil, err return nil, err
} }
@ -1425,9 +1418,9 @@ end:
} }
// set one value to struct column field. // set one value to struct column field.
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *models.FieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.FieldType
isNative := !fi.isFielder isNative := !fi.IsFielder
setValue: setValue:
switch { switch {
@ -1594,20 +1587,20 @@ setValue:
} }
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
if value != nil { if value != nil {
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.RelModelInfo.Fields.Pk.FieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type())
field.Set(mf) field.Set(mf)
f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) f := mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex)
field = f field = f
goto setValue goto setValue
} }
} }
if !isNative { if !isNative {
fd := field.Addr().Interface().(Fielder) fd := field.Addr().Interface().(models.Fielder)
err := fd.SetRaw(value) err := fd.SetRaw(value)
if err != nil { if err != nil {
err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.FullName, err)
return nil, err return nil, err
} }
} }
@ -1616,7 +1609,7 @@ setValue:
} }
// ReadValues query sql, read values , save to *[]ParamList. // ReadValues query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
maps []Params maps []Params
lists []ParamsList lists []ParamsList
@ -1651,7 +1644,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
var ( var (
cols []string cols []string
infos []*fieldInfo infos []*models.FieldInfo
) )
hasExprs := len(exprs) > 0 hasExprs := len(exprs) > 0
@ -1660,20 +1653,20 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
if hasExprs { if hasExprs {
cols = make([]string, 0, len(exprs)) cols = make([]string, 0, len(exprs))
infos = make([]*fieldInfo, 0, len(exprs)) infos = make([]*models.FieldInfo, 0, len(exprs))
for _, ex := range exprs { for _, ex := range exprs {
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if !suc { if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", ex)) panic(fmt.Errorf("unknown field/column name `%s`", ex))
} }
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.Column, Q, Q, name, Q))
infos = append(infos, fi) infos = append(infos, fi)
} }
} else { } else {
cols = make([]string, 0, len(mi.fields.dbcols)) cols = make([]string, 0, len(mi.Fields.DBcols))
infos = make([]*fieldInfo, 0, len(exprs)) infos = make([]*models.FieldInfo, 0, len(exprs))
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.Fields.FieldsDB {
cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.Column, Q, Q, fi.Name, Q))
infos = append(infos, fi) infos = append(infos, fi)
} }
} }
@ -1683,7 +1676,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
orderBy := tables.getOrderSQL(qs.orders) orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit) limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL() join := tables.getJoinSQL()
specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes)
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
@ -1693,7 +1686,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, sqlSelect, sels,
Q, mi.table, Q, Q, mi.Table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit) specifyIndexes, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1808,12 +1801,12 @@ func (d *dbBase) ReplaceMarks(query *string) {
} }
// flag of RETURNING sql. // flag of RETURNING sql.
func (d *dbBase) HasReturningID(*modelInfo, *string) bool { func (d *dbBase) HasReturningID(*models.ModelInfo, *string) bool {
return false return false
} }
// sync auto key // sync auto key
func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error {
return nil return nil
} }
@ -1923,7 +1916,7 @@ func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []
case hints.KeyIgnoreIndex: case hints.KeyIgnoreIndex:
useWay = `IGNORE` useWay = `IGNORE`
default: default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return `` return ``
} }

View File

@ -21,6 +21,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/logs"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
) )
@ -320,7 +322,7 @@ func detectTZ(al *alias) {
al.TZ = t.Location() al.TZ = t.Location()
} }
} else { } else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
@ -347,7 +349,7 @@ func detectTZ(al *alias) {
if err == nil { if err == nil {
al.TZ = loc al.TZ = loc
} else { } else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
} }
@ -479,7 +481,7 @@ end:
if db != nil { if db != nil {
db.Close() db.Close()
} }
DebugLog.Println(err.Error()) logs.DebugLog.Println(err.Error())
} }
return err return err

View File

@ -19,6 +19,10 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// mysql operators. // mysql operators.
@ -72,28 +76,28 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator. // OperatorSQL get mysql operator.
func (d *dbBaseMysql) OperatorSQL(operator string) string { func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }
// get mysql table field types. // DbTypes get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string { func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes return mysqlTypes
} }
// show table sql for mysql. // ShowTablesQuery show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string { func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
} }
// show columns sql of table for mysql. // ShowColumnsQuery show Columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string { func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table) "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
} }
// execute sql to check index exist. // IndexExists execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@ -106,7 +110,7 @@ func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table strin
// If your primary key or unique column conflict will update // If your primary key or unique column conflict will update
// If no will insert // If no will insert
// Add "`" for mysql sql building // Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string var iouStr string
argsMap := map[string]string{} argsMap := map[string]string{}
@ -120,10 +124,9 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model
} }
} }
isMulti := false names := make([]string, 0, len(mi.Fields.DBcols)-1)
names := make([]string, 0, len(mi.fields.dbcols)-1)
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -150,26 +153,17 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model
qupdates := strings.Join(updates, ", ") qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep) columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
// conflitValue maybe is an int,can`t use fmt.Sprintf // conflitValue maybe is an int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if !d.ins.HasReturningID(mi, &query) {
res, err := q.ExecContext(ctx, query, values...) res, err := q.ExecContext(ctx, query, values...)
if err == nil { if err == nil {
if isMulti {
return res.RowsAffected()
}
lastInsertId, err := res.LastInsertId() lastInsertId, err := res.LastInsertId()
if err != nil { if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable return lastInsertId, ErrLastInsertIdUnavailable
} else { } else {
return lastInsertId, nil return lastInsertId, nil

View File

@ -19,6 +19,10 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/client/orm/hints"
) )
@ -116,16 +120,16 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde
case hints.KeyIgnoreIndex: case hints.KeyIgnoreIndex:
hint = `NO_INDEX` hint = `NO_INDEX`
default: default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return `` return ``
} }
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`)) return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
} }
// execute insert sql with given struct and given values. // InsertValue execute insert sql with given struct and given values.
// insert the given values, not the field values in struct. // insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -143,7 +147,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
} }
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -156,7 +160,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn
lastInsertId, err := res.LastInsertId() lastInsertId, err := res.LastInsertId()
if err != nil { if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable return lastInsertId, ErrLastInsertIdUnavailable
} else { } else {
return lastInsertId, nil return lastInsertId, nil

View File

@ -18,6 +18,10 @@ import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// postgresql operators. // postgresql operators.
@ -76,7 +80,7 @@ func (d *dbBasePostgres) OperatorSQL(operator string) string {
} }
// generate functioned sql string, such as contains(text). // generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) {
switch operator { switch operator {
case "contains", "startswith", "endswith": case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol) *leftCol = fmt.Sprintf("%s::text", *leftCol)
@ -128,20 +132,20 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
} }
// make returning sql support for postgresql. // make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { func (d *dbBasePostgres) HasReturningID(mi *models.ModelInfo, query *string) bool {
fi := mi.fields.pk fi := mi.Fields.Pk
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { if fi.FieldType&IsPositiveIntegerField == 0 && fi.FieldType&IsIntegerField == 0 {
return false return false
} }
if query != nil { if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.Column)
} }
return true return true
} }
// sync auto key // sync auto key
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error {
if len(autoFields) == 0 { if len(autoFields) == 0 {
return nil return nil
} }
@ -149,9 +153,9 @@ func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
for _, name := range autoFields { for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name, mi.Table, name,
Q, name, Q, Q, name, Q,
Q, mi.table, Q) Q, mi.Table, Q)
if _, err := db.ExecContext(ctx, query); err != nil { if _, err := db.ExecContext(ctx, query); err != nil {
return err return err
} }
@ -164,9 +168,9 @@ func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
} }
// show table columns sql for postgresql. // show table Columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string { func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.Columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
} }
// get column types of postgresql. // get column types of postgresql.
@ -185,7 +189,7 @@ func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table st
// GenerateSpecifyIndex return a specifying index clause // GenerateSpecifyIndex return a specifying index clause
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") logs.DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
return `` return ``
} }

View File

@ -22,6 +22,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/client/orm/hints"
) )
@ -74,9 +78,9 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax // override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate { if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") logs.DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
} }
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
} }
@ -88,8 +92,8 @@ func (d *dbBaseSqlite) OperatorSQL(operator string) string {
// generate functioned sql for sqlite. // generate functioned sql for sqlite.
// only support DATE(text). // only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField { if fi.FieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol) *leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
} }
} }
@ -114,7 +118,7 @@ func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'" return "SELECT name FROM sqlite_master WHERE type = 'table'"
} }
// get columns in sqlite. // get Columns in sqlite.
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.QueryContext(ctx, query) rows, err := db.QueryContext(ctx, query)
@ -135,7 +139,7 @@ func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table strin
return columns, nil return columns, nil
} }
// get show columns sql in sqlite. // get show Columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table) return fmt.Sprintf("pragma table_info('%s')", table)
} }
@ -171,7 +175,7 @@ func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, inde
case hints.KeyUseIndex, hints.KeyForceIndex: case hints.KeyUseIndex, hints.KeyForceIndex:
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`)) return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
default: default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return `` return ``
} }
} }

View File

@ -19,6 +19,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses" "github.com/beego/beego/v2/client/orm/clauses"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
) )
@ -31,8 +33,8 @@ type dbTable struct {
names []string names []string
sel bool sel bool
inner bool inner bool
mi *modelInfo mi *models.ModelInfo
fi *fieldInfo fi *models.FieldInfo
jtl *dbTable jtl *dbTable
} }
@ -40,14 +42,14 @@ type dbTable struct {
type dbTables struct { type dbTables struct {
tablesM map[string]*dbTable tablesM map[string]*dbTable
tables []*dbTable tables []*dbTable
mi *modelInfo mi *models.ModelInfo
base dbBaser base dbBaser
skipEnd bool skipEnd bool
} }
// set table info to collection. // set table info to collection.
// if not exist, create new. // if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { func (t *dbTables) set(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok { if j, ok := t.tablesM[name]; ok {
j.name = name j.name = name
@ -64,7 +66,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
} }
// add table info to collection. // add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { func (t *dbTables) add(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; !ok { if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1 i := len(t.tables) + 1
@ -82,29 +84,29 @@ func (t *dbTables) get(name string) (*dbTable, bool) {
return j, ok return j, ok
} }
// get related fields info in recursive depth loop. // get related Fields info in recursive depth loop.
// loop once, depth decreases one. // loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { func (t *dbTables) loopDepth(depth int, prefix string, fi *models.FieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany { if depth < 0 || fi.FieldType == RelManyToMany {
return related return related
} }
if prefix == "" { if prefix == "" {
prefix = fi.name prefix = fi.Name
} else { } else {
prefix = prefix + ExprSep + fi.name prefix = prefix + ExprSep + fi.Name
} }
related = append(related, prefix) related = append(related, prefix)
depth-- depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel { for _, fi := range fi.RelModelInfo.Fields.FieldsRel {
related = t.loopDepth(depth, prefix, fi, related) related = t.loopDepth(depth, prefix, fi, related)
} }
return related return related
} }
// parse related fields. // parse related Fields.
func (t *dbTables) parseRelated(rels []string, depth int) { func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels) relsNum := len(rels)
related := make([]string, relsNum) related := make([]string, relsNum)
@ -117,7 +119,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
relDepth-- relDepth--
for _, fi := range t.mi.fields.fieldsRel { for _, fi := range t.mi.Fields.FieldsRel {
related = t.loopDepth(relDepth, "", fi, related) related = t.loopDepth(relDepth, "", fi, related)
} }
@ -133,18 +135,18 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
inner := true inner := true
for _, ex := range exs { for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { if fi, ok := mmi.Fields.GetByAny(ex); ok && fi.Rel && fi.FieldType != RelManyToMany {
names = append(names, fi.name) names = append(names, fi.Name)
mmi = fi.relModelInfo mmi = fi.RelModelInfo
if fi.null || t.skipEnd { if fi.Null || t.skipEnd {
inner = false inner = false
} }
jt := t.set(names, mmi, fi, inner) jt := t.set(names, mmi, fi, inner)
jt.jtl = jtl jt.jtl = jtl
if fi.reverse { if fi.Reverse {
cancel = false cancel = false
} }
@ -185,24 +187,24 @@ func (t *dbTables) getJoinSQL() (join string) {
t1 = jt.jtl.index t1 = jt.jtl.index
} }
t2 = jt.index t2 = jt.index
table = jt.mi.table table = jt.mi.Table
switch { switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: case jt.fi.FieldType == RelManyToMany || jt.fi.FieldType == RelReverseMany || jt.fi.Reverse && jt.fi.ReverseFieldInfo.FieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column c1 = jt.fi.Mi.Fields.Pk.Column
for _, ffi := range jt.mi.fields.fieldsRel { for _, ffi := range jt.mi.Fields.FieldsRel {
if jt.fi.mi == ffi.relModelInfo { if jt.fi.Mi == ffi.RelModelInfo {
c2 = ffi.column c2 = ffi.Column
break break
} }
} }
default: default:
c1 = jt.fi.column c1 = jt.fi.Column
c2 = jt.fi.relModelInfo.fields.pk.column c2 = jt.fi.RelModelInfo.Fields.Pk.Column
if jt.fi.reverse { if jt.fi.Reverse {
c1 = jt.mi.fields.pk.column c1 = jt.mi.Fields.Pk.Column
c2 = jt.fi.reverseFieldInfo.column c2 = jt.fi.ReverseFieldInfo.Column
} }
} }
@ -213,11 +215,11 @@ func (t *dbTables) getJoinSQL() (join string) {
} }
// parse orm model struct field tag expression. // parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { func (t *dbTables) parseExprs(mi *models.ModelInfo, exprs []string) (index, name string, info *models.FieldInfo, success bool) {
var ( var (
jtl *dbTable jtl *dbTable
fi *fieldInfo fi *models.FieldInfo
fiN *fieldInfo fiN *models.FieldInfo
mmi = mi mmi = mi
) )
@ -238,38 +240,38 @@ loopFor:
} }
if i == 0 { if i == 0 {
fi, ok = mmi.fields.GetByAny(ex) fi, ok = mmi.Fields.GetByAny(ex)
} }
_ = okN _ = okN
if ok { if ok {
isRel := fi.rel || fi.reverse isRel := fi.Rel || fi.Reverse
names = append(names, fi.name) names = append(names, fi.Name)
switch { switch {
case fi.rel: case fi.Rel:
mmi = fi.relModelInfo mmi = fi.RelModelInfo
if fi.fieldType == RelManyToMany { if fi.FieldType == RelManyToMany {
mmi = fi.relThroughModelInfo mmi = fi.RelThroughModelInfo
} }
case fi.reverse: case fi.Reverse:
mmi = fi.reverseFieldInfo.mi mmi = fi.ReverseFieldInfo.Mi
} }
if i < num { if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1]) fiN, okN = mmi.Fields.GetByAny(exprs[i+1])
} }
if isRel && (!fi.mi.isThrough || num != i) { if isRel && (!fi.Mi.IsThrough || num != i) {
if fi.null || t.skipEnd { if fi.Null || t.skipEnd {
inner = false inner = false
} }
if t.skipEnd && okN || !t.skipEnd { if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk { if t.skipEnd && okN && fiN.Pk {
goto loopEnd goto loopEnd
} }
@ -295,20 +297,20 @@ loopFor:
info = fi info = fi
if jtl == nil { if jtl == nil {
name = fi.name name = fi.Name
} else { } else {
name = jtl.name + ExprSep + fi.name name = jtl.name + ExprSep + fi.Name
} }
switch { switch {
case fi.rel: case fi.Rel:
case fi.reverse: case fi.Reverse:
switch fi.reverseFieldInfo.fieldType { switch fi.ReverseFieldInfo.FieldType {
case RelOneToOne, RelForeignKey: case RelOneToOne, RelForeignKey:
index = jtl.index index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk info = fi.ReverseFieldInfo.Mi.Fields.Pk
name = info.name name = info.Name
} }
} }
@ -382,7 +384,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
} }
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSQL) where += fmt.Sprintf("%s %s ", leftCol, operSQL)
@ -415,7 +417,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
} }
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q))
} }
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
@ -449,7 +451,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
} }
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.Column, Q, order.SortString()))
} }
} }
@ -458,7 +460,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
} }
// generate limit sql. // generate limit sql.
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSQL(mi *models.ModelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
} }
@ -490,7 +492,7 @@ func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string)
} }
// crete new tables collection. // crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *models.ModelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable) tables.tablesM = make(map[string]*dbTable)
tables.mi = mi tables.mi = mi

View File

@ -41,9 +41,9 @@ func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
} }
// show columns sql of table for mysql. // show Columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string { func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table) "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
} }

View File

@ -18,6 +18,10 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// get table alias. // get table alias.
@ -29,32 +33,32 @@ func getDbAlias(name string) *alias {
} }
// get pk column info. // get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { func getExistPk(mi *models.ModelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk fi := mi.Fields.Pk
v := ind.FieldByIndex(fi.fieldIndex) v := ind.FieldByIndex(fi.FieldIndex)
if fi.fieldType&IsPositiveIntegerField > 0 { if fi.FieldType&IsPositiveIntegerField > 0 {
vu := v.Uint() vu := v.Uint()
exist = vu > 0 exist = vu > 0
value = vu value = vu
} else if fi.fieldType&IsIntegerField > 0 { } else if fi.FieldType&IsIntegerField > 0 {
vu := v.Int() vu := v.Int()
exist = true exist = true
value = vu value = vu
} else if fi.fieldType&IsRelField > 0 { } else if fi.FieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) _, value, exist = getExistPk(fi.RelModelInfo, reflect.Indirect(v))
} else { } else {
vu := v.String() vu := v.String()
exist = vu != "" exist = vu != ""
value = vu value = vu
} }
column = fi.column column = fi.Column
return return
} }
// get fields description as flatted string. // get Fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { func getFlatParams(fi *models.FieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor: outFor:
for _, arg := range args { for _, arg := range args {
if arg == nil { if arg == nil {
@ -74,32 +78,32 @@ outFor:
case reflect.String: case reflect.String:
v := val.String() v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.FieldType == TypeTimeField || fi.FieldType == TypeDateField || fi.FieldType == TypeDateTimeField {
var t time.Time var t time.Time
var err error var err error
if len(v) >= 19 { if len(v) >= 19 {
s := v[:19] s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) t, err = time.ParseInLocation(utils.FormatDateTime, s, DefaultTimeLoc)
} else if len(v) >= 10 { } else if len(v) >= 10 {
s := v s := v
if len(v) > 10 { if len(v) > 10 {
s = v[:10] s = v[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(utils.FormatDate, s, tz)
} else { } else {
s := v s := v
if len(s) > 8 { if len(s) > 8 {
s = v[:8] s = v[:8]
} }
t, err = time.ParseInLocation(formatTime, s, tz) t, err = time.ParseInLocation(utils.FormatTime, s, tz)
} }
if err == nil { if err == nil {
if fi.fieldType == TypeDateField { if fi.FieldType == TypeDateField {
v = t.In(tz).Format(formatDate) v = t.In(tz).Format(utils.FormatDate)
} else if fi.fieldType == TypeDateTimeField { } else if fi.FieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime) v = t.In(tz).Format(utils.FormatDateTime)
} else { } else {
v = t.In(tz).Format(formatTime) v = t.In(tz).Format(utils.FormatTime)
} }
} }
} }
@ -110,7 +114,7 @@ outFor:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
arg = val.Uint() arg = val.Uint()
case reflect.Float32: case reflect.Float32:
arg, _ = StrTo(ToStr(arg)).Float64() arg, _ = utils.StrTo(utils.ToStr(arg)).Float64()
case reflect.Float64: case reflect.Float64:
arg = val.Float() arg = val.Float()
case reflect.Bool: case reflect.Bool:
@ -143,18 +147,18 @@ outFor:
continue outFor continue outFor
case reflect.Struct: case reflect.Struct:
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.FieldType == TypeDateField {
arg = v.In(tz).Format(formatDate) arg = v.In(tz).Format(utils.FormatDate)
} else if fi != nil && fi.fieldType == TypeDateTimeField { } else if fi != nil && fi.FieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime) arg = v.In(tz).Format(utils.FormatDateTime)
} else if fi != nil && fi.fieldType == TypeTimeField { } else if fi != nil && fi.FieldType == TypeTimeField {
arg = v.In(tz).Format(formatTime) arg = v.In(tz).Format(utils.FormatTime)
} else { } else {
arg = v.In(tz).Format(formatDateTime) arg = v.In(tz).Format(utils.FormatDateTime)
} }
} else { } else {
typ := val.Type() typ := val.Type()
name := getFullName(typ) name := models.GetFullName(typ)
var value interface{} var value interface{}
if mmi, ok := defaultModelCache.getByFullName(name); ok { if mmi, ok := defaultModelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist { if _, vu, exist := getExistPk(mmi, val); exist {

View File

@ -20,6 +20,10 @@ import (
"reflect" "reflect"
"time" "time"
utils2 "github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/logs"
"github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/core/utils"
) )
@ -192,13 +196,13 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS
var ( var (
name string name string
md interface{} md interface{}
mi *modelInfo mi *models.ModelInfo
) )
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = table name = table
} else { } else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) name = models.GetFullName(utils2.IndirectType(reflect.TypeOf(ptrStructOrTableName)))
md = ptrStructOrTableName md = ptrStructOrTableName
} }
@ -303,7 +307,7 @@ func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, erro
func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
var ( var (
md interface{} md interface{}
mi *modelInfo mi *models.ModelInfo
) )
sind := reflect.Indirect(reflect.ValueOf(mds)) sind := reflect.Indirect(reflect.ValueOf(mds))

View File

@ -1,10 +1,10 @@
// Copyright 2020 // Copyright 2014 beego Author. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
@ -12,24 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package buffers
import ( import "github.com/valyala/bytebufferpool"
"reflect"
"testing"
"github.com/stretchr/testify/assert" var _ Buffer = &bytebufferpool.ByteBuffer{}
)
type NotApplicableModel struct { type Buffer interface {
Id int Write(p []byte) (int, error)
WriteString(s string) (int, error)
WriteByte(c byte) error
} }
func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { func Get() Buffer {
return db == "default" return bytebufferpool.Get()
} }
func TestIsApplicableTableForDB(t *testing.T) { func Put(bf Buffer) {
assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) bytebufferpool.Put(bf.(*bytebufferpool.ByteBuffer))
assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default"))
} }

View File

@ -0,0 +1,20 @@
package logs
import (
"io"
"log"
"os"
)
var DebugLog = NewLog(os.Stdout)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d
}

View File

@ -0,0 +1,785 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
import (
"fmt"
"strconv"
"time"
"github.com/beego/beego/v2/client/orm/internal/utils"
)
// Define the Type enum
const (
TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField
TypeTextField
TypeTimeField
TypeDateField
TypeDateTimeField
TypeBitField
TypeSmallIntegerField
TypeIntegerField
TypeBigIntegerField
TypePositiveBitField
TypePositiveSmallIntegerField
TypePositiveIntegerField
TypePositiveBigIntegerField
TypeFloatField
TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1
)
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := utils.StrTo(d).Bool()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
// CharField A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeVarCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder
var _ Fielder = new(CharField)
// TimeField A time, represented in go by a time.Time instance.
// only time values like 10:00:00
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField)
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatDate)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := utils.TimeParse(d, utils.FormatDateTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datetime implement fielder
var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return utils.ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Float64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Int64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return utils.ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := utils.StrTo(d).Uint64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder
var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)

View File

@ -12,147 +12,149 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package models
import ( import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"github.com/beego/beego/v2/client/orm/internal/utils"
) )
var errSkipField = errors.New("skip field") var errSkipField = errors.New("skip field")
// field info collection // Fields field info collection
type fields struct { type Fields struct {
pk *fieldInfo Pk *FieldInfo
columns map[string]*fieldInfo Columns map[string]*FieldInfo
fields map[string]*fieldInfo Fields map[string]*FieldInfo
fieldsLow map[string]*fieldInfo FieldsLow map[string]*FieldInfo
fieldsByType map[int][]*fieldInfo FieldsByType map[int][]*FieldInfo
fieldsRel []*fieldInfo FieldsRel []*FieldInfo
fieldsReverse []*fieldInfo FieldsReverse []*FieldInfo
fieldsDB []*fieldInfo FieldsDB []*FieldInfo
rels []*fieldInfo Rels []*FieldInfo
orders []string Orders []string
dbcols []string DBcols []string
} }
// add field info // Add adds field info
func (f *fields) Add(fi *fieldInfo) (added bool) { func (f *Fields) Add(fi *FieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil { if f.Fields[fi.Name] == nil && f.Columns[fi.Column] == nil {
f.columns[fi.column] = fi f.Columns[fi.Column] = fi
f.fields[fi.name] = fi f.Fields[fi.Name] = fi
f.fieldsLow[strings.ToLower(fi.name)] = fi f.FieldsLow[strings.ToLower(fi.Name)] = fi
} else { } else {
return return
} }
if _, ok := f.fieldsByType[fi.fieldType]; !ok { if _, ok := f.FieldsByType[fi.FieldType]; !ok {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) f.FieldsByType[fi.FieldType] = make([]*FieldInfo, 0)
} }
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) f.FieldsByType[fi.FieldType] = append(f.FieldsByType[fi.FieldType], fi)
f.orders = append(f.orders, fi.column) f.Orders = append(f.Orders, fi.Column)
if fi.dbcol { if fi.DBcol {
f.dbcols = append(f.dbcols, fi.column) f.DBcols = append(f.DBcols, fi.Column)
f.fieldsDB = append(f.fieldsDB, fi) f.FieldsDB = append(f.FieldsDB, fi)
} }
if fi.rel { if fi.Rel {
f.fieldsRel = append(f.fieldsRel, fi) f.FieldsRel = append(f.FieldsRel, fi)
} }
if fi.reverse { if fi.Reverse {
f.fieldsReverse = append(f.fieldsReverse, fi) f.FieldsReverse = append(f.FieldsReverse, fi)
} }
return true return true
} }
// get field info by name // GetByName get field info by name
func (f *fields) GetByName(name string) *fieldInfo { func (f *Fields) GetByName(name string) *FieldInfo {
return f.fields[name] return f.Fields[name]
} }
// get field info by column name // GetByColumn get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo { func (f *Fields) GetByColumn(column string) *FieldInfo {
return f.columns[column] return f.Columns[column]
} }
// get field info by string, name is prior // GetByAny get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) { func (f *Fields) GetByAny(name string) (*FieldInfo, bool) {
if fi, ok := f.fields[name]; ok { if fi, ok := f.Fields[name]; ok {
return fi, ok return fi, ok
} }
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { if fi, ok := f.FieldsLow[strings.ToLower(name)]; ok {
return fi, ok return fi, ok
} }
if fi, ok := f.columns[name]; ok { if fi, ok := f.Columns[name]; ok {
return fi, ok return fi, ok
} }
return nil, false return nil, false
} }
// create new field info collection // NewFields create new field info collection
func newFields() *fields { func NewFields() *Fields {
f := new(fields) f := new(Fields)
f.fields = make(map[string]*fieldInfo) f.Fields = make(map[string]*FieldInfo)
f.fieldsLow = make(map[string]*fieldInfo) f.FieldsLow = make(map[string]*FieldInfo)
f.columns = make(map[string]*fieldInfo) f.Columns = make(map[string]*FieldInfo)
f.fieldsByType = make(map[int][]*fieldInfo) f.FieldsByType = make(map[int][]*FieldInfo)
return f return f
} }
// single field info // FieldInfo single field info
type fieldInfo struct { type FieldInfo struct {
dbcol bool // table column fk and onetoone DBcol bool // table column fk and onetoone
inModel bool InModel bool
auto bool Auto bool
pk bool Pk bool
null bool Null bool
index bool Index bool
unique bool Unique bool
colDefault bool // whether has default tag ColDefault bool // whether has default tag
toText bool ToText bool
autoNow bool AutoNow bool
autoNowAdd bool AutoNowAdd bool
rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true Rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
reverse bool Reverse bool
isFielder bool // implement Fielder interface IsFielder bool // implement Fielder interface
mi *modelInfo Mi *ModelInfo
fieldIndex []int FieldIndex []int
fieldType int FieldType int
name string Name string
fullName string FullName string
column string Column string
addrValue reflect.Value AddrValue reflect.Value
sf reflect.StructField Sf reflect.StructField
initial StrTo // store the default value Initial utils.StrTo // store the default value
size int Size int
reverseField string ReverseField string
reverseFieldInfo *fieldInfo ReverseFieldInfo *FieldInfo
reverseFieldInfoTwo *fieldInfo ReverseFieldInfoTwo *FieldInfo
reverseFieldInfoM2M *fieldInfo ReverseFieldInfoM2M *FieldInfo
relTable string RelTable string
relThrough string RelThrough string
relThroughModelInfo *modelInfo RelThroughModelInfo *ModelInfo
relModelInfo *modelInfo RelModelInfo *ModelInfo
digits int Digits int
decimals int Decimals int
onDelete string OnDelete string
description string Description string
timePrecision *int TimePrecision *int
} }
// new field info // NewFieldInfo new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { func NewFieldInfo(mi *ModelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *FieldInfo, err error) {
var ( var (
tag string tag string
tagValue string tagValue string
initial StrTo // store the default value initial utils.StrTo // store the default value
fieldType int fieldType int
attrs map[string]bool attrs map[string]bool
tags map[string]string tags map[string]string
addrField reflect.Value addrField reflect.Value
) )
fi = new(fieldInfo) fi = new(FieldInfo)
// if field which CanAddr is the follow type // if field which CanAddr is the follow type
// A value is addressable if it is an element of a slice, // A value is addressable if it is an element of a slice,
@ -168,7 +170,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
} }
} }
attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) attrs, tags = ParseStructTag(sf.Tag.Get(DefaultStructTagName))
if _, ok := attrs["-"]; ok { if _, ok := attrs["-"]; ok {
return nil, errSkipField return nil, errSkipField
@ -187,7 +189,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
checkType: checkType:
switch f := addrField.Interface().(type) { switch f := addrField.Interface().(type) {
case Fielder: case Fielder:
fi.isFielder = true fi.IsFielder = true
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Ptr {
err = fmt.Errorf("the model Fielder can not be use ptr") err = fmt.Errorf("the model Fielder can not be use ptr")
goto end goto end
@ -211,9 +213,9 @@ checkType:
case "m2m": case "m2m":
fieldType = RelManyToMany fieldType = RelManyToMany
if tv := tags["rel_table"]; tv != "" { if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv fi.RelTable = tv
} else if tv := tags["rel_through"]; tv != "" { } else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv fi.RelThrough = tv
} }
break checkType break checkType
default: default:
@ -231,9 +233,9 @@ checkType:
case "many": case "many":
fieldType = RelReverseMany fieldType = RelReverseMany
if tv := tags["rel_table"]; tv != "" { if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv fi.RelTable = tv
} else if tv := tags["rel_through"]; tv != "" { } else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv fi.RelThrough = tv
} }
break checkType break checkType
default: default:
@ -295,117 +297,117 @@ checkType:
goto end goto end
} }
fi.fieldType = fieldType fi.FieldType = fieldType
fi.name = sf.Name fi.Name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) fi.Column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = addrField fi.AddrValue = addrField
fi.sf = sf fi.Sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name fi.FullName = mi.FullName + mName + "." + sf.Name
fi.description = tags["description"] fi.Description = tags["description"]
fi.null = attrs["null"] fi.Null = attrs["null"]
fi.index = attrs["index"] fi.Index = attrs["index"]
fi.auto = attrs["auto"] fi.Auto = attrs["auto"]
fi.pk = attrs["pk"] fi.Pk = attrs["pk"]
fi.unique = attrs["unique"] fi.Unique = attrs["unique"]
// Mark object property if there is attribute "default" in the orm configuration // Mark object property if there is attribute "default" in the orm configuration
if _, ok := tags["default"]; ok { if _, ok := tags["default"]; ok {
fi.colDefault = true fi.ColDefault = true
} }
switch fieldType { switch fieldType {
case RelManyToMany, RelReverseMany, RelReverseOne: case RelManyToMany, RelReverseMany, RelReverseOne:
fi.null = false fi.Null = false
fi.index = false fi.Index = false
fi.auto = false fi.Auto = false
fi.pk = false fi.Pk = false
fi.unique = false fi.Unique = false
default: default:
fi.dbcol = true fi.DBcol = true
} }
switch fieldType { switch fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany: case RelForeignKey, RelOneToOne, RelManyToMany:
fi.rel = true fi.Rel = true
if fieldType == RelOneToOne { if fieldType == RelOneToOne {
fi.unique = true fi.Unique = true
} }
case RelReverseMany, RelReverseOne: case RelReverseMany, RelReverseOne:
fi.reverse = true fi.Reverse = true
} }
if fi.rel && fi.dbcol { if fi.Rel && fi.DBcol {
switch onDelete { switch onDelete {
case odCascade, odDoNothing: case OdCascade, OdDoNothing:
case odSetDefault: case OdSetDefault:
if !initial.Exist() { if !initial.Exist() {
err = errors.New("on_delete: set_default need set field a default value") err = errors.New("on_delete: set_default need set field a default value")
goto end goto end
} }
case odSetNULL: case OdSetNULL:
if !fi.null { if !fi.Null {
err = errors.New("on_delete: set_null need set field null") err = errors.New("on_delete: set_null need set field null")
goto end goto end
} }
default: default:
if onDelete == "" { if onDelete == "" {
onDelete = odCascade onDelete = OdCascade
} else { } else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end goto end
} }
} }
fi.onDelete = onDelete fi.OnDelete = onDelete
} }
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" { if size != "" {
v, e := StrTo(size).Int32() v, e := utils.StrTo(size).Int32()
if e != nil { if e != nil {
err = fmt.Errorf("wrong size value `%s`", size) err = fmt.Errorf("wrong size value `%s`", size)
} else { } else {
fi.size = int(v) fi.Size = int(v)
} }
} else { } else {
fi.size = 255 fi.Size = 255
fi.toText = true fi.ToText = true
} }
case TypeTextField: case TypeTextField:
fi.index = false fi.Index = false
fi.unique = false fi.Unique = false
case TypeTimeField, TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if fieldType == TypeDateTimeField { if fieldType == TypeDateTimeField {
if precision != "" { if precision != "" {
v, e := StrTo(precision).Int() v, e := utils.StrTo(precision).Int()
if e != nil { if e != nil {
err = fmt.Errorf("convert %s to int error:%v", precision, e) err = fmt.Errorf("convert %s to int error:%v", precision, e)
} else { } else {
fi.timePrecision = &v fi.TimePrecision = &v
} }
} }
} }
if attrs["auto_now"] { if attrs["auto_now"] {
fi.autoNow = true fi.AutoNow = true
} else if attrs["auto_now_add"] { } else if attrs["auto_now_add"] {
fi.autoNowAdd = true fi.AutoNowAdd = true
} }
case TypeFloatField: case TypeFloatField:
case TypeDecimalField: case TypeDecimalField:
d1 := digits d1 := digits
d2 := decimals d2 := decimals
v1, er1 := StrTo(d1).Int8() v1, er1 := utils.StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int8() v2, er2 := utils.StrTo(d2).Int8()
if er1 != nil || er2 != nil { if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end goto end
} }
fi.digits = int(v1) fi.Digits = int(v1)
fi.decimals = int(v2) fi.Decimals = int(v2)
default: default:
switch { switch {
case fieldType&IsIntegerField > 0: case fieldType&IsIntegerField > 0:
@ -414,33 +416,33 @@ checkType:
} }
if fieldType&IsIntegerField == 0 { if fieldType&IsIntegerField == 0 {
if fi.auto { if fi.Auto {
err = fmt.Errorf("non-integer type cannot set auto") err = fmt.Errorf("non-integer type cannot set auto")
goto end goto end
} }
} }
if fi.auto || fi.pk { if fi.Auto || fi.Pk {
if fi.auto { if fi.Auto {
switch addrField.Elem().Kind() { switch addrField.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
default: default:
err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind())
goto end goto end
} }
fi.pk = true fi.Pk = true
} }
fi.null = false fi.Null = false
fi.index = false fi.Index = false
fi.unique = false fi.Unique = false
} }
if fi.unique { if fi.Unique {
fi.index = false fi.Index = false
} }
// can not set default for these type // can not set default for these type
if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { if fi.Auto || fi.Pk || fi.Unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
initial.Clear() initial.Clear()
} }
@ -474,7 +476,7 @@ checkType:
} }
} }
fi.initial = initial fi.Initial = initial
end: end:
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -0,0 +1,148 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
import (
"fmt"
"os"
"reflect"
)
// ModelInfo single model info
type ModelInfo struct {
Manual bool
IsThrough bool
Pkg string
Name string
FullName string
Table string
Model interface{}
Fields *Fields
AddrField reflect.Value // store the original struct value
Uniques []string
}
// NewModelInfo new model info
func NewModelInfo(val reflect.Value) (mi *ModelInfo) {
mi = &ModelInfo{}
mi.Fields = NewFields()
ind := reflect.Indirect(val)
mi.AddrField = val
mi.Name = ind.Type().Name()
mi.FullName = GetFullName(ind.Type())
AddModelFields(mi, ind, "", []int{})
return
}
// AddModelFields index: FieldByIndex returns the nested field corresponding to index
func AddModelFields(mi *ModelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *FieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct Fields
if sf.Anonymous {
AddModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = NewFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
// record current field index
fi.FieldIndex = append(fi.FieldIndex, index...)
fi.FieldIndex = append(fi.FieldIndex, i)
fi.Mi = mi
fi.InModel = true
if !mi.Fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.Column)
break
}
if fi.Pk {
if mi.Fields.Pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
mi.Fields.Pk = fi
}
}
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
}
// NewM2MModelInfo combine related model info to new model info.
// prepare for relation models query.
func NewM2MModelInfo(m1, m2 *ModelInfo) (mi *ModelInfo) {
mi = new(ModelInfo)
mi.Fields = NewFields()
mi.Table = m1.Table + "_" + m2.Table + "s"
mi.Name = CamelString(mi.Table)
mi.FullName = m1.Pkg + "." + mi.Name
fa := new(FieldInfo) // pk
f1 := new(FieldInfo) // m1 table RelForeignKey
f2 := new(FieldInfo) // m2 table RelForeignKey
fa.FieldType = TypeBigIntegerField
fa.Auto = true
fa.Pk = true
fa.DBcol = true
fa.Name = "Id"
fa.Column = "id"
fa.FullName = mi.FullName + "." + fa.Name
f1.DBcol = true
f2.DBcol = true
f1.FieldType = RelForeignKey
f2.FieldType = RelForeignKey
f1.Name = CamelString(m1.Table)
f2.Name = CamelString(m2.Table)
f1.FullName = mi.FullName + "." + f1.Name
f2.FullName = mi.FullName + "." + f2.Name
f1.Column = m1.Table + "_id"
f2.Column = m2.Table + "_id"
f1.Rel = true
f2.Rel = true
f1.RelTable = m1.Table
f2.RelTable = m2.Table
f1.RelModelInfo = m1
f2.RelModelInfo = m2
f1.Mi = mi
f2.Mi = mi
mi.Fields.Add(fa)
mi.Fields.Add(f1)
mi.Fields.Add(f2)
mi.Fields.Pk = fa
mi.Uniques = []string{f1.Column, f2.Column}
return
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package models
import ( import (
"database/sql" "database/sql"
@ -20,6 +20,8 @@ import (
"reflect" "reflect"
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/logs"
) )
// 1 is attr // 1 is attr
@ -48,15 +50,29 @@ var supportTag = map[string]int{
"precision": 2, "precision": 2,
} }
// get reflect.Type name with package path. type fn func(string) string
func getFullName(typ reflect.Type) string {
var (
NameStrategyMap = map[string]fn{
DefaultNameStrategy: SnakeString,
SnakeAcronymNameStrategy: SnakeStringWithAcronym,
}
DefaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
NameStrategy = DefaultNameStrategy
defaultStructTagDelim = ";"
DefaultStructTagName = "orm"
)
// GetFullName get reflect.Type name with package path.
func GetFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name() return typ.PkgPath() + "." + typ.Name()
} }
// getTableName get struct table name. // GetTableName get struct table name.
// If the struct implement the TableName, then get the result as tablename // If the struct implement the TableName, then get the result as tablename
// else use the struct name which will apply snakeString. // else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string { func GetTableName(val reflect.Value) string {
if fun := val.MethodByName("TableName"); fun.IsValid() { if fun := val.MethodByName("TableName"); fun.IsValid() {
vals := fun.Call([]reflect.Value{}) vals := fun.Call([]reflect.Value{})
// has return and the first val is string // has return and the first val is string
@ -64,11 +80,11 @@ func getTableName(val reflect.Value) string {
return vals[0].String() return vals[0].String()
} }
} }
return snakeString(reflect.Indirect(val).Type().Name()) return SnakeString(reflect.Indirect(val).Type().Name())
} }
// get table engine, myisam or innodb. // GetTableEngine get table engine, myisam or innodb.
func getTableEngine(val reflect.Value) string { func GetTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine") fun := val.MethodByName("TableEngine")
if fun.IsValid() { if fun.IsValid() {
vals := fun.Call([]reflect.Value{}) vals := fun.Call([]reflect.Value{})
@ -79,8 +95,8 @@ func getTableEngine(val reflect.Value) string {
return "" return ""
} }
// get table index from method. // GetTableIndex get table index from method.
func getTableIndex(val reflect.Value) [][]string { func GetTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex") fun := val.MethodByName("TableIndex")
if fun.IsValid() { if fun.IsValid() {
vals := fun.Call([]reflect.Value{}) vals := fun.Call([]reflect.Value{})
@ -93,8 +109,8 @@ func getTableIndex(val reflect.Value) [][]string {
return nil return nil
} }
// get table unique from method // GetTableUnique get table unique from method
func getTableUnique(val reflect.Value) [][]string { func GetTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique") fun := val.MethodByName("TableUnique")
if fun.IsValid() { if fun.IsValid() {
vals := fun.Call([]reflect.Value{}) vals := fun.Call([]reflect.Value{})
@ -107,8 +123,8 @@ func getTableUnique(val reflect.Value) [][]string {
return nil return nil
} }
// get whether the table needs to be created for the database alias // IsApplicableTableForDB get whether the table needs to be created for the database alias
func isApplicableTableForDB(val reflect.Value, db string) bool { func IsApplicableTableForDB(val reflect.Value, db string) bool {
if !val.IsValid() { if !val.IsValid() {
return true return true
} }
@ -126,7 +142,7 @@ func isApplicableTableForDB(val reflect.Value, db string) bool {
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := col column := col
if col == "" { if col == "" {
column = nameStrategyMap[nameStrategy](sf.Name) column = NameStrategyMap[NameStrategy](sf.Name)
} }
switch ft { switch ft {
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
@ -218,8 +234,8 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return return
} }
// parse struct tag string // ParseStructTag parse struct tag string
func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { func ParseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
attrs = make(map[string]bool) attrs = make(map[string]bool)
tags = make(map[string]string) tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) { for _, v := range strings.Split(data, defaultStructTagDelim) {
@ -236,8 +252,74 @@ func parseStructTag(data string) (attrs map[string]bool, tags map[string]string)
tags[name] = v tags[name] = v
} }
} else { } else {
DebugLog.Println("unsupport orm tag", v) logs.DebugLog.Println("unsupport orm tag", v)
} }
} }
return return
} }
func SnakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// SnakeString snake string, XxYy to xx_yy , XxYY to xx_y_y
func SnakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// CamelString camel string, xx_yy to XxYy
func CamelString(s string) string {
data := make([]byte, 0, len(s))
flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
if d == '_' {
flag = true
continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
}
data = append(data, d)
}
return string(data)
}
const (
OdCascade = "cascade"
OdSetNULL = "set_null"
OdSetDefault = "set_default"
OdDoNothing = "do_nothing"
)

View File

@ -1,10 +1,10 @@
// Copyright 2014 beego Author. All Rights Reserved. // Copyright 2020
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
@ -12,27 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package orm package models
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestCamelString(t *testing.T) { type NotApplicableModel struct {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} Id int
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} }
answer := make(map[string]string) func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool {
for i, v := range snake { return db == "default"
answer[v] = camel[i] }
}
for _, v := range snake { func TestIsApplicableTableForDB(t *testing.T) {
res := camelString(v) assert.False(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa"))
if res != answer[v] { assert.True(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default"))
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
} }
func TestSnakeString(t *testing.T) { func TestSnakeString(t *testing.T) {
@ -45,7 +44,7 @@ func TestSnakeString(t *testing.T) {
} }
for _, v := range camel { for _, v := range camel {
res := snakeString(v) res := SnakeString(v)
if res != answer[v] { if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v]) t.Error("Unit Test Fail:", v, res, answer[v])
} }
@ -62,7 +61,24 @@ func TestSnakeStringWithAcronym(t *testing.T) {
} }
for _, v := range camel { for _, v := range camel {
res := snakeStringWithAcronym(v) res := SnakeStringWithAcronym(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
for _, v := range snake {
res := CamelString(v)
if res != answer[v] { if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v]) t.Error("Unit Test Fail:", v, res, answer[v])
} }

View File

@ -0,0 +1,23 @@
// Copyright 2023 beego-dev. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package models
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}

View File

@ -0,0 +1,249 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"fmt"
"math/big"
"reflect"
"strconv"
"time"
)
// StrTo is the target string
type StrTo string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(rune(0x1E))
}
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(rune(0x1E))
}
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
return v, err
}
return ni.Int64(), nil
}
return v, err
}
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// Uint32 string to uint32
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
return v, err
}
return ni.Uint64(), nil
}
return v, err
}
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, ArgInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, ArgInt(args).Get(0, 10))
case string:
s = v
case []byte:
s = string(v)
default:
s = fmt.Sprintf("%v", v)
}
return s
}
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
}
return
}
type ArgString []string
// Get get string by index from string slice
func (a ArgString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type ArgInt []int
// Get get int by index from int slice
func (a ArgInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// TimeParse parse time to string with location
func TimeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// IndirectType get pointer indirect type
func IndirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
return IndirectType(v.Elem())
default:
return v
}
}
const (
FormatTime = "15:04:05"
FormatDate = "2006-01-02"
FormatDateTime = "2006-01-02 15:04:05"
)
var (
DefaultTimeLoc = time.Local
)

View File

@ -17,6 +17,8 @@ package orm
import ( import (
"context" "context"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// Invocation represents an "Orm" invocation // Invocation represents an "Orm" invocation
@ -27,7 +29,7 @@ type Invocation struct {
// the args are all arguments except context.Context // the args are all arguments except context.Context
Args []interface{} Args []interface{}
mi *modelInfo mi *models.ModelInfo
// f is the Orm operation // f is the Orm operation
f func(ctx context.Context) []interface{} f func(ctx context.Context) []interface{}
@ -39,7 +41,7 @@ type Invocation struct {
func (inv *Invocation) GetTableName() string { func (inv *Invocation) GetTableName() string {
if inv.mi != nil { if inv.mi != nil {
return inv.mi.table return inv.mi.Table
} }
return "" return ""
} }
@ -51,8 +53,8 @@ func (inv *Invocation) execute(ctx context.Context) []interface{} {
// GetPkFieldName return the primary key of this table // GetPkFieldName return the primary key of this table
// if not found, "" is returned // if not found, "" is returned
func (inv *Invocation) GetPkFieldName() string { func (inv *Invocation) GetPkFieldName() string {
if inv.mi.fields.pk != nil { if inv.mi.Fields.Pk != nil {
return inv.mi.fields.pk.name return inv.mi.Fields.Pk.Name
} }
return "" return ""
} }

View File

@ -17,6 +17,8 @@ package orm
import ( import (
"testing" "testing"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -53,10 +55,10 @@ func TestDbBase_GetTables(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
assert.NotNil(t, mi) assert.NotNil(t, mi)
engine := getTableEngine(mi.addrField) engine := models.GetTableEngine(mi.AddrField)
assert.Equal(t, "innodb", engine) assert.Equal(t, "innodb", engine)
uniques := getTableUnique(mi.addrField) uniques := models.GetTableUnique(mi.AddrField)
assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques) assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques)
indexes := getTableIndex(mi.addrField) indexes := models.GetTableIndex(mi.AddrField)
assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes) assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes)
} }

View File

@ -21,15 +21,8 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync" "sync"
)
const ( imodels "github.com/beego/beego/v2/client/orm/internal/models"
odCascade = "cascade"
odSetNULL = "set_null"
odSetDefault = "set_default"
odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
) )
var defaultModelCache = NewModelCacheHandler() var defaultModelCache = NewModelCacheHandler()
@ -38,22 +31,22 @@ var defaultModelCache = NewModelCacheHandler()
type modelCache struct { type modelCache struct {
sync.RWMutex // only used outsite for bootStrap sync.RWMutex // only used outsite for bootStrap
orders []string orders []string
cache map[string]*modelInfo cache map[string]*imodels.ModelInfo
cacheByFullName map[string]*modelInfo cacheByFullName map[string]*imodels.ModelInfo
done bool done bool
} }
// NewModelCacheHandler generator of modelCache // NewModelCacheHandler generator of modelCache
func NewModelCacheHandler() *modelCache { func NewModelCacheHandler() *modelCache {
return &modelCache{ return &modelCache{
cache: make(map[string]*modelInfo), cache: make(map[string]*imodels.ModelInfo),
cacheByFullName: make(map[string]*modelInfo), cacheByFullName: make(map[string]*imodels.ModelInfo),
} }
} }
// get all model info // get all model info
func (mc *modelCache) all() map[string]*modelInfo { func (mc *modelCache) all() map[string]*imodels.ModelInfo {
m := make(map[string]*modelInfo, len(mc.cache)) m := make(map[string]*imodels.ModelInfo, len(mc.cache))
for k, v := range mc.cache { for k, v := range mc.cache {
m[k] = v m[k] = v
} }
@ -61,8 +54,8 @@ func (mc *modelCache) all() map[string]*modelInfo {
} }
// get ordered model info // get ordered model info
func (mc *modelCache) allOrdered() []*modelInfo { func (mc *modelCache) allOrdered() []*imodels.ModelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*imodels.ModelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {
m = append(m, mc.cache[table]) m = append(m, mc.cache[table])
} }
@ -70,30 +63,30 @@ func (mc *modelCache) allOrdered() []*modelInfo {
} }
// get model info by table name // get model info by table name
func (mc *modelCache) get(table string) (mi *modelInfo, ok bool) { func (mc *modelCache) get(table string) (mi *imodels.ModelInfo, ok bool) {
mi, ok = mc.cache[table] mi, ok = mc.cache[table]
return return
} }
// get model info by full name // get model info by full name
func (mc *modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { func (mc *modelCache) getByFullName(name string) (mi *imodels.ModelInfo, ok bool) {
mi, ok = mc.cacheByFullName[name] mi, ok = mc.cacheByFullName[name]
return return
} }
func (mc *modelCache) getByMd(md interface{}) (*modelInfo, bool) { func (mc *modelCache) getByMd(md interface{}) (*imodels.ModelInfo, bool) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
name := getFullName(typ) name := imodels.GetFullName(typ)
return mc.getByFullName(name) return mc.getByFullName(name)
} }
// set model info to collection // set model info to collection
func (mc *modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *modelCache) set(table string, mi *imodels.ModelInfo) *imodels.ModelInfo {
mii := mc.cache[table] mii := mc.cache[table]
mc.cache[table] = mi mc.cache[table] = mi
mc.cacheByFullName[mi.fullName] = mi mc.cacheByFullName[mi.FullName] = mi
if mii == nil { if mii == nil {
mc.orders = append(mc.orders, table) mc.orders = append(mc.orders, table)
} }
@ -106,8 +99,8 @@ func (mc *modelCache) clean() {
defer mc.Unlock() defer mc.Unlock()
mc.orders = make([]string, 0) mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo) mc.cache = make(map[string]*imodels.ModelInfo)
mc.cacheByFullName = make(map[string]*modelInfo) mc.cacheByFullName = make(map[string]*imodels.ModelInfo)
mc.done = false mc.done = false
} }
@ -120,7 +113,7 @@ func (mc *modelCache) bootstrap() {
} }
var ( var (
err error err error
models map[string]*modelInfo models map[string]*imodels.ModelInfo
) )
if dataBaseCache.getDefault() == nil { if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register DataBase alias named `default`") err = fmt.Errorf("must have one register DataBase alias named `default`")
@ -131,51 +124,51 @@ func (mc *modelCache) bootstrap() {
// RelManyToMany set the relTable // RelManyToMany set the relTable
models = mc.all() models = mc.all()
for _, mi := range models { for _, mi := range models {
for _, fi := range mi.fields.columns { for _, fi := range mi.Fields.Columns {
if fi.rel || fi.reverse { if fi.Rel || fi.Reverse {
elm := fi.addrValue.Type().Elem() elm := fi.AddrValue.Type().Elem()
if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { if fi.FieldType == RelReverseMany || fi.FieldType == RelManyToMany {
elm = elm.Elem() elm = elm.Elem()
} }
// check the rel or reverse model already register // check the rel or reverse model already register
name := getFullName(elm) name := imodels.GetFullName(elm)
mii, ok := mc.getByFullName(name) mii, ok := mc.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() { if !ok || mii.Pkg != elm.PkgPath() {
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.FullName, elm.String())
goto end goto end
} }
fi.relModelInfo = mii fi.RelModelInfo = mii
switch fi.fieldType { switch fi.FieldType {
case RelManyToMany: case RelManyToMany:
if fi.relThrough != "" { if fi.RelThrough != "" {
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { if i := strings.LastIndex(fi.RelThrough, "."); i != -1 && len(fi.RelThrough) > (i+1) {
pn := fi.relThrough[:i] pn := fi.RelThrough[:i]
rmi, ok := mc.getByFullName(fi.relThrough) rmi, ok := mc.getByFullName(fi.RelThrough)
if !ok || pn != rmi.pkg { if !ok || pn != rmi.Pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.FullName, fi.RelThrough)
goto end goto end
} }
fi.relThroughModelInfo = rmi fi.RelThroughModelInfo = rmi
fi.relTable = rmi.table fi.RelTable = rmi.Table
} else { } else {
err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.FullName, fi.RelThrough)
goto end goto end
} }
} else { } else {
i := newM2MModelInfo(mi, mii) i := imodels.NewM2MModelInfo(mi, mii)
if fi.relTable != "" { if fi.RelTable != "" {
i.table = fi.relTable i.Table = fi.RelTable
} }
if v := mc.set(i.table, i); v != nil { if v := mc.set(i.Table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.RelTable)
goto end goto end
} }
fi.relTable = i.table fi.RelTable = i.Table
fi.relThroughModelInfo = i fi.RelThroughModelInfo = i
} }
fi.relThroughModelInfo.isThrough = true fi.RelThroughModelInfo.IsThrough = true
} }
} }
} }
@ -185,42 +178,42 @@ func (mc *modelCache) bootstrap() {
// if not exist, add a new field to the relModelInfo // if not exist, add a new field to the relModelInfo
models = mc.all() models = mc.all()
for _, mi := range models { for _, mi := range models {
for _, fi := range mi.fields.fieldsRel { for _, fi := range mi.Fields.FieldsRel {
switch fi.fieldType { switch fi.FieldType {
case RelForeignKey, RelOneToOne, RelManyToMany: case RelForeignKey, RelOneToOne, RelManyToMany:
inModel := false inModel := false
for _, ffi := range fi.relModelInfo.fields.fieldsReverse { for _, ffi := range fi.RelModelInfo.Fields.FieldsReverse {
if ffi.relModelInfo == mi { if ffi.RelModelInfo == mi {
inModel = true inModel = true
break break
} }
} }
if !inModel { if !inModel {
rmi := fi.relModelInfo rmi := fi.RelModelInfo
ffi := new(fieldInfo) ffi := new(imodels.FieldInfo)
ffi.name = mi.name ffi.Name = mi.Name
ffi.column = ffi.name ffi.Column = ffi.Name
ffi.fullName = rmi.fullName + "." + ffi.name ffi.FullName = rmi.FullName + "." + ffi.Name
ffi.reverse = true ffi.Reverse = true
ffi.relModelInfo = mi ffi.RelModelInfo = mi
ffi.mi = rmi ffi.Mi = rmi
if fi.fieldType == RelOneToOne { if fi.FieldType == RelOneToOne {
ffi.fieldType = RelReverseOne ffi.FieldType = RelReverseOne
} else { } else {
ffi.fieldType = RelReverseMany ffi.FieldType = RelReverseMany
} }
if !rmi.fields.Add(ffi) { if !rmi.Fields.Add(ffi) {
added := false added := false
for cnt := 0; cnt < 5; cnt++ { for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) ffi.Name = fmt.Sprintf("%s%d", mi.Name, cnt)
ffi.column = ffi.name ffi.Column = ffi.Name
ffi.fullName = rmi.fullName + "." + ffi.name ffi.FullName = rmi.FullName + "." + ffi.Name
if added = rmi.fields.Add(ffi); added { if added = rmi.Fields.Add(ffi); added {
break break
} }
} }
if !added { if !added {
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.FullName, ffi.FullName))
} }
} }
} }
@ -230,24 +223,24 @@ func (mc *modelCache) bootstrap() {
models = mc.all() models = mc.all()
for _, mi := range models { for _, mi := range models {
for _, fi := range mi.fields.fieldsRel { for _, fi := range mi.Fields.FieldsRel {
switch fi.fieldType { switch fi.FieldType {
case RelManyToMany: case RelManyToMany:
for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { for _, ffi := range fi.RelThroughModelInfo.Fields.FieldsRel {
switch ffi.fieldType { switch ffi.FieldType {
case RelOneToOne, RelForeignKey: case RelOneToOne, RelForeignKey:
if ffi.relModelInfo == fi.relModelInfo { if ffi.RelModelInfo == fi.RelModelInfo {
fi.reverseFieldInfoTwo = ffi fi.ReverseFieldInfoTwo = ffi
} }
if ffi.relModelInfo == mi { if ffi.RelModelInfo == mi {
fi.reverseField = ffi.name fi.ReverseField = ffi.Name
fi.reverseFieldInfo = ffi fi.ReverseFieldInfo = ffi
} }
} }
} }
if fi.reverseFieldInfoTwo == nil { if fi.ReverseFieldInfoTwo == nil {
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
fi.relThroughModelInfo.fullName) fi.RelThroughModelInfo.FullName)
goto end goto end
} }
} }
@ -256,63 +249,63 @@ func (mc *modelCache) bootstrap() {
models = mc.all() models = mc.all()
for _, mi := range models { for _, mi := range models {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.Fields.FieldsReverse {
switch fi.fieldType { switch fi.FieldType {
case RelReverseOne: case RelReverseOne:
found := false found := false
mForA: mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelOneToOne] {
if ffi.relModelInfo == mi { if ffi.RelModelInfo == mi {
found = true found = true
fi.reverseField = ffi.name fi.ReverseField = ffi.Name
fi.reverseFieldInfo = ffi fi.ReverseFieldInfo = ffi
ffi.reverseField = fi.name ffi.ReverseField = fi.Name
ffi.reverseFieldInfo = fi ffi.ReverseFieldInfo = fi
break mForA break mForA
} }
} }
if !found { if !found {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName)
goto end goto end
} }
case RelReverseMany: case RelReverseMany:
found := false found := false
mForB: mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelForeignKey] {
if ffi.relModelInfo == mi { if ffi.RelModelInfo == mi {
found = true found = true
fi.reverseField = ffi.name fi.ReverseField = ffi.Name
fi.reverseFieldInfo = ffi fi.ReverseFieldInfo = ffi
ffi.reverseField = fi.name ffi.ReverseField = fi.Name
ffi.reverseFieldInfo = fi ffi.ReverseFieldInfo = fi
break mForB break mForB
} }
} }
if !found { if !found {
mForC: mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelManyToMany] {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || conditions := fi.RelThrough != "" && fi.RelThrough == ffi.RelThrough ||
fi.relTable != "" && fi.relTable == ffi.relTable || fi.RelTable != "" && fi.RelTable == ffi.RelTable ||
fi.relThrough == "" && fi.relTable == "" fi.RelThrough == "" && fi.RelTable == ""
if ffi.relModelInfo == mi && conditions { if ffi.RelModelInfo == mi && conditions {
found = true found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name fi.ReverseField = ffi.ReverseFieldInfoTwo.Name
fi.reverseFieldInfo = ffi.reverseFieldInfoTwo fi.ReverseFieldInfo = ffi.ReverseFieldInfoTwo
fi.relThroughModelInfo = ffi.relThroughModelInfo fi.RelThroughModelInfo = ffi.RelThroughModelInfo
fi.reverseFieldInfoTwo = ffi.reverseFieldInfo fi.ReverseFieldInfoTwo = ffi.ReverseFieldInfo
fi.reverseFieldInfoM2M = ffi fi.ReverseFieldInfoM2M = ffi
ffi.reverseFieldInfoM2M = fi ffi.ReverseFieldInfoM2M = fi
break mForC break mForC
} }
} }
} }
if !found { if !found {
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName)
goto end goto end
} }
} }
@ -334,7 +327,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
typ := reflect.Indirect(val).Type() typ := reflect.Indirect(val).Type()
if val.Kind() != reflect.Ptr { if val.Kind() != reflect.Ptr {
err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)) err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", imodels.GetFullName(typ))
return return
} }
// For this case: // For this case:
@ -347,7 +340,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
if val.Elem().Kind() == reflect.Slice { if val.Elem().Kind() == reflect.Slice {
val = reflect.New(val.Elem().Type().Elem()) val = reflect.New(val.Elem().Type().Elem())
} }
table := getTableName(val) table := imodels.GetTableName(val)
if prefixOrSuffixStr != "" { if prefixOrSuffixStr != "" {
if prefixOrSuffix { if prefixOrSuffix {
@ -358,7 +351,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
} }
// models's fullname is pkgpath + struct name // models's fullname is pkgpath + struct name
name := getFullName(typ) name := imodels.GetFullName(typ)
if _, ok := mc.getByFullName(name); ok { if _, ok := mc.getByFullName(name); ok {
err = fmt.Errorf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name) err = fmt.Errorf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
return return
@ -368,26 +361,26 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo
return nil return nil
} }
mi := newModelInfo(val) mi := imodels.NewModelInfo(val)
if mi.fields.pk == nil { if mi.Fields.Pk == nil {
outFor: outFor:
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.Fields.FieldsDB {
if strings.ToLower(fi.name) == "id" { if strings.ToLower(fi.Name) == "id" {
switch fi.addrValue.Elem().Kind() { switch fi.AddrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true fi.Auto = true
fi.pk = true fi.Pk = true
mi.fields.pk = fi mi.Fields.Pk = fi
break outFor break outFor
} }
} }
} }
} }
mi.table = table mi.Table = table
mi.pkg = typ.PkgPath() mi.Pkg = typ.PkgPath()
mi.model = model mi.Model = model
mi.manual = true mi.Manual = true
mc.set(table, mi) mc.set(table, mi)
} }
@ -404,7 +397,7 @@ func (mc *modelCache) getDbDropSQL(al *alias) (queries []string, err error) {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
for _, mi := range mc.allOrdered() { for _, mi := range mc.allOrdered() {
queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.Table, Q))
} }
return queries, nil return queries, nil
} }
@ -424,33 +417,33 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
for _, mi := range mc.allOrdered() { for _, mi := range mc.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.FullName)
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.Table, Q)
columns := make([]string, 0, len(mi.fields.fieldsDB)) columns := make([]string, 0, len(mi.Fields.FieldsDB))
sqlIndexes := [][]string{} sqlIndexes := [][]string{}
var commentIndexes []int // store comment indexes for postgres var commentIndexes []int // store comment indexes for postgres
for i, fi := range mi.fields.fieldsDB { for i, fi := range mi.Fields.FieldsDB {
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) column := fmt.Sprintf(" %s%s%s ", Q, fi.Column, Q)
col := getColumnTyp(al, fi) col := getColumnTyp(al, fi)
if fi.auto { if fi.Auto {
switch al.Driver { switch al.Driver {
case DRSqlite, DRPostgres: case DRSqlite, DRPostgres:
column += T["auto"] column += T["auto"]
default: default:
column += col + " " + T["auto"] column += col + " " + T["auto"]
} }
} else if fi.pk { } else if fi.Pk {
column += col + " " + T["pk"] column += col + " " + T["pk"]
} else { } else {
column += col column += col
if !fi.null { if !fi.Null {
column += " " + "NOT NULL" column += " " + "NOT NULL"
} }
@ -461,42 +454,42 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
// Append attribute DEFAULT // Append attribute DEFAULT
column += getColumnDefault(fi) column += getColumnDefault(fi)
if fi.unique { if fi.Unique {
column += " " + "UNIQUE" column += " " + "UNIQUE"
} }
if fi.index { if fi.Index {
sqlIndexes = append(sqlIndexes, []string{fi.column}) sqlIndexes = append(sqlIndexes, []string{fi.Column})
} }
} }
if strings.Contains(column, "%COL%") { if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1) column = strings.Replace(column, "%COL%", fi.Column, -1)
} }
if fi.description != "" && al.Driver != DRSqlite { if fi.Description != "" && al.Driver != DRSqlite {
if al.Driver == DRPostgres { if al.Driver == DRPostgres {
commentIndexes = append(commentIndexes, i) commentIndexes = append(commentIndexes, i)
} else { } else {
column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) column += " " + fmt.Sprintf("COMMENT '%s'", fi.Description)
} }
} }
columns = append(columns, column) columns = append(columns, column)
} }
if mi.model != nil { if mi.Model != nil {
allnames := getTableUnique(mi.addrField) allnames := imodels.GetTableUnique(mi.AddrField)
if !mi.manual && len(mi.uniques) > 0 { if !mi.Manual && len(mi.Uniques) > 0 {
allnames = append(allnames, mi.uniques) allnames = append(allnames, mi.Uniques)
} }
for _, names := range allnames { for _, names := range allnames {
cols := make([]string, 0, len(names)) cols := make([]string, 0, len(names))
for _, name := range names { for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.column) cols = append(cols, fi.Column)
} else { } else {
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.FullName))
} }
} }
column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
@ -509,8 +502,8 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
if al.Driver == DRMySQL { if al.Driver == DRMySQL {
var engine string var engine string
if mi.model != nil { if mi.Model != nil {
engine = getTableEngine(mi.addrField) engine = imodels.GetTableEngine(mi.AddrField)
} }
if engine == "" { if engine == "" {
engine = al.Engine engine = al.Engine
@ -524,24 +517,24 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
for _, index := range commentIndexes { for _, index := range commentIndexes {
sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';", sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';",
Q, Q,
mi.table, mi.Table,
Q, Q,
Q, Q,
mi.fields.fieldsDB[index].column, mi.Fields.FieldsDB[index].Column,
Q, Q,
mi.fields.fieldsDB[index].description) mi.Fields.FieldsDB[index].Description)
} }
} }
queries = append(queries, sql) queries = append(queries, sql)
if mi.model != nil { if mi.Model != nil {
for _, names := range getTableIndex(mi.addrField) { for _, names := range imodels.GetTableIndex(mi.AddrField) {
cols := make([]string, 0, len(names)) cols := make([]string, 0, len(names))
for _, name := range names { for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.column) cols = append(cols, fi.Column)
} else { } else {
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.FullName))
} }
} }
sqlIndexes = append(sqlIndexes, cols) sqlIndexes = append(sqlIndexes, cols)
@ -549,16 +542,16 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes
} }
for _, names := range sqlIndexes { for _, names := range sqlIndexes {
name := mi.table + "_" + strings.Join(names, "_") name := mi.Table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep) cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.Table, Q, Q, cols, Q)
index := dbIndex{} index := dbIndex{}
index.Table = mi.table index.Table = mi.Table
index.Name = name index.Name = name
index.SQL = sql index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index) tableIndexes[mi.Table] = append(tableIndexes[mi.Table], index)
} }
} }

View File

@ -15,91 +15,47 @@
package orm package orm
import ( import (
"fmt" "github.com/beego/beego/v2/client/orm/internal/models"
"strconv"
"time"
) )
// Define the Type enum // Define the Type enum
const ( const (
TypeBooleanField = 1 << iota TypeBooleanField = models.TypeBooleanField
TypeVarCharField TypeVarCharField = models.TypeVarCharField
TypeCharField TypeCharField = models.TypeCharField
TypeTextField TypeTextField = models.TypeTextField
TypeTimeField TypeTimeField = models.TypeTimeField
TypeDateField TypeDateField = models.TypeDateField
TypeDateTimeField TypeDateTimeField = models.TypeDateTimeField
TypeBitField TypeBitField = models.TypeBitField
TypeSmallIntegerField TypeSmallIntegerField = models.TypeSmallIntegerField
TypeIntegerField TypeIntegerField = models.TypeIntegerField
TypeBigIntegerField TypeBigIntegerField = models.TypeBigIntegerField
TypePositiveBitField TypePositiveBitField = models.TypePositiveBitField
TypePositiveSmallIntegerField TypePositiveSmallIntegerField = models.TypePositiveSmallIntegerField
TypePositiveIntegerField TypePositiveIntegerField = models.TypePositiveIntegerField
TypePositiveBigIntegerField TypePositiveBigIntegerField = models.TypePositiveBigIntegerField
TypeFloatField TypeFloatField = models.TypeFloatField
TypeDecimalField TypeDecimalField = models.TypeDecimalField
TypeJSONField TypeJSONField = models.TypeJSONField
TypeJsonbField TypeJsonbField = models.TypeJsonbField
RelForeignKey RelForeignKey = models.RelForeignKey
RelOneToOne RelOneToOne = models.RelOneToOne
RelManyToMany RelManyToMany = models.RelManyToMany
RelReverseOne RelReverseOne = models.RelReverseOne
RelReverseMany RelReverseMany = models.RelReverseMany
) )
// Define some logic enum // Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 IsIntegerField = models.IsIntegerField
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 IsPositiveIntegerField = models.IsPositiveIntegerField
IsRelField = ^-RelReverseMany >> 18 << 19 IsRelField = models.IsRelField
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = models.IsFieldType
) )
// BooleanField A true/false field. // BooleanField A true/false field.
type BooleanField bool type BooleanField = models.BooleanField
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface // verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField) var _ Fielder = new(BooleanField)
@ -108,43 +64,7 @@ var _ Fielder = new(BooleanField)
// required values tag: size // required values tag: size
// The size is enforced at the database level and in modelss validation. // The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"` // eg: `orm:"size(120)"`
type CharField string type CharField = models.CharField
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeVarCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder // verify CharField implement Fielder
var _ Fielder = new(CharField) var _ Fielder = new(CharField)
@ -162,49 +82,7 @@ var _ Fielder = new(CharField)
// Note that the current date is always used; its not just a default value that you can override. // Note that the current date is always used; its not just a default value that you can override.
// //
// eg: `orm:"auto_now"` or `orm:"auto_now_add"` // eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time type TimeField = models.TimeField
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField) var _ Fielder = new(TimeField)
@ -221,49 +99,7 @@ var _ Fielder = new(TimeField)
// Note that the current date is always used; its not just a default value that you can override. // Note that the current date is always used; its not just a default value that you can override.
// //
// eg: `orm:"auto_now"` or `orm:"auto_now_add"` // eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time type DateField = models.DateField
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDate)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface // verify DateField implement fielder interface
var _ Fielder = new(DateField) var _ Fielder = new(DateField)
@ -271,513 +107,67 @@ var _ Fielder = new(DateField)
// DateTimeField A date, represented in go by a time.Time instance. // DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05 // datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField. // Takes the same extra arguments as DateField.
type DateTimeField time.Time type DateTimeField = models.DateTimeField
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDateTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datetime implement fielder // verify datetime implement fielder
var _ Fielder = new(DateTimeField) var _ models.Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value. // FloatField A floating-point number represented in go by a float32 value.
type FloatField float64 type FloatField = models.FloatField
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder // verify FloatField implement Fielder
var _ Fielder = new(FloatField) var _ Fielder = new(FloatField)
// SmallIntegerField -32768 to 32767 // SmallIntegerField -32768 to 32767
type SmallIntegerField int16 type SmallIntegerField = models.SmallIntegerField
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder // verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField) var _ Fielder = new(SmallIntegerField)
// IntegerField -2147483648 to 2147483647 // IntegerField -2147483648 to 2147483647
type IntegerField int32 type IntegerField = models.IntegerField
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder // verify IntegerField implement Fielder
var _ Fielder = new(IntegerField) var _ Fielder = new(IntegerField)
// BigIntegerField -9223372036854775808 to 9223372036854775807. // BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64 type BigIntegerField = models.BigIntegerField
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder // verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField) var _ Fielder = new(BigIntegerField)
// PositiveSmallIntegerField 0 to 65535 // PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16 type PositiveSmallIntegerField = models.PositiveSmallIntegerField
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder // verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField) var _ Fielder = new(PositiveSmallIntegerField)
// PositiveIntegerField 0 to 4294967295 // PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32 type PositiveIntegerField = models.PositiveIntegerField
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder // verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField) var _ Fielder = new(PositiveIntegerField)
// PositiveBigIntegerField 0 to 18446744073709551615 // PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64 type PositiveBigIntegerField = models.PositiveBigIntegerField
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder // verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField) var _ Fielder = new(PositiveBigIntegerField)
// TextField A large text field. // TextField A large text field.
type TextField string type TextField = models.TextField
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder // verify TextField implement Fielder
var _ Fielder = new(TextField) var _ Fielder = new(TextField)
// JSONField postgres json field. // JSONField postgres json field.
type JSONField string type JSONField = models.JSONField
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder // verify JSONField implement Fielder
var _ Fielder = new(JSONField) var _ models.Fielder = new(JSONField)
// JsonbField postgres json field. // JsonbField postgres json field.
type JsonbField string type JsonbField = models.JsonbField
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder // verify JsonbField implement Fielder
var _ Fielder = new(JsonbField) var _ models.Fielder = new(JsonbField)

View File

@ -1,148 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"os"
"reflect"
)
// single model info
type modelInfo struct {
manual bool
isThrough bool
pkg string
name string
fullName string
table string
model interface{}
fields *fields
addrField reflect.Value // store the original struct value
uniques []string
}
// new model info
func newModelInfo(val reflect.Value) (mi *modelInfo) {
mi = &modelInfo{}
mi.fields = newFields()
ind := reflect.Indirect(val)
mi.addrField = val
mi.name = ind.Type().Name()
mi.fullName = getFullName(ind.Type())
addModelFields(mi, ind, "", []int{})
return
}
// index: FieldByIndex returns the nested field corresponding to index
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct fields
if sf.Anonymous {
addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = newFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
// record current field index
fi.fieldIndex = append(fi.fieldIndex, index...)
fi.fieldIndex = append(fi.fieldIndex, i)
fi.mi = mi
fi.inModel = true
if !mi.fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
mi.fields.pk = fi
}
}
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
}
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
mi = new(modelInfo)
mi.fields = newFields()
mi.table = m1.table + "_" + m2.table + "s"
mi.name = camelString(mi.table)
mi.fullName = m1.pkg + "." + mi.name
fa := new(fieldInfo) // pk
f1 := new(fieldInfo) // m1 table RelForeignKey
f2 := new(fieldInfo) // m2 table RelForeignKey
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
f1.fieldType = RelForeignKey
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = mi.fullName + "." + f1.name
f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
f2.rel = true
f1.relTable = m1.table
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = mi
f2.mi = mi
mi.fields.Add(fa)
mi.fields.Add(f1)
mi.fields.Add(f2)
mi.fields.pk = fa
mi.uniques = []string{f1.column, f2.column}
return
}

View File

@ -22,6 +22,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/models"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -79,7 +81,7 @@ func (e *SliceStringField) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(SliceStringField) var _ models.Fielder = new(SliceStringField)
// A json field. // A json field.
type JSONFieldTest struct { type JSONFieldTest struct {
@ -111,7 +113,7 @@ func (e *JSONFieldTest) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(JSONFieldTest) var _ models.Fielder = new(JSONFieldTest)
type Data struct { type Data struct {
ID int `orm:"column(id)"` ID int `orm:"column(id)"`

View File

@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//go:build go1.8
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite // Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage // Simple Usage
// //
@ -57,9 +54,12 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"os"
"reflect" "reflect"
"time"
ilogs "github.com/beego/beego/v2/client/orm/internal/logs"
iutils "github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/client/orm/hints"
@ -75,10 +75,10 @@ const (
// Define common vars // Define common vars
var ( var (
Debug = false Debug = false
DebugLog = NewLog(os.Stdout) DebugLog = ilogs.DebugLog
DefaultRowsLimit = -1 DefaultRowsLimit = -1
DefaultRelsDepth = 2 DefaultRelsDepth = 2
DefaultTimeLoc = time.Local DefaultTimeLoc = iutils.DefaultTimeLoc
ErrTxDone = errors.New("<TxOrmer.Commit/Rollback> transaction already done") ErrTxDone = errors.New("<TxOrmer.Commit/Rollback> transaction already done")
ErrMultiRows = errors.New("<QuerySeter> return multi rows") ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found") ErrNoRows = errors.New("<QuerySeter> no row found")
@ -107,7 +107,7 @@ var (
) )
// get model info and model reflect value // get model info and model reflect value
func (*ormBase) getMi(md interface{}) (mi *modelInfo) { func (*ormBase) getMi(md interface{}) (mi *models.ModelInfo) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
@ -116,19 +116,19 @@ func (*ormBase) getMi(md interface{}) (mi *modelInfo) {
} }
// get need ptr model info and model reflect value // get need ptr model info and model reflect value
func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { func (*ormBase) getPtrMiInd(md interface{}) (mi *models.ModelInfo, ind reflect.Value) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind = reflect.Indirect(val) ind = reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
if val.Kind() != reflect.Ptr { if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", models.GetFullName(typ)))
} }
mi = getTypeMi(typ) mi = getTypeMi(typ)
return return
} }
func getTypeMi(mdTyp reflect.Type) *modelInfo { func getTypeMi(mdTyp reflect.Type) *models.ModelInfo {
name := getFullName(mdTyp) name := models.GetFullName(mdTyp)
if mi, ok := defaultModelCache.getByFullName(name); ok { if mi, ok := defaultModelCache.getByFullName(name); ok {
return mi return mi
} }
@ -136,10 +136,10 @@ func getTypeMi(mdTyp reflect.Type) *modelInfo {
} }
// get field info from model info by given field name // get field info from model info by given field name
func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { func (*ormBase) getFieldInfo(mi *models.ModelInfo, name string) *models.FieldInfo {
fi, ok := mi.fields.GetByAny(name) fi, ok := mi.Fields.GetByAny(name)
if !ok { if !ok {
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName)) panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.FullName))
} }
return fi return fi
} }
@ -179,11 +179,11 @@ func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1
return err == nil, id, err return err == nil, id, err
} }
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) id, vid := int64(0), ind.FieldByIndex(mi.Fields.Pk.FieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint()) id = int64(vid.Uint())
} else if mi.fields.pk.rel { } else if mi.Fields.Pk.Rel {
return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.Fields.Pk.RelModelInfo.Fields.Pk.Name)
} else { } else {
id = vid.Int() id = vid.Int()
} }
@ -209,12 +209,12 @@ func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, err
} }
// set auto pk field // set auto pk field
func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { func (*ormBase) setPk(mi *models.ModelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.Fields.Pk.Auto {
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetInt(id)
} }
} }
} }
@ -276,7 +276,7 @@ func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, col
} }
// update model to database. // update model to database.
// cols set the columns those want to update. // cols set the Columns those want to update.
func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
return o.UpdateWithCtx(context.Background(), md, cols...) return o.UpdateWithCtx(context.Background(), md, cols...)
} }
@ -304,10 +304,10 @@ func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
switch { switch {
case fi.fieldType == RelManyToMany: case fi.FieldType == RelManyToMany:
case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: case fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough:
default: default:
panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.Name, mi.FullName))
} }
return newQueryM2M(md, o, mi, fi, ind) return newQueryM2M(md, o, mi, fi, ind)
@ -362,7 +362,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
} }
}) })
switch fi.fieldType { switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelReverseOne: case RelOneToOne, RelForeignKey, RelReverseOne:
limit = 1 limit = 1
offset = 0 offset = 0
@ -376,11 +376,11 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
qs.orders = order_clause.ParseOrder(order) qs.orders = order_clause.ParseOrder(order)
} }
find := ind.FieldByIndex(fi.fieldIndex) find := ind.FieldByIndex(fi.FieldIndex)
var nums int64 var nums int64
var err error var err error
switch fi.fieldType { switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelReverseOne: case RelOneToOne, RelForeignKey, RelReverseOne:
val := reflect.New(find.Type().Elem()) val := reflect.New(find.Type().Elem())
container := val.Interface() container := val.Interface()
@ -397,7 +397,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str
} }
// get QuerySeter for related models to md model // get QuerySeter for related models to md model
func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { func (o *ormBase) queryRelated(md interface{}, name string) (*models.ModelInfo, *models.FieldInfo, reflect.Value, *querySet) {
mi, ind := o.getPtrMiInd(md) mi, ind := o.getPtrMiInd(md)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
@ -408,14 +408,14 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI
var qs *querySet var qs *querySet
switch fi.fieldType { switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelManyToMany: case RelOneToOne, RelForeignKey, RelManyToMany:
if !fi.inModel { if !fi.InModel {
break break
} }
qs = o.getRelQs(md, mi, fi) qs = o.getRelQs(md, mi, fi)
case RelReverseOne, RelReverseMany: case RelReverseOne, RelReverseMany:
if !fi.inModel { if !fi.InModel {
break break
} }
qs = o.getReverseQs(md, mi, fi) qs = o.getReverseQs(md, mi, fi)
@ -429,41 +429,41 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI
} }
// get reverse relation QuerySeter // get reverse relation QuerySeter
func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *ormBase) getReverseQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet {
switch fi.fieldType { switch fi.FieldType {
case RelReverseOne, RelReverseMany: case RelReverseOne, RelReverseMany:
default: default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.Name, mi.FullName))
} }
var q *querySet var q *querySet
if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { if fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough {
q = newQuerySet(o, fi.relModelInfo).(*querySet) q = newQuerySet(o, fi.RelModelInfo).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) q.cond = NewCondition().And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md)
} else { } else {
q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) q = newQuerySet(o, fi.ReverseFieldInfo.Mi).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) q.cond = NewCondition().And(fi.ReverseFieldInfo.Column, md)
} }
return q return q
} }
// get relation QuerySeter // get relation QuerySeter
func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *ormBase) getRelQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet {
switch fi.fieldType { switch fi.FieldType {
case RelOneToOne, RelForeignKey, RelManyToMany: case RelOneToOne, RelForeignKey, RelManyToMany:
default: default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.Name, mi.FullName))
} }
q := newQuerySet(o, fi.relModelInfo).(*querySet) q := newQuerySet(o, fi.RelModelInfo).(*querySet)
q.cond = NewCondition() q.cond = NewCondition()
if fi.fieldType == RelManyToMany { if fi.FieldType == RelManyToMany {
q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) q.cond = q.cond.And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md)
} else { } else {
q.cond = q.cond.And(fi.reverseFieldInfo.column, md) q.cond = q.cond.And(fi.ReverseFieldInfo.Column, md)
} }
return q return q
@ -475,12 +475,12 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string var name string
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table) name = models.NameStrategyMap[models.DefaultNameStrategy](table)
if mi, ok := defaultModelCache.get(name); ok { if mi, ok := defaultModelCache.get(name); ok {
qs = newQuerySet(o, mi) qs = newQuerySet(o, mi)
} }
} else { } else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) name = models.GetFullName(iutils.IndirectType(reflect.TypeOf(ptrStructOrTableName)))
if mi, ok := defaultModelCache.getByFullName(name); ok { if mi, ok := defaultModelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi) qs = newQuerySet(o, mi)
} }

View File

@ -22,23 +22,22 @@ import (
"log" "log"
"strings" "strings"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/logs"
) )
// Log implement the log.Logger type Log = logs.Log
type Log struct {
*log.Logger
}
// costomer log func
var LogFunc func(query map[string]interface{})
// NewLog set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *logs.Log {
d := new(Log) d := new(logs.Log)
d.Logger = log.New(out, "[ORM]", log.LstdFlags) d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d return d
} }
// LogFunc costomer log func
var LogFunc func(query map[string]interface{})
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
logMap := make(map[string]interface{}) logMap := make(map[string]interface{})
sub := time.Since(t) / 1e5 sub := time.Since(t) / 1e5
@ -64,7 +63,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if LogFunc != nil { if LogFunc != nil {
LogFunc(logMap) LogFunc(logMap)
} }
DebugLog.Println(con) logs.DebugLog.Println(con)
} }
// statement query logger struct. // statement query logger struct.

View File

@ -18,11 +18,13 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// an insert queryer struct // an insert queryer struct
type insertSet struct { type insertSet struct {
mi *modelInfo mi *models.ModelInfo
orm *ormBase orm *ormBase
stmt stmtQuerier stmt stmtQuerier
closed bool closed bool
@ -42,23 +44,23 @@ func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, e
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
name := getFullName(typ) name := models.GetFullName(typ)
if val.Kind() != reflect.Ptr { if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name)) panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
} }
if name != o.mi.fullName { if name != o.mi.FullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.FullName, name))
} }
id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
if id > 0 { if id > 0 {
if o.mi.fields.pk.auto { if o.mi.Fields.Pk.Auto {
if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { if o.mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetInt(id)
} }
} }
} }
@ -75,7 +77,7 @@ func (o *insertSet) Close() error {
} }
// create new insert queryer. // create new insert queryer.
func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { func newInsertSet(ctx context.Context, orm *ormBase, mi *models.ModelInfo) (Inserter, error) {
bi := new(insertSet) bi := new(insertSet)
bi.orm = orm bi.orm = orm
bi.mi = mi bi.mi = mi

View File

@ -17,13 +17,15 @@ package orm
import ( import (
"context" "context"
"reflect" "reflect"
"github.com/beego/beego/v2/client/orm/internal/models"
) )
// model to model struct // model to model struct
type queryM2M struct { type queryM2M struct {
md interface{} md interface{}
mi *modelInfo mi *models.ModelInfo
fi *fieldInfo fi *models.FieldInfo
qs *querySet qs *querySet
ind reflect.Value ind reflect.Value
} }
@ -42,9 +44,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
mi := fi.relThroughModelInfo mi := fi.RelThroughModelInfo
mfi := fi.reverseFieldInfo mfi := fi.ReverseFieldInfo
rfi := fi.reverseFieldInfoTwo rfi := fi.ReverseFieldInfoTwo
orm := o.qs.orm orm := o.qs.orm
dbase := orm.alias.DbBaser dbase := orm.alias.DbBaser
@ -53,9 +55,9 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
var otherValues []interface{} var otherValues []interface{}
var otherNames []string var otherNames []string
for _, colname := range mi.fields.dbcols { for _, colname := range mi.Fields.DBcols {
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && if colname != mfi.Column && colname != rfi.Column && colname != fi.Mi.Fields.Pk.Column &&
mi.fields.columns[colname] != mi.fields.pk { mi.Fields.Columns[colname] != mi.Fields.Pk {
otherNames = append(otherNames, colname) otherNames = append(otherNames, colname)
} }
} }
@ -84,7 +86,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
panic(ErrMissPK) panic(ErrMissPK)
} }
names := []string{mfi.column, rfi.column} names := []string{mfi.Column, rfi.Column}
values := make([]interface{}, 0, len(models)*2) values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
@ -94,7 +96,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e
if ind.Kind() != reflect.Struct { if ind.Kind() != reflect.Struct {
v2 = ind.Interface() v2 = ind.Interface()
} else { } else {
_, v2, exist = getExistPk(fi.relModelInfo, ind) _, v2, exist = getExistPk(fi.RelModelInfo, ind)
if !exist { if !exist {
panic(ErrMissPK) panic(ErrMissPK)
} }
@ -114,9 +116,9 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.ReverseFieldInfo.Name, o.md)
return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() return qs.Filter(fi.ReverseFieldInfoTwo.Name+ExprSep+"in", mds).Delete()
} }
// check model is existed in relationship of origin model // check model is existed in relationship of origin model
@ -126,8 +128,8 @@ func (o *queryM2M) Exist(md interface{}) bool {
func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md). return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) Filter(fi.ReverseFieldInfoTwo.Name, md).ExistWithCtx(ctx)
} }
// clean all models in related of origin model // clean all models in related of origin model
@ -137,7 +139,7 @@ func (o *queryM2M) Clear() (int64, error) {
func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).DeleteWithCtx(ctx)
} }
// count all related models of origin model // count all related models of origin model
@ -147,18 +149,18 @@ func (o *queryM2M) Count() (int64, error) {
func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).CountWithCtx(ctx)
} }
var _ QueryM2Mer = new(queryM2M) var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer. // create new M2M queryer.
func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { func newQueryM2M(md interface{}, o *ormBase, mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M) qm2m := new(queryM2M)
qm2m.md = md qm2m.md = md
qm2m.mi = mi qm2m.mi = mi
qm2m.fi = fi qm2m.fi = fi
qm2m.ind = ind qm2m.ind = ind
qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) qm2m.qs = newQuerySet(o, fi.RelThroughModelInfo).(*querySet)
return qm2m return qm2m
} }

View File

@ -18,6 +18,10 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/client/orm/hints"
) )
@ -54,7 +58,7 @@ func ColValue(opt operator, value interface{}) interface{} {
default: default:
panic(fmt.Errorf("orm.ColValue wrong operator")) panic(fmt.Errorf("orm.ColValue wrong operator"))
} }
v, err := StrTo(ToStr(value)).Int64() v, err := utils.StrTo(utils.ToStr(value)).Int64()
if err != nil { if err != nil {
panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err))
} }
@ -66,7 +70,7 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct // real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *models.ModelInfo
cond *Condition cond *Condition
related []string related []string
relDepth int relDepth int
@ -113,13 +117,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
// set offset number // set offset number
func (o *querySet) setOffset(num interface{}) { func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num) o.offset = utils.ToInt64(num)
} }
// add LIMIT value. // add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset. // args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit) o.limit = utils.ToInt64(limit)
if len(args) > 0 { if len(args) > 0 {
o.setOffset(args[0]) o.setOffset(args[0])
} }
@ -273,7 +277,7 @@ func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) {
} }
// query all data and map to containers. // query all data and map to containers.
// cols means the columns when querying. // cols means the Columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) { func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.AllWithCtx(context.Background(), container, cols...) return o.AllWithCtx(context.Background(), container, cols...)
} }
@ -283,7 +287,7 @@ func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols .
} }
// query one row data and map to containers. // query one row data and map to containers.
// cols means the columns when querying. // cols means the Columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
return o.OneWithCtx(context.Background(), container, cols...) return o.OneWithCtx(context.Background(), container, cols...)
} }
@ -366,7 +370,7 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
} }
// create new QuerySeter. // create new QuerySeter.
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { func newQuerySet(orm *ormBase, mi *models.ModelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
o.mi = mi o.mi = mi
o.orm = orm o.orm = orm

View File

@ -20,6 +20,10 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -95,7 +99,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} else if v, ok := value.(bool); ok { } else if v, ok := value.(bool); ok {
ind.SetBool(v) ind.SetBool(v)
} else { } else {
v, _ := StrTo(ToStr(value)).Bool() v, _ := utils.StrTo(utils.ToStr(value)).Bool()
ind.SetBool(v) ind.SetBool(v)
} }
@ -103,7 +107,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if value == nil { if value == nil {
ind.SetString("") ind.SetString("")
} else { } else {
ind.SetString(ToStr(value)) ind.SetString(utils.ToStr(value))
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -117,7 +121,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetInt(int64(val.Uint())) ind.SetInt(int64(val.Uint()))
default: default:
v, _ := StrTo(ToStr(value)).Int64() v, _ := utils.StrTo(utils.ToStr(value)).Int64()
ind.SetInt(v) ind.SetInt(v)
} }
} }
@ -132,7 +136,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetUint(val.Uint()) ind.SetUint(val.Uint())
default: default:
v, _ := StrTo(ToStr(value)).Uint64() v, _ := utils.StrTo(utils.ToStr(value)).Uint64()
ind.SetUint(v) ind.SetUint(v)
} }
} }
@ -145,7 +149,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Float64: case reflect.Float64:
ind.SetFloat(val.Float()) ind.SetFloat(val.Float())
default: default:
v, _ := StrTo(ToStr(value)).Float64() v, _ := utils.StrTo(utils.ToStr(value)).Float64()
ind.SetFloat(v) ind.SetFloat(v)
} }
} }
@ -170,20 +174,20 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" { if str != "" {
if len(str) >= 19 { if len(str) >= 19 {
str = str[:19] str = str[:19]
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) t, err := time.ParseInLocation(utils.FormatDateTime, str, o.orm.alias.TZ)
if err == nil { if err == nil {
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
} else if len(str) >= 10 { } else if len(str) >= 10 {
str = str[:10] str = str[:10]
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) t, err := time.ParseInLocation(utils.FormatDate, str, DefaultTimeLoc)
if err == nil { if err == nil {
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
} else if len(str) >= 8 { } else if len(str) >= 8 {
str = str[:8] str = str[:8]
t, err := time.ParseInLocation(formatTime, str, DefaultTimeLoc) t, err := time.ParseInLocation(utils.FormatTime, str, DefaultTimeLoc)
if err == nil { if err == nil {
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
@ -287,7 +291,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
refs = make([]interface{}, 0, len(containers)) refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value sInds []reflect.Value
eTyps []reflect.Type eTyps []reflect.Type
sMi *modelInfo sMi *models.ModelInfo
) )
structMode := false structMode := false
for _, container := range containers { for _, container := range containers {
@ -313,7 +317,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
} }
structMode = true structMode = true
fn := getFullName(typ) fn := models.GetFullName(typ)
if mi, ok := defaultModelCache.getByFullName(fn); ok { if mi, ok := defaultModelCache.getByFullName(fn); ok {
sMi = mi sMi = mi
} }
@ -370,16 +374,16 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if sMi != nil { if sMi != nil {
for _, col := range columns { for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil { if fi := sMi.Fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface() value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex) field := ind.FieldByIndex(fi.FieldIndex)
if fi.fieldType&IsRelField > 0 { if fi.FieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type())
field.Set(mf) field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex)
} }
if fi.isFielder { if fi.IsFielder {
fd := field.Addr().Interface().(Fielder) fd := field.Addr().Interface().(models.Fielder)
err := fd.SetRaw(value) err := fd.SetRaw(value)
if err != nil { if err != nil {
return errors.Errorf("set raw error:%s", err) return errors.Errorf("set raw error:%s", err)
@ -406,12 +410,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// thanks @Gazeboxu. // thanks @Gazeboxu.
tags := structTagMap[fe.Tag] tags := structTagMap[fe.Tag]
if tags == nil { if tags == nil {
_, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) _, tags = models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName))
structTagMap[fe.Tag] = tags structTagMap[fe.Tag] = tags
} }
var col string var col string
if col = tags["column"]; col == "" { if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name) col = models.NameStrategyMap[models.NameStrategy](fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface() value := reflect.ValueOf(v).Elem().Interface()
@ -449,7 +453,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs = make([]interface{}, 0, len(containers)) refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value sInds []reflect.Value
eTyps []reflect.Type eTyps []reflect.Type
sMi *modelInfo sMi *models.ModelInfo
) )
structMode := false structMode := false
for _, container := range containers { for _, container := range containers {
@ -474,7 +478,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
} }
structMode = true structMode = true
fn := getFullName(typ) fn := models.GetFullName(typ)
if mi, ok := defaultModelCache.getByFullName(fn); ok { if mi, ok := defaultModelCache.getByFullName(fn); ok {
sMi = mi sMi = mi
} }
@ -537,16 +541,16 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
if sMi != nil { if sMi != nil {
for _, col := range columns { for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil { if fi := sMi.Fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface() value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex) field := ind.FieldByIndex(fi.FieldIndex)
if fi.fieldType&IsRelField > 0 { if fi.FieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type())
field.Set(mf) field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex)
} }
if fi.isFielder { if fi.IsFielder {
fd := field.Addr().Interface().(Fielder) fd := field.Addr().Interface().(models.Fielder)
err := fd.SetRaw(value) err := fd.SetRaw(value)
if err != nil { if err != nil {
return 0, errors.Errorf("set raw error:%s", err) return 0, errors.Errorf("set raw error:%s", err)
@ -570,10 +574,10 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
recursiveSetField(f) recursiveSetField(f)
} }
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) _, tags := models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName))
var col string var col string
if col = tags["column"]; col == "" { if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name) col = models.NameStrategyMap[models.NameStrategy](fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface() value := reflect.ValueOf(v).Elem().Interface()
@ -837,7 +841,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
} }
default: default:
if id := ind.FieldByName(camelString(key)); id.IsValid() { if id := ind.FieldByName(models.CamelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
} }
} }

View File

@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//go:build go1.8
// +build go1.8
package orm package orm
import ( import (
@ -32,6 +29,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/logs"
"github.com/beego/beego/v2/client/orm/internal/utils"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
@ -41,9 +44,9 @@ import (
var _ = os.PathSeparator var _ = os.PathSeparator
var ( var (
testDate = formatDate + " -0700" testDate = utils.FormatDate + " -0700"
testDateTime = formatDateTime + " -0700" testDateTime = utils.FormatDateTime + " -0700"
testTime = formatTime + " -0700" testTime = utils.FormatTime + " -0700"
) )
type argAny []interface{} type argAny []interface{}
@ -72,7 +75,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er
case time.Time: case time.Time:
if v2, vo := b.(time.Time); vo { if v2, vo := b.(time.Time); vo {
if arg.Get(1) != nil { if arg.Get(1) != nil {
format := ToStr(arg.Get(1)) format := utils.ToStr(arg.Get(1))
a = v.Format(format) a = v.Format(format)
b = v2.Format(format) b = v2.Format(format)
ok = a == b ok = a == b
@ -82,7 +85,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er
} }
} }
default: default:
ok = ToStr(a) == ToStr(b) ok = utils.ToStr(a) == utils.ToStr(b)
} }
ok = is && ok || !is && !ok ok = is && ok || !is && !ok
if !ok { if !ok {
@ -250,14 +253,14 @@ func TestRegisterModels(_ *testing.T) {
func TestModelSyntax(t *testing.T) { func TestModelSyntax(t *testing.T) {
user := &User{} user := &User{}
ind := reflect.ValueOf(user).Elem() ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type()) fn := models.GetFullName(ind.Type())
_, ok := defaultModelCache.getByFullName(fn) _, ok := defaultModelCache.getByFullName(fn)
throwFail(t, AssertIs(ok, true)) throwFail(t, AssertIs(ok, true))
mi, ok := defaultModelCache.get("user") mi, ok := defaultModelCache.get("user")
throwFail(t, AssertIs(ok, true)) throwFail(t, AssertIs(ok, true))
if ok { if ok {
throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) throwFail(t, AssertIs(mi.Fields.GetByName("ShouldSkip") == nil, true))
} }
} }
@ -561,7 +564,7 @@ func TestNullDataTypes(t *testing.T) {
assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second)
assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second)
// test support for pointer fields using RawSeter.QueryRows() // test support for pointer Fields using RawSeter.QueryRows()
var dnList []*DataNull var dnList []*DataNull
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList)
@ -1894,7 +1897,7 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(row.Id, 4)) throwFail(t, AssertIs(row.Id, 4))
throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com")) throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com"))
// test for sql.Null* fields // test for sql.Null* Fields
nData := &DataNull{ nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true}, NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true},
@ -2003,7 +2006,7 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30)) throwFailNow(t, AssertIs(l[1].Age, 30))
// test for sql.Null* fields // test for sql.Null* Fields
nData := &DataNull{ nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true}, NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true},
@ -2616,7 +2619,7 @@ func TestSnake(t *testing.T) {
"tag_666Name": "tag_666_name", "tag_666Name": "tag_666_name",
} }
for name, want := range cases { for name, want := range cases {
got := snakeString(name) got := models.SnakeString(name)
throwFail(t, AssertIs(got, want)) throwFail(t, AssertIs(got, want))
} }
} }
@ -2637,10 +2640,10 @@ func TestIgnoreCaseTag(t *testing.T) {
if t == nil { if t == nil {
return return
} }
throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) throwFail(t, AssertIs(info.Fields.GetByName("NOO").Column, "n"))
throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) throwFail(t, AssertIs(info.Fields.GetByName("Name01").Null, true))
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) throwFail(t, AssertIs(info.Fields.GetByName("Name02").Column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) throwFail(t, AssertIs(info.Fields.GetByName("Name03").Column, "name"))
} }
func TestInsertOrUpdate(t *testing.T) { func TestInsertOrUpdate(t *testing.T) {
@ -2934,9 +2937,9 @@ func TestDebugLog(t *testing.T) {
func captureDebugLogOutput(f func()) string { func captureDebugLogOutput(f func()) string {
var buf bytes.Buffer var buf bytes.Buffer
DebugLog.SetOutput(&buf) logs.DebugLog.SetOutput(&buf)
defer func() { defer func() {
DebugLog.SetOutput(os.Stderr) logs.DebugLog.SetOutput(os.Stderr)
}() }()
f() f()
return buf.String() return buf.String()

View File

@ -28,7 +28,7 @@ type MySQLQueryBuilder struct {
tokens []string tokens []string
} }
// Select will join the fields // Select will join the Fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace)) qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb return qb
@ -94,7 +94,7 @@ func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
return qb return qb
} }
// OrderBy join the Order by fields // OrderBy join the Order by Fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace)) qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb return qb
@ -124,7 +124,7 @@ func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
return qb return qb
} }
// GroupBy join the Group by fields // GroupBy join the Group by Fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace)) qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb return qb

View File

@ -19,7 +19,7 @@ func processingStr(str []string) string {
return s return s
} }
// Select will join the fields // Select will join the Fields
func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder { func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder {
var str string var str string
n := len(fields) n := len(fields)
@ -121,7 +121,7 @@ func (qb *PostgresQueryBuilder) In(vals ...string) QueryBuilder {
return qb return qb
} }
// OrderBy join the Order by fields // OrderBy join the Order by Fields
func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder { func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder {
str := processingStr(fields) str := processingStr(fields)
qb.tokens = append(qb.tokens, "ORDER BY", str) qb.tokens = append(qb.tokens, "ORDER BY", str)
@ -152,7 +152,7 @@ func (qb *PostgresQueryBuilder) Offset(offset int) QueryBuilder {
return qb return qb
} }
// GroupBy join the Group by fields // GroupBy join the Group by Fields
func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder { func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder {
str := processingStr(fields) str := processingStr(fields)
qb.tokens = append(qb.tokens, "GROUP BY", str) qb.tokens = append(qb.tokens, "GROUP BY", str)

View File

@ -20,11 +20,13 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/beego/beego/v2/client/orm/internal/models"
"github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/clauses/order_clause"
"github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/core/utils"
) )
// TableNaming is usually used by model // TableNameI is usually used by model
// when you custom your table name, please implement this interfaces // when you custom your table name, please implement this interfaces
// for example: // for example:
// //
@ -95,22 +97,16 @@ type Driver interface {
Type() DriverType Type() DriverType
} }
// Fielder define field info type Fielder = models.Fielder
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}
type TxBeginner interface { type TxBeginner interface {
// self control transaction // Begin self control transaction
Begin() (TxOrmer, error) Begin() (TxOrmer, error)
BeginWithCtx(ctx context.Context) (TxOrmer, error) BeginWithCtx(ctx context.Context) (TxOrmer, error)
BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error)
BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error)
// closure control transaction // DoTx closure control transaction
DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error
DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
@ -146,27 +142,27 @@ type txEnder interface {
RollbackUnlessCommit() error RollbackUnlessCommit() error
} }
// Data Manipulation Language // DML Data Manipulation Language
type DML interface { type DML interface {
// insert model data to database // Insert insert model data to database
// for example: // for example:
// user := new(User) // user := new(User)
// id, err = Ormer.Insert(user) // id, err = Ormer.Insert(user)
// user must be a pointer and Insert will set user's pk field // user must be a pointer and Insert will set user's pk field
Insert(md interface{}) (int64, error) Insert(md interface{}) (int64, error)
InsertWithCtx(ctx context.Context, md interface{}) (int64, error) InsertWithCtx(ctx context.Context, md interface{}) (int64, error)
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // InsertOrUpdate mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// if colu type is integer : can use(+-*/), string : convert(colu,"value") // if colu type is integer : can use(+-*/), string : convert(colu,"value")
// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
// if colu type is integer : can use(+-*/), string : colu || "value" // if colu type is integer : can use(+-*/), string : colu || "value"
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error)
// insert some models to database // InsertMulti inserts some models to database
InsertMulti(bulk int, mds interface{}) (int64, error) InsertMulti(bulk int, mds interface{}) (int64, error)
InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error)
// update model to database. // Update updates model to database.
// cols set the columns those want to update. // cols set the Columns those want to update.
// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns // find model by Id(pk) field and update Columns specified by Fields, if cols is null then update all Columns
// for example: // for example:
// user := User{Id: 2} // user := User{Id: 2}
// user.Langs = append(user.Langs, "zh-CN", "en-US") // user.Langs = append(user.Langs, "zh-CN", "en-US")
@ -175,11 +171,11 @@ type DML interface {
// num, err = Ormer.Update(&user, "Langs", "Extra") // num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error) Update(md interface{}, cols ...string) (int64, error)
UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
// delete model in database // Delete deletes model in database
Delete(md interface{}, cols ...string) (int64, error) Delete(md interface{}, cols ...string) (int64, error)
DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
// return a raw query seter for raw sql string. // Raw return a raw query seter for raw sql string.
// for example: // for example:
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
// // update user testing's name to slene // // update user testing's name to slene
@ -187,9 +183,9 @@ type DML interface {
RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter
} }
// Data Query Language // DQL Data Query Language
type DQL interface { type DQL interface {
// read data to model // Read reads data to model
// for example: // for example:
// this will find User by Id field // this will find User by Id field
// u = &User{Id: user.Id} // u = &User{Id: user.Id}
@ -200,16 +196,16 @@ type DQL interface {
Read(md interface{}, cols ...string) error Read(md interface{}, cols ...string) error
ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error
// Like Read(), but with "FOR UPDATE" clause, useful in transaction. // ReadForUpdate Like Read(), but with "FOR UPDATE" clause, useful in transaction.
// Some databases are not support this feature. // Some databases are not support this feature.
ReadForUpdate(md interface{}, cols ...string) error ReadForUpdate(md interface{}, cols ...string) error
ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist // ReadOrCreate Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error)
// load related models to md model. // LoadRelated load related models to md model.
// args are limit, offset int and order string. // args are limit, offset int and order string.
// //
// example: // example:
@ -224,20 +220,20 @@ type DQL interface {
LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error)
LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error)
// create a models to models queryer // QueryM2M create a models to models queryer
// for example: // for example:
// post := Post{Id: 4} // post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags") // m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer QueryM2M(md interface{}, name string) QueryM2Mer
// NOTE: this method is deprecated, context parameter will not take effect. // QueryM2MWithCtx NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations. // QueryTable return a QuerySeter for table operations.
// table name can be string or struct. // table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter QueryTable(ptrStructOrTableName interface{}) QuerySeter
// NOTE: this method is deprecated, context parameter will not take effect. // QueryTableWithCtx NOTE: this method is deprecated, context parameter will not take effect.
// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
@ -278,7 +274,7 @@ type Inserter interface {
// QuerySeter query seter // QuerySeter query seter
type QuerySeter interface { type QuerySeter interface {
// add condition expression to QuerySeter. // Filter add condition expression to QuerySeter.
// for example: // for example:
// filter by UserName == 'slene' // filter by UserName == 'slene'
// qs.Filter("UserName", "slene") // qs.Filter("UserName", "slene")
@ -287,22 +283,22 @@ type QuerySeter interface {
// // time compare // // time compare
// qs.Filter("created", time.Now()) // qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
// add raw sql to querySeter. // FilterRaw add raw sql to querySeter.
// for example: // for example:
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
FilterRaw(string, string) QuerySeter FilterRaw(string, string) QuerySeter
// add NOT condition to querySeter. // Exclude add NOT condition to querySeter.
// have the same usage as Filter // have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
// set condition to QuerySeter. // SetCond set condition to QuerySeter.
// sql's where condition // sql's where condition
// cond := orm.NewCondition() // cond := orm.NewCondition()
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count() // num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter SetCond(*Condition) QuerySeter
// get condition from QuerySeter. // GetCond get condition from QuerySeter.
// sql's where condition // sql's where condition
// cond := orm.NewCondition() // cond := orm.NewCondition()
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1) // cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
@ -312,7 +308,7 @@ type QuerySeter interface {
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond).Count() // num, err := qs.SetCond(cond).Count()
GetCond() *Condition GetCond() *Condition
// add LIMIT value. // Limit add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset. // args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000 // if Limit <= 0 then Limit will be set to default limit ,eg 1000
// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
@ -320,19 +316,19 @@ type QuerySeter interface {
// qs.Limit(10, 2) // qs.Limit(10, 2)
// // sql-> limit 10 offset 2 // // sql-> limit 10 offset 2
Limit(limit interface{}, args ...interface{}) QuerySeter Limit(limit interface{}, args ...interface{}) QuerySeter
// add OFFSET value // Offset add OFFSET value
// same as Limit function's args[0] // same as Limit function's args[0]
Offset(offset interface{}) QuerySeter Offset(offset interface{}) QuerySeter
// add GROUP BY expression // GroupBy add GROUP BY expression
// for example: // for example:
// qs.GroupBy("id") // qs.GroupBy("id")
GroupBy(exprs ...string) QuerySeter GroupBy(exprs ...string) QuerySeter
// add ORDER expression. // OrderBy add ORDER expression.
// "column" means ASC, "-column" means DESC. // "column" means ASC, "-column" means DESC.
// for example: // for example:
// qs.OrderBy("-status") // qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter OrderBy(exprs ...string) QuerySeter
// add ORDER expression by order clauses // OrderClauses add ORDER expression by order clauses
// for example: // for example:
// OrderClauses( // OrderClauses(
// order_clause.Clause( // order_clause.Clause(
@ -354,50 +350,50 @@ type QuerySeter interface {
// order_clause.Raw(),//default false.if true, do not check field is valid or not // order_clause.Raw(),//default false.if true, do not check field is valid or not
// )) // ))
OrderClauses(orders ...*order_clause.Order) QuerySeter OrderClauses(orders ...*order_clause.Order) QuerySeter
// add FORCE INDEX expression. // ForceIndex add FORCE INDEX expression.
// for example: // for example:
// qs.ForceIndex(`idx_name1`,`idx_name2`) // qs.ForceIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
ForceIndex(indexes ...string) QuerySeter ForceIndex(indexes ...string) QuerySeter
// add USE INDEX expression. // UseIndex add USE INDEX expression.
// for example: // for example:
// qs.UseIndex(`idx_name1`,`idx_name2`) // qs.UseIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
UseIndex(indexes ...string) QuerySeter UseIndex(indexes ...string) QuerySeter
// add IGNORE INDEX expression. // IgnoreIndex add IGNORE INDEX expression.
// for example: // for example:
// qs.IgnoreIndex(`idx_name1`,`idx_name2`) // qs.IgnoreIndex(`idx_name1`,`idx_name2`)
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
IgnoreIndex(indexes ...string) QuerySeter IgnoreIndex(indexes ...string) QuerySeter
// set relation model to query together. // RelatedSel set relation model to query together.
// it will query relation models and assign to parent model. // it will query relation models and assign to parent model.
// for example: // for example:
// // will load all related fields use left join . // // will load all related Fields use left join .
// qs.RelatedSel().One(&user) // qs.RelatedSel().One(&user)
// // will load related field only profile // // will load related field only profile
// qs.RelatedSel("profile").One(&user) // qs.RelatedSel("profile").One(&user)
// user.Profile.Age = 32 // user.Profile.Age = 32
RelatedSel(params ...interface{}) QuerySeter RelatedSel(params ...interface{}) QuerySeter
// Set Distinct // Distinct Set Distinct
// for example: // for example:
// o.QueryTable("policy").Filter("Groups__Group__Users__User", user). // o.QueryTable("policy").Filter("Groups__Group__Users__User", user).
// Distinct(). // Distinct().
// All(&permissions) // All(&permissions)
Distinct() QuerySeter Distinct() QuerySeter
// set FOR UPDATE to query. // ForUpdate set FOR UPDATE to query.
// for example: // for example:
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
ForUpdate() QuerySeter ForUpdate() QuerySeter
// return QuerySeter execution result number // Count returns QuerySeter execution result number
// for example: // for example:
// num, err = qs.Filter("profile__age__gt", 28).Count() // num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error) Count() (int64, error)
CountWithCtx(context.Context) (int64, error) CountWithCtx(context.Context) (int64, error)
// check result empty or not after QuerySeter executed // Exist check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0 // the same as QuerySeter.Count > 0
Exist() bool Exist() bool
ExistWithCtx(context.Context) bool ExistWithCtx(context.Context) bool
// execute update with parameters // Update execute update with parameters
// for example: // for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{ // num, err = qs.Filter("user_name", "slene").Update(Params{
// "Nums": ColValue(Col_Minus, 50), // "Nums": ColValue(Col_Minus, 50),
@ -407,13 +403,13 @@ type QuerySeter interface {
// }) // user slene's name will change to slene2 // }) // user slene's name will change to slene2
Update(values Params) (int64, error) Update(values Params) (int64, error)
UpdateWithCtx(ctx context.Context, values Params) (int64, error) UpdateWithCtx(ctx context.Context, values Params) (int64, error)
// delete from table // Delete delete from table
// for example: // for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2 // //delete two user who's name is testing1 or testing2
Delete() (int64, error) Delete() (int64, error)
DeleteWithCtx(context.Context) (int64, error) DeleteWithCtx(context.Context) (int64, error)
// return an insert queryer. // PrepareInsert return an insert queryer.
// it can be used in times. // it can be used in times.
// example: // example:
// i,err := sq.PrepareInsert() // i,err := sq.PrepareInsert()
@ -422,21 +418,21 @@ type QuerySeter interface {
// err = i.Close() //don't forget call Close // err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
PrepareInsertWithCtx(context.Context) (Inserter, error) PrepareInsertWithCtx(context.Context) (Inserter, error)
// query all data and map to containers. // All query all data and map to containers.
// cols means the columns when querying. // cols means the Columns when querying.
// for example: // for example:
// var users []*User // var users []*User
// qs.All(&users) // users[0],users[1],users[2] ... // qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error) All(container interface{}, cols ...string) (int64, error)
AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error)
// query one row data and map to containers. // One query one row data and map to containers.
// cols means the columns when querying. // cols means the Columns when querying.
// for example: // for example:
// var user User // var user User
// qs.One(&user) //user.UserName == "slene" // qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error One(container interface{}, cols ...string) error
OneWithCtx(ctx context.Context, container interface{}, cols ...string) error OneWithCtx(ctx context.Context, container interface{}, cols ...string) error
// query all data and map to []map[string]interface. // Values query all data and map to []map[string]interface.
// expres means condition expression. // expres means condition expression.
// it converts data to []map[column]value. // it converts data to []map[column]value.
// for example: // for example:
@ -444,21 +440,21 @@ type QuerySeter interface {
// qs.Values(&maps) //maps[0]["UserName"]=="slene" // qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error) Values(results *[]Params, exprs ...string) (int64, error)
ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface // ValuesList query all data and map to [][]interface
// it converts data to [][column_index]value // it converts data to [][column_index]value
// for example: // for example:
// var list []ParamsList // var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene" // qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error) ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface. // ValuesFlat query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value. // it's designed for one column record set, auto change to []value, not [][column]value.
// for example: // for example:
// var list ParamsList // var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene" // qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error) ValuesFlat(result *ParamsList, expr string) (int64, error)
ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name. // RowsToMap query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value" // keyCol = "name", valueCol = "value"
// table data // table data
// name | value // name | value
@ -469,7 +465,7 @@ type QuerySeter interface {
// "found": 200, // "found": 200,
// } // }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error) RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name. // RowsToStruct query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value" // keyCol = "name", valueCol = "value"
// table data // table data
// name | value // name | value
@ -480,7 +476,7 @@ type QuerySeter interface {
// Found int // Found int
// } // }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// aggregate func. // Aggregate aggregate func.
// for example: // for example:
// type result struct { // type result struct {
// DeptName string // DeptName string
@ -494,7 +490,7 @@ type QuerySeter interface {
// QueryM2Mer model to model query struct // QueryM2Mer model to model query struct
// all operations are on the m2m table only, will not affect the origin model table // all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface { type QueryM2Mer interface {
// add models to origin models when creating queryM2M. // Add adds models to origin models when creating queryM2M.
// example: // example:
// m2m := orm.QueryM2M(post,"Tag") // m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{}) // m2m.Add(&Tag1{},&Tag2{})
@ -507,20 +503,20 @@ type QueryM2Mer interface {
// make sure the relation is defined in post model struct tag. // make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
AddWithCtx(context.Context, ...interface{}) (int64, error) AddWithCtx(context.Context, ...interface{}) (int64, error)
// remove models following the origin model relationship // Remove removes models following the origin model relationship
// only delete rows from m2m table // only delete rows from m2m table
// for example: // for example:
// tag3 := &Tag{Id:5,Name: "TestTag3"} // tag3 := &Tag{Id:5,Name: "TestTag3"}
// num, err = m2m.Remove(tag3) // num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
RemoveWithCtx(context.Context, ...interface{}) (int64, error) RemoveWithCtx(context.Context, ...interface{}) (int64, error)
// check model is existed in relationship of origin model // Exist checks model is existed in relationship of origin model
Exist(interface{}) bool Exist(interface{}) bool
ExistWithCtx(context.Context, interface{}) bool ExistWithCtx(context.Context, interface{}) bool
// clean all models in related of origin model // Clear cleans all models in related of origin model
Clear() (int64, error) Clear() (int64, error)
ClearWithCtx(context.Context) (int64, error) ClearWithCtx(context.Context) (int64, error)
// count all related models of origin model // Count counts all related models of origin model
Count() (int64, error) Count() (int64, error)
CountWithCtx(context.Context) (int64, error) CountWithCtx(context.Context) (int64, error)
} }
@ -538,32 +534,32 @@ type RawPreparer interface {
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) // sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1) // rs := Ormer.Raw(sql, 1)
type RawSeter interface { type RawSeter interface {
// execute sql and get result // Exec execute sql and get result
Exec() (sql.Result, error) Exec() (sql.Result, error)
// query data and map to container // QueryRow query data and map to container
// for example: // for example:
// var name string // var name string
// var id int // var id int
// rs.QueryRow(&id,&name) // id==2 name=="slene" // rs.QueryRow(&id,&name) // id==2 name=="slene"
QueryRow(containers ...interface{}) error QueryRow(containers ...interface{}) error
// query data rows and map to container // QueryRows query data rows and map to container
// var ids []int // var ids []int
// var names []int // var names []int
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
QueryRows(containers ...interface{}) (int64, error) QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter SetArgs(...interface{}) RawSeter
// query data to []map[string]interface // Values query data to []map[string]interface
// see QuerySeter's Values // see QuerySeter's Values
Values(container *[]Params, cols ...string) (int64, error) Values(container *[]Params, cols ...string) (int64, error)
// query data to [][]interface // ValuesList query data to [][]interface
// see QuerySeter's ValuesList // see QuerySeter's ValuesList
ValuesList(container *[]ParamsList, cols ...string) (int64, error) ValuesList(container *[]ParamsList, cols ...string) (int64, error)
// query data to []interface // ValuesFlat query data to []interface
// see QuerySeter's ValuesFlat // see QuerySeter's ValuesFlat
ValuesFlat(container *ParamsList, cols ...string) (int64, error) ValuesFlat(container *ParamsList, cols ...string) (int64, error)
// query all rows into map[string]interface with specify key and value column name. // RowsToMap query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value" // keyCol = "name", valueCol = "value"
// table data // table data
// name | value // name | value
@ -574,7 +570,7 @@ type RawSeter interface {
// "found": 200, // "found": 200,
// } // }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error) RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name. // RowsToStruct query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value" // keyCol = "name", valueCol = "value"
// table data // table data
// name | value // name | value
@ -586,7 +582,7 @@ type RawSeter interface {
// } // }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// return prepared raw statement for used in times. // Prepare return prepared raw statement for used in times.
// for example: // for example:
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
@ -626,32 +622,32 @@ type dbQuerier interface {
// base database struct // base database struct
type dbBaser interface { type dbBaser interface {
Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error Read(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string, bool) error
ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) ReadBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) Count(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertOrUpdate(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) InsertMulti(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertValue(context.Context, dbQuerier, *models.ModelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(context.Context, stmtQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error)
Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error)
UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, Params, *time.Location) (int64, error)
Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Delete(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error)
DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error)
SupportUpdateJoin() bool SupportUpdateJoin() bool
OperatorSQL(string) string OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*models.ModelInfo, *models.FieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*models.FieldInfo, string, *string)
PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(context.Context, dbQuerier, *models.ModelInfo) (stmtQuerier, string, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool HasReturningID(*models.ModelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location) TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location) TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string DbTypes() map[string]string
@ -660,8 +656,8 @@ type dbBaser interface {
ShowTablesQuery() string ShowTablesQuery() string
ShowColumnsQuery(string) string ShowColumnsQuery(string) string
IndexExists(context.Context, dbQuerier, string, string) bool IndexExists(context.Context, dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*models.ModelInfo, *models.FieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(context.Context, dbQuerier, *modelInfo, []string) error setval(context.Context, dbQuerier, *models.ModelInfo, []string) error
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
} }

View File

@ -15,305 +15,15 @@
package orm package orm
import ( import (
"fmt" "github.com/beego/beego/v2/client/orm/internal/models"
"math/big" "github.com/beego/beego/v2/client/orm/internal/utils"
"reflect"
"strconv"
"strings"
"time"
) )
type fn func(string) string type StrTo = utils.StrTo
var (
nameStrategyMap = map[string]fn{
defaultNameStrategy: snakeString,
SnakeAcronymNameStrategy: snakeStringWithAcronym,
}
defaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
nameStrategy = defaultNameStrategy
)
// StrTo is the target string
type StrTo string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(rune(0x1E))
}
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(rune(0x1E))
}
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
return v, err
}
return ni.Int64(), nil
}
return v, err
}
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// Uint32 string to uint32
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
return v, err
}
return ni.Uint64(), nil
}
return v, err
}
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
case string:
s = v
case []byte:
s = string(v)
default:
s = fmt.Sprintf("%v", v)
}
return s
}
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
}
return
}
func snakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// snake string, XxYy to xx_yy , XxYY to xx_y_y
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data))
}
// SetNameStrategy set different name strategy
func SetNameStrategy(s string) { func SetNameStrategy(s string) {
if SnakeAcronymNameStrategy != s { if models.SnakeAcronymNameStrategy != s {
nameStrategy = defaultNameStrategy models.NameStrategy = models.DefaultNameStrategy
}
nameStrategy = s
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
if d == '_' {
flag = true
continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
}
data = append(data, d)
}
return string(data)
}
type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
return indirectType(v.Elem())
default:
return v
} }
models.NameStrategy = s
} }

1
go.mod
View File

@ -31,6 +31,7 @@ require (
github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
github.com/valyala/bytebufferpool v1.0.0
go.etcd.io/etcd/client/v3 v3.5.9 go.etcd.io/etcd/client/v3 v3.5.9
go.opentelemetry.io/otel v1.11.2 go.opentelemetry.io/otel v1.11.2
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2 go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2

2
go.sum
View File

@ -197,6 +197,8 @@ github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112/go.mod h1:Z4AUp2K
github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg=
github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=