diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c639490..27ac0cc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ - [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232) - [remove adapter package](https://github.com/beego/beego/pull/5239) +## ORM refactoring +- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) + # v2.1.0 - [unified gopkg.in/yaml version to v2](https://github.com/beego/beego/pull/5169) - [add non-block write log in asynchronous mode](https://github.com/beego/beego/pull/5150) diff --git a/LICENSE b/LICENSE index 26050108..b947dac3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2014 astaxie +Copyright 2014 Beego Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/client/orm/cmd.go b/client/orm/cmd.go index 9819badb..cd6fd5cc 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -20,6 +20,10 @@ import ( "fmt" "os" "strings" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" ) type commander interface { @@ -53,7 +57,7 @@ func RunCommand() { BootStrap() - args := argString(os.Args[2:]) + args := utils.ArgString(os.Args[2:]) name := args.Get(0) if name == "help" { @@ -112,7 +116,7 @@ func (d *commandSyncDb) Run() error { for i, mi := range defaultModelCache.allOrdered() { query := drops[i] if !d.noInfo { - fmt.Printf("drop table `%s`\n", mi.table) + fmt.Printf("drop table `%s`\n", mi.Table) } _, err := db.Exec(query) if d.verbose { @@ -143,18 +147,18 @@ func (d *commandSyncDb) Run() error { ctx := context.Background() for i, mi := range defaultModelCache.allOrdered() { - if !isApplicableTableForDB(mi.addrField, d.al.Name) { - fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) + if !models.IsApplicableTableForDB(mi.AddrField, d.al.Name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.Table, d.al.Name) continue } - if tables[mi.table] { + if tables[mi.Table] { if !d.noInfo { - fmt.Printf("table `%s` already exists, skip\n", mi.table) + fmt.Printf("table `%s` already exists, skip\n", mi.Table) } - var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) + var fields []*models.FieldInfo + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.Table) if err != nil { if d.rtOnError { return err @@ -162,8 +166,8 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } - for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; !ok { + for _, fi := range mi.Fields.FieldsDB { + if _, ok := columns[fi.Column]; !ok { fields = append(fields, fi) } } @@ -172,7 +176,7 @@ func (d *commandSyncDb) Run() error { query := getColumnAddQuery(d.al, fi) if !d.noInfo { - fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + fmt.Printf("add column `%s` for table `%s`\n", fi.FullName, mi.Table) } _, err := db.Exec(query) @@ -187,7 +191,7 @@ func (d *commandSyncDb) Run() error { } } - for _, idx := range indexes[mi.table] { + for _, idx := range indexes[mi.Table] { if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) @@ -211,11 +215,11 @@ func (d *commandSyncDb) Run() error { } if !d.noInfo { - fmt.Printf("create table `%s` \n", mi.table) + fmt.Printf("create table `%s` \n", mi.Table) } queries := []string{createQueries[i]} - for _, idx := range indexes[mi.table] { + for _, idx := range indexes[mi.Table] { queries = append(queries, idx.SQL) } @@ -265,7 +269,7 @@ func (d *commandSQLAll) Run() error { var all []string for i, mi := range defaultModelCache.allOrdered() { queries := []string{createQueries[i]} - for _, idx := range indexes[mi.table] { + for _, idx := range indexes[mi.Table] { queries = append(queries, idx.SQL) } sql := strings.Join(queries, "\n") diff --git a/client/orm/cmd_utils.go b/client/orm/cmd_utils.go index 7b795b22..9c9f1303 100644 --- a/client/orm/cmd_utils.go +++ b/client/orm/cmd_utils.go @@ -17,6 +17,8 @@ package orm import ( "fmt" "strings" + + "github.com/beego/beego/v2/client/orm/internal/models" ) type dbIndex struct { @@ -26,17 +28,17 @@ type dbIndex struct { } // get database column type string. -func getColumnTyp(al *alias, fi *fieldInfo) (col string) { +func getColumnTyp(al *alias, fi *models.FieldInfo) (col string) { T := al.DbBaser.DbTypes() - fieldType := fi.fieldType - fieldSize := fi.size + fieldType := fi.FieldType + fieldSize := fi.Size checkColumn: switch fieldType { case TypeBooleanField: col = T["bool"] case TypeVarCharField: - if al.Driver == DRPostgres && fi.toText { + if al.Driver == DRPostgres && fi.ToText { col = T["string-text"] } else { col = fmt.Sprintf(T["string"], fieldSize) @@ -51,11 +53,11 @@ checkColumn: col = T["time.Time-date"] case TypeDateTimeField: // the precision of sqlite is not implemented - if al.Driver == 2 || fi.timePrecision == nil { + if al.Driver == 2 || fi.TimePrecision == nil { col = T["time.Time"] } else { s := T["time.Time-precision"] - col = fmt.Sprintf(s, *fi.timePrecision) + col = fmt.Sprintf(s, *fi.TimePrecision) } case TypeBitField: @@ -85,7 +87,7 @@ checkColumn: if !strings.Contains(s, "%d") { col = s } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) + col = fmt.Sprintf(s, fi.Digits, fi.Decimals) } case TypeJSONField: if al.Driver != DRPostgres { @@ -100,8 +102,8 @@ checkColumn: } col = T["jsonb"] case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - fieldSize = fi.relModelInfo.fields.pk.size + fieldType = fi.RelModelInfo.Fields.Pk.FieldType + fieldSize = fi.RelModelInfo.Fields.Pk.Size goto checkColumn } @@ -109,34 +111,34 @@ checkColumn: } // create alter sql string. -func getColumnAddQuery(al *alias, fi *fieldInfo) string { +func getColumnAddQuery(al *alias, fi *models.FieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) - if !fi.null { + if !fi.Null { typ += " " + "NOT NULL" } return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, + Q, fi.Mi.Table, Q, + Q, fi.Column, Q, typ, getColumnDefault(fi), ) } // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands -func getColumnDefault(fi *fieldInfo) string { +func getColumnDefault(fi *models.FieldInfo) string { var v, t, d string // Skip default attribute if field is in relations - if fi.rel || fi.reverse { + if fi.Rel || fi.Reverse { return v } t = " DEFAULT '%s' " // These defaults will be useful if there no config value orm:"default" and NOT NULL is on - switch fi.fieldType { + switch fi.FieldType { case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: return v @@ -153,14 +155,14 @@ func getColumnDefault(fi *fieldInfo) string { d = "{}" } - if fi.colDefault { - if !fi.initial.Exist() { + if fi.ColDefault { + if !fi.Initial.Exist() { v = fmt.Sprintf(t, "") } else { - v = fmt.Sprintf(t, fi.initial.String()) + v = fmt.Sprintf(t, fi.Initial.String()) } } else { - if !fi.null { + if !fi.Null { v = fmt.Sprintf(t, d) } } diff --git a/client/orm/db.go b/client/orm/db.go index cbaa81ad..e4b53e36 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -23,13 +23,13 @@ import ( "strings" "time" - "github.com/beego/beego/v2/client/orm/hints" -) + "github.com/beego/beego/v2/client/orm/internal/logs" -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/hints" ) // ErrMissPK missing pk error @@ -72,8 +72,8 @@ type dbBase struct { // check dbBase implements dbBaser interface. var _ dbBaser = new(dbBase) -// get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { +// get struct Columns values as interface slice. +func (d *dbBase) collectValues(mi *models.ModelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { if names == nil { ns := make([]string, 0, len(cols)) names = &ns @@ -81,13 +81,13 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, values = make([]interface{}, 0, len(cols)) for _, column := range cols { - var fi *fieldInfo - if fi, _ = mi.fields.GetByAny(column); fi != nil { - column = fi.column + var fi *models.FieldInfo + if fi, _ = mi.Fields.GetByAny(column); fi != nil { + column = fi.Column } else { - panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.FullName)) } - if !fi.dbcol || fi.auto && skipAuto { + if !fi.DBcol || fi.Auto && skipAuto { continue } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) @@ -96,8 +96,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } // ignore empty value auto field - if insert && fi.auto { - if fi.fieldType&IsPositiveIntegerField > 0 { + if insert && fi.Auto { + if fi.FieldType&IsPositiveIntegerField > 0 { if vu, ok := value.(uint64); !ok || vu == 0 { continue } @@ -106,7 +106,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, continue } } - autoFields = append(autoFields, fi.column) + autoFields = append(autoFields, fi.Column) } *names, values = append(*names, column), append(values, value) @@ -116,17 +116,17 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } // get one field value in struct column as interface. -func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { +func (d *dbBase) collectFieldValue(mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { var value interface{} - if fi.pk { + if fi.Pk { _, value, _ = getExistPk(mi, ind) } else { - field := ind.FieldByIndex(fi.fieldIndex) - if fi.isFielder { - f := field.Addr().Interface().(Fielder) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.IsFielder { + f := field.Addr().Interface().(models.Fielder) value = f.RawValue() } else { - switch fi.fieldType { + switch fi.FieldType { case TypeBooleanField: if nb, ok := field.Interface().(sql.NullBool); ok { value = nil @@ -172,7 +172,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { vu := field.Interface() if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() + value, _ = utils.StrTo(utils.ToStr(vu)).Float64() } else { value = field.Float() } @@ -189,7 +189,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } default: switch { - case fi.fieldType&IsPositiveIntegerField > 0: + case fi.FieldType&IsPositiveIntegerField > 0: if field.Kind() == reflect.Ptr { if field.IsNil() { value = nil @@ -199,7 +199,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { value = field.Uint() } - case fi.fieldType&IsIntegerField > 0: + case fi.FieldType&IsIntegerField > 0: if ni, ok := field.Interface().(sql.NullInt64); ok { value = nil if ni.Valid { @@ -214,25 +214,25 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { value = field.Int() } - case fi.fieldType&IsRelField > 0: + case fi.FieldType&IsRelField > 0: if field.IsNil() { value = nil } else { - if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { + if _, vu, ok := getExistPk(fi.RelModelInfo, reflect.Indirect(field)); ok { value = vu } else { value = nil } } - if !fi.null && value == nil { - return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) + if !fi.Null && value == nil { + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.FullName) } } } } - switch fi.fieldType { + switch fi.FieldType { case TypeTimeField, TypeDateField, TypeDateTimeField: - if fi.autoNow || fi.autoNowAdd && insert { + if fi.AutoNow || fi.AutoNowAdd && insert { if insert { if t, ok := value.(time.Time); ok && !t.IsZero() { break @@ -241,8 +241,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val tnow := time.Now() d.ins.TimeToDB(&tnow, tz) value = tnow - if fi.isFielder { - f := field.Addr().Interface().(Fielder) + if fi.IsFielder { + f := field.Addr().Interface().(models.Fielder) f.SetRaw(tnow.In(DefaultTimeLoc)) } else if field.Kind() == reflect.Ptr { v := tnow.In(DefaultTimeLoc) @@ -253,8 +253,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } case TypeJSONField, TypeJsonbField: if s, ok := value.(string); (ok && len(s) == 0) || value == nil { - if fi.colDefault && fi.initial.Exist() { - value = fi.initial.String() + if fi.ColDefault && fi.Initial.Exist() { + value = fi.Initial.String() } else { value = nil } @@ -265,14 +265,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // PrepareInsert create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *models.ModelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() - dbcols := make([]string, 0, len(mi.fields.dbcols)) - marks := make([]string, 0, len(mi.fields.dbcols)) - for _, fi := range mi.fields.fieldsDB { - if !fi.auto { - dbcols = append(dbcols, fi.column) + dbcols := make([]string, 0, len(mi.Fields.DBcols)) + marks := make([]string, 0, len(mi.Fields.DBcols)) + for _, fi := range mi.Fields.FieldsDB { + if !fi.Auto { + dbcols = append(dbcols, fi.Column) marks = append(marks, "?") } } @@ -280,7 +280,7 @@ func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) sep := fmt.Sprintf("%s, %s", Q, Q) columns := strings.Join(dbcols, sep) - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) @@ -291,8 +291,8 @@ func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) } // InsertStmt insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, nil, tz) if err != nil { return 0, err } @@ -311,7 +311,7 @@ func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -336,8 +336,8 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle Q := d.ins.TableQuote() sep := fmt.Sprintf("%s, %s", Q, Q) - sels := strings.Join(mi.fields.dbcols, sep) - colsNum := len(mi.fields.dbcols) + sels := strings.Join(mi.Fields.DBcols, sep) + colsNum := len(mi.Fields.DBcols) sep = fmt.Sprintf("%s = ? AND %s", Q, Q) wheres := strings.Join(whereCols, sep) @@ -347,7 +347,7 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle forUpdate = "FOR UPDATE" } - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.Table, Q, Q, wheres, Q, forUpdate) refs := make([]interface{}, colsNum) for i := range refs { @@ -364,17 +364,17 @@ func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind refle } return err } - elm := reflect.New(mi.addrField.Elem().Type()) + elm := reflect.New(mi.AddrField.Elem().Type()) mind := reflect.Indirect(elm) - d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + d.setColsValues(mi, &mind, mi.Fields.DBcols, refs, tz) ind.Set(mind) return nil } // Insert execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)) - values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names := make([]string, 0, len(mi.Fields.DBcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz) if err != nil { return 0, err } @@ -391,7 +391,7 @@ func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref } // InsertMulti multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *models.ModelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -399,32 +399,25 @@ func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, si names []string ) - // typ := reflect.Indirect(mi.addrField).Type() - length, autoFields := sind.Len(), make([]string, 0, 1) for i := 1; i <= length; i++ { ind := reflect.Indirect(sind.Index(i - 1)) - // Is this needed ? - // if !ind.Type().AssignableTo(typ) { - // return cnt, ErrArgs - // } - if i == 1 { var ( vus []interface{} err error ) - vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + vus, autoFields, err = d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz) if err != nil { return cnt, err } values = make([]interface{}, bulk*len(vus)) nums += copy(values, vus) } else { - vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + vus, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, nil, tz) if err != nil { return cnt, err } @@ -456,7 +449,7 @@ func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, si // InsertValue execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -474,7 +467,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks } - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) @@ -487,7 +480,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable } else { return lastInsertId, nil @@ -504,7 +497,7 @@ func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, is // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -530,9 +523,9 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, } isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) + names := make([]string, 0, len(mi.Fields.DBcols)-1) Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ) if err != nil { return 0, err } @@ -556,7 +549,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, case DRPostgres: if conflitValue != nil { // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) + updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.Table, args0) updateValues = append(updateValues, conflitValue) } else { return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) @@ -581,7 +574,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks } // conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr) d.ins.ReplaceMarks(&query) @@ -594,7 +587,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable } else { return lastInsertId, nil @@ -613,7 +606,7 @@ func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, } // Update execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -621,10 +614,10 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref var setNames []string - // if specify cols length is zero, then commit all columns. + // if specify cols length is zero, then commit all Columns. if len(cols) == 0 { - cols = mi.fields.dbcols - setNames = make([]string, 0, len(mi.fields.dbcols)-1) + cols = mi.Fields.DBcols + setNames = make([]string, 0, len(mi.Fields.DBcols)-1) } else { setNames = make([]string, 0, len(cols)) } @@ -637,11 +630,11 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref var findAutoNowAdd, findAutoNow bool var index int for i, col := range setNames { - if mi.fields.GetByColumn(col).autoNowAdd { + if mi.Fields.GetByColumn(col).AutoNowAdd { index = i findAutoNowAdd = true } - if mi.fields.GetByColumn(col).autoNow { + if mi.Fields.GetByColumn(col).AutoNow { findAutoNow = true } } @@ -651,8 +644,8 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref } if !findAutoNow { - for col, info := range mi.fields.columns { - if info.autoNow { + for col, info := range mi.Fields.Columns { + if info.AutoNow { setNames = append(setNames, col) setValues = append(setValues, time.Now()) } @@ -666,7 +659,7 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref sep := fmt.Sprintf("%s = ?, %s", Q, Q) setColumns := strings.Join(setNames, sep) - query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.Table, Q, Q, setColumns, Q, Q, pkName, Q) d.ins.ReplaceMarks(&query) @@ -679,7 +672,7 @@ func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref // Delete execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -705,7 +698,7 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref sep := fmt.Sprintf("%s = ? AND %s", Q, Q) wheres := strings.Join(whereCols, sep) - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.Table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) res, err := q.ExecContext(ctx, query, args...) @@ -727,14 +720,14 @@ func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind ref // UpdateBatch update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { + if fi, ok := mi.Fields.GetByAny(col); !ok || !fi.DBcol { panic(fmt.Errorf("wrong field/column name `%s`", col)) } else { - columns = append(columns, fi.column) + columns = append(columns, fi.Column) values = append(values, val) } } @@ -747,7 +740,7 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) - specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) + specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) } where, args := tables.getCondSQL(cond, false, tz) @@ -798,13 +791,13 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi sets := strings.Join(cols, ", ") + " " if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where) + query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.Table, Q, specifyIndexes, join, sets, where) } else { supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s", - Q, mi.fields.pk.column, Q, - Q, mi.table, Q, + Q, mi.Fields.Pk.Column, Q, + Q, mi.Table, Q, specifyIndexes, join, where) - query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) + query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.Table, Q, sets, Q, mi.Fields.Pk.Column, Q, supQuery) } d.ins.ReplaceMarks(&query) @@ -817,41 +810,41 @@ func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { - for _, fi := range mi.fields.fieldsReverse { - fi = fi.reverseFieldInfo - switch fi.onDelete { - case odCascade: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *models.ModelInfo, args []interface{}, tz *time.Location) error { + for _, fi := range mi.Fields.FieldsReverse { + fi = fi.ReverseFieldInfo + switch fi.OnDelete { + case models.OdCascade: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...) + _, err := d.DeleteBatch(ctx, q, nil, fi.Mi, cond, tz) if err != nil { return err } - case odSetDefault, odSetNULL: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - params := Params{fi.column: nil} - if fi.onDelete == odSetDefault { - params[fi.column] = fi.initial.String() + case models.OdSetDefault, models.OdSetNULL: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...) + params := Params{fi.Column: nil} + if fi.OnDelete == models.OdSetDefault { + params[fi.Column] = fi.Initial.String() } - _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.Mi, cond, params, tz) if err != nil { return err } - case odDoNothing: + case models.OdDoNothing: } } return nil } // DeleteBatch delete table-related records. -func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) - specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) + specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) } if cond == nil || cond.IsEmpty() { @@ -863,8 +856,8 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi where, args := tables.getCondSQL(cond, false, tz) join := tables.getJoinSQL() - cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where) + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.Fields.Pk.Column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.Table, Q, specifyIndexes, join, where) d.ins.ReplaceMarks(&query) @@ -883,7 +876,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi if err := rs.Scan(&ref); err != nil { return 0, err } - pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + pkValue, err := d.convertValueFromDB(mi.Fields.Pk, reflect.ValueOf(ref).Interface(), tz) if err != nil { return 0, err } @@ -900,7 +893,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi marks[i] = "?" } sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.Table, Q, Q, mi.Fields.Pk.Column, Q, sqlIn) d.ins.ReplaceMarks(&query) res, err := q.ExecContext(ctx, query, args...) @@ -921,7 +914,7 @@ func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi } // ReadBatch read related records. -func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -937,17 +930,17 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m typ := ind.Type().Elem() switch typ.Kind() { case reflect.Ptr: - fn = getFullName(typ.Elem()) + fn = models.GetFullName(typ.Elem()) case reflect.Struct: isPtr = false - fn = getFullName(typ) - name = getTableName(reflect.New(typ)) + fn = models.GetFullName(typ) + name = models.GetTableName(reflect.New(typ)) } } else { - fn = getFullName(ind.Type()) - name = getTableName(ind) + fn = models.GetFullName(ind.Type()) + name = models.GetTableName(ind) } - unregister = fn != mi.fullName + unregister = fn != mi.FullName } if unregister { @@ -968,26 +961,26 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m maps = make(map[string]bool) } for _, col := range cols { - if fi, ok := mi.fields.GetByAny(col); ok { - tCols = append(tCols, fi.column) + if fi, ok := mi.Fields.GetByAny(col); ok { + tCols = append(tCols, fi.Column) if hasRel { - maps[fi.column] = true + maps[fi.Column] = true } } else { return 0, fmt.Errorf("wrong field/column name `%s`", col) } } if hasRel { - for _, fi := range mi.fields.fieldsDB { - if fi.fieldType&IsRelField > 0 { - if !maps[fi.column] { - tCols = append(tCols, fi.column) + for _, fi := range mi.Fields.FieldsDB { + if fi.FieldType&IsRelField > 0 { + if !maps[fi.Column] { + tCols = append(tCols, fi.Column) } } } } } else { - tCols = mi.fields.dbcols + tCols = mi.Fields.DBcols } colsNum := len(tCols) @@ -1002,13 +995,13 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, offset, rlimit) join := tables.getJoinSQL() - specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) + specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) for _, tbl := range tables.tables { if tbl.sel { - colsNum += len(tbl.mi.fields.dbcols) + colsNum += len(tbl.mi.Fields.DBcols) sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) - sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.Fields.DBcols, sep), Q) } } @@ -1020,7 +1013,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m sels = qs.aggregate } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", - sqlSelect, sels, Q, mi.table, Q, + sqlSelect, sels, Q, mi.Table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) if qs.forUpdate { @@ -1039,7 +1032,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m slice := ind if unregister { mi, _ = defaultModelCache.get(name) - tCols = mi.fields.dbcols + tCols = mi.Fields.DBcols colsNum = len(tCols) } @@ -1055,11 +1048,11 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m return 0, err } - elm := reflect.New(mi.addrField.Elem().Type()) + elm := reflect.New(mi.AddrField.Elem().Type()) mind := reflect.Indirect(elm) cacheV := make(map[string]*reflect.Value) - cacheM := make(map[string]*modelInfo) + cacheM := make(map[string]*models.ModelInfo) trefs := refs d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) @@ -1078,18 +1071,18 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m last = *val mmi = cacheM[names] } else { - fi := mmi.fields.GetByName(name) + fi := mmi.Fields.GetByName(name) lastm := mmi - mmi = fi.relModelInfo + mmi = fi.RelModelInfo field := last if last.Kind() != reflect.Invalid { - field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) + field = reflect.Indirect(last.FieldByIndex(fi.FieldIndex)) if field.IsValid() { - d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) - for _, fi := range mmi.fields.fieldsReverse { - if fi.inModel && fi.reverseFieldInfo.mi == lastm { - if fi.reverseFieldInfo != nil { - f := field.FieldByIndex(fi.fieldIndex) + d.setColsValues(mmi, &field, mmi.Fields.DBcols, trefs[:len(mmi.Fields.DBcols)], tz) + for _, fi := range mmi.Fields.FieldsReverse { + if fi.InModel && fi.ReverseFieldInfo.Mi == lastm { + if fi.ReverseFieldInfo != nil { + f := field.FieldByIndex(fi.FieldIndex) if f.Kind() == reflect.Ptr { f.Set(last.Addr()) } @@ -1103,7 +1096,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m cacheM[names] = mmi } } - trefs = trefs[len(mmi.fields.dbcols):] + trefs = trefs[len(mmi.Fields.DBcols):] } } @@ -1146,7 +1139,7 @@ func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *m } // Count excute count sql and return count result int64. -func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1154,12 +1147,12 @@ func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *model groupBy := tables.getGroupSQL(qs.groups) tables.getOrderSQL(qs.orders) join := tables.getJoinSQL() - specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) + specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) Q := d.ins.TableQuote() query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", - Q, mi.table, Q, + Q, mi.Table, Q, specifyIndexes, join, where, groupBy) if groupBy != "" { @@ -1174,7 +1167,7 @@ func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *model } // GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. -func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { +func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { var sql string params := getFlatParams(fi, args, tz) @@ -1206,7 +1199,7 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri params[0] = "IS NULL" } case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": - param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + param := strings.Replace(utils.ToStr(arg), `%`, `\%`, -1) switch operator { case "iexact": case "contains", "icontains": @@ -1234,18 +1227,18 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri } // GenerateOperatorLeftCol gernerate sql string with inner function, such as UPPER(text). -func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { +func (d *dbBase) GenerateOperatorLeftCol(*models.FieldInfo, string, *string) { // default not use } // set values to struct column. -func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { +func (d *dbBase) setColsValues(mi *models.ModelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() - fi := mi.fields.GetByColumn(column) + fi := mi.Fields.GetByColumn(column) - field := ind.FieldByIndex(fi.fieldIndex) + field := ind.FieldByIndex(fi.FieldIndex) value, err := d.convertValueFromDB(fi, val, tz) if err != nil { @@ -1261,7 +1254,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } // convert value from database result to value following in field type. -func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { +func (d *dbBase) convertValueFromDB(fi *models.FieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil } @@ -1269,17 +1262,17 @@ func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Loc var value interface{} var tErr error - var str *StrTo + var str *utils.StrTo switch v := val.(type) { case []byte: - s := StrTo(string(v)) + s := utils.StrTo(string(v)) str = &s case string: - s := StrTo(v) + s := utils.StrTo(v) str = &s } - fieldType := fi.fieldType + fieldType := fi.FieldType setValue: switch { @@ -1290,7 +1283,7 @@ setValue: b := v == 1 value = b default: - s := StrTo(ToStr(v)) + s := utils.StrTo(utils.ToStr(v)) str = &s } } @@ -1304,7 +1297,7 @@ setValue: } case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: if str == nil { - value = ToStr(val) + value = utils.ToStr(val) } else { value = str.String() } @@ -1315,7 +1308,7 @@ setValue: d.ins.TimeFromDB(&t, tz) value = t default: - s := StrTo(ToStr(t)) + s := utils.StrTo(utils.ToStr(t)) str = &s } } @@ -1326,25 +1319,25 @@ setValue: err error ) - if fi.timePrecision != nil && len(s) >= (20+*fi.timePrecision) { - layout := formatDateTime + "." - for i := 0; i < *fi.timePrecision; i++ { + if fi.TimePrecision != nil && len(s) >= (20+*fi.TimePrecision) { + layout := utils.FormatDateTime + "." + for i := 0; i < *fi.TimePrecision; i++ { layout += "0" } - t, err = time.ParseInLocation(layout, s[:20+*fi.timePrecision], tz) + t, err = time.ParseInLocation(layout, s[:20+*fi.TimePrecision], tz) } else if len(s) >= 19 { s = s[:19] - t, err = time.ParseInLocation(formatDateTime, s, tz) + t, err = time.ParseInLocation(utils.FormatDateTime, s, tz) } else if len(s) >= 10 { if len(s) > 10 { s = s[:10] } - t, err = time.ParseInLocation(formatDate, s, tz) + t, err = time.ParseInLocation(utils.FormatDate, s, tz) } else if len(s) >= 8 { if len(s) > 8 { s = s[:8] } - t, err = time.ParseInLocation(formatTime, s, tz) + t, err = time.ParseInLocation(utils.FormatTime, s, tz) } t = t.In(DefaultTimeLoc) @@ -1356,7 +1349,7 @@ setValue: } case fieldType&IsIntegerField > 0: if str == nil { - s := StrTo(ToStr(val)) + s := utils.StrTo(utils.ToStr(val)) str = &s } if str != nil { @@ -1397,7 +1390,7 @@ setValue: case float64: value = v default: - s := StrTo(ToStr(v)) + s := utils.StrTo(utils.ToStr(v)) str = &s } } @@ -1410,14 +1403,14 @@ setValue: value = v } case fieldType&IsRelField > 0: - fi = fi.relModelInfo.fields.pk - fieldType = fi.fieldType + fi = fi.RelModelInfo.Fields.Pk + fieldType = fi.FieldType goto setValue } end: if tErr != nil { - err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.AddrValue.Type(), fi.FullName, tErr) return nil, err } @@ -1425,9 +1418,9 @@ end: } // set one value to struct column field. -func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { - fieldType := fi.fieldType - isNative := !fi.isFielder +func (d *dbBase) setFieldValue(fi *models.FieldInfo, value interface{}, field reflect.Value) (interface{}, error) { + fieldType := fi.FieldType + isNative := !fi.IsFielder setValue: switch { @@ -1594,20 +1587,20 @@ setValue: } case fieldType&IsRelField > 0: if value != nil { - fieldType = fi.relModelInfo.fields.pk.fieldType - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + fieldType = fi.RelModelInfo.Fields.Pk.FieldType + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + f := mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) field = f goto setValue } } if !isNative { - fd := field.Addr().Interface().(Fielder) + fd := field.Addr().Interface().(models.Fielder) err := fd.SetRaw(value) if err != nil { - err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) + err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.FullName, err) return nil, err } } @@ -1616,7 +1609,7 @@ setValue: } // ReadValues query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params lists []ParamsList @@ -1651,7 +1644,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * var ( cols []string - infos []*fieldInfo + infos []*models.FieldInfo ) hasExprs := len(exprs) > 0 @@ -1660,20 +1653,20 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * if hasExprs { cols = make([]string, 0, len(exprs)) - infos = make([]*fieldInfo, 0, len(exprs)) + infos = make([]*models.FieldInfo, 0, len(exprs)) for _, ex := range exprs { index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) if !suc { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } - cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.Column, Q, Q, name, Q)) infos = append(infos, fi) } } else { - cols = make([]string, 0, len(mi.fields.dbcols)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, fi := range mi.fields.fieldsDB { - cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) + cols = make([]string, 0, len(mi.Fields.DBcols)) + infos = make([]*models.FieldInfo, 0, len(exprs)) + for _, fi := range mi.Fields.FieldsDB { + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.Column, Q, Q, fi.Name, Q)) infos = append(infos, fi) } } @@ -1683,7 +1676,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * orderBy := tables.getOrderSQL(qs.orders) limit := tables.getLimitSQL(mi, qs.offset, qs.limit) join := tables.getJoinSQL() - specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) + specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) sels := strings.Join(cols, ", ") @@ -1693,7 +1686,7 @@ func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi * } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", sqlSelect, sels, - Q, mi.table, Q, + Q, mi.Table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -1808,12 +1801,12 @@ func (d *dbBase) ReplaceMarks(query *string) { } // flag of RETURNING sql. -func (d *dbBase) HasReturningID(*modelInfo, *string) bool { +func (d *dbBase) HasReturningID(*models.ModelInfo, *string) bool { return false } // sync auto key -func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error { return nil } @@ -1923,7 +1916,7 @@ func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes [] case hints.KeyIgnoreIndex: useWay = `IGNORE` default: - DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") return `` } diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go index 28c8ab8e..ff0b962f 100644 --- a/client/orm/db_alias.go +++ b/client/orm/db_alias.go @@ -21,6 +21,8 @@ import ( "sync" "time" + "github.com/beego/beego/v2/client/orm/internal/logs" + lru "github.com/hashicorp/golang-lru" ) @@ -320,7 +322,7 @@ func detectTZ(al *alias) { al.TZ = t.Location() } } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) } } @@ -347,7 +349,7 @@ func detectTZ(al *alias) { if err == nil { al.TZ = loc } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + logs.DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) } } } @@ -479,7 +481,7 @@ end: if db != nil { db.Close() } - DebugLog.Println(err.Error()) + logs.DebugLog.Println(err.Error()) } return err diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index 75d24b2a..889d807f 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -19,6 +19,10 @@ import ( "fmt" "reflect" "strings" + + "github.com/beego/beego/v2/client/orm/internal/logs" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // mysql operators. @@ -72,28 +76,28 @@ type dbBaseMysql struct { var _ dbBaser = new(dbBaseMysql) -// get mysql operator. +// OperatorSQL get mysql operator. func (d *dbBaseMysql) OperatorSQL(operator string) string { return mysqlOperators[operator] } -// get mysql table field types. +// DbTypes get mysql table field types. func (d *dbBaseMysql) DbTypes() map[string]string { return mysqlTypes } -// show table sql for mysql. +// ShowTablesQuery show table sql for mysql. func (d *dbBaseMysql) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } -// show columns sql of table for mysql. +// ShowColumnsQuery show Columns sql of table for mysql. func (d *dbBaseMysql) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } -// execute sql to check index exist. +// IndexExists execute sql to check index exist. func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) @@ -106,7 +110,7 @@ func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table strin // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -120,10 +124,9 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model } } - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) + names := make([]string, 0, len(mi.Fields.DBcols)-1) Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.TZ) if err != nil { return 0, err } @@ -150,26 +153,17 @@ func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *model qupdates := strings.Join(updates, ", ") columns := strings.Join(names, sep) - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } // conflitValue maybe is an int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr) d.ins.ReplaceMarks(&query) - if isMulti || !d.ins.HasReturningID(mi, &query) { + if !d.ins.HasReturningID(mi, &query) { res, err := q.ExecContext(ctx, query, values...) if err == nil { - if isMulti { - return res.RowsAffected() - } - lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable } else { return lastInsertId, nil diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index a3b93ff3..5057f358 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -19,6 +19,10 @@ import ( "fmt" "strings" + "github.com/beego/beego/v2/client/orm/internal/logs" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/hints" ) @@ -116,16 +120,16 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde case hints.KeyIgnoreIndex: hint = `NO_INDEX` default: - DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") return `` } return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`)) } -// execute insert sql with given struct and given values. +// InsertValue execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -143,7 +147,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks } - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) @@ -156,7 +160,7 @@ func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelIn lastInsertId, err := res.LastInsertId() if err != nil { - DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + logs.DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) return lastInsertId, ErrLastInsertIdUnavailable } else { return lastInsertId, nil diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index b2f321db..b52b2578 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -18,6 +18,10 @@ import ( "context" "fmt" "strconv" + + "github.com/beego/beego/v2/client/orm/internal/logs" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // postgresql operators. @@ -76,7 +80,7 @@ func (d *dbBasePostgres) OperatorSQL(operator string) string { } // generate functioned sql string, such as contains(text). -func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) { switch operator { case "contains", "startswith", "endswith": *leftCol = fmt.Sprintf("%s::text", *leftCol) @@ -128,20 +132,20 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { } // make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { - fi := mi.fields.pk - if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { +func (d *dbBasePostgres) HasReturningID(mi *models.ModelInfo, query *string) bool { + fi := mi.Fields.Pk + if fi.FieldType&IsPositiveIntegerField == 0 && fi.FieldType&IsIntegerField == 0 { return false } if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.Column) } return true } // sync auto key -func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -149,9 +153,9 @@ func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo Q := d.ins.TableQuote() for _, name := range autoFields { query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", - mi.table, name, + mi.Table, name, Q, name, Q, - Q, mi.table, Q) + Q, mi.Table, Q) if _, err := db.ExecContext(ctx, query); err != nil { return err } @@ -164,9 +168,9 @@ func (d *dbBasePostgres) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" } -// show table columns sql for postgresql. +// show table Columns sql for postgresql. func (d *dbBasePostgres) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.Columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) } // get column types of postgresql. @@ -185,7 +189,7 @@ func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table st // GenerateSpecifyIndex return a specifying index clause func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { - DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") + logs.DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") return `` } diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index 6a4b3131..e353404e 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -22,6 +22,10 @@ import ( "strings" "time" + "github.com/beego/beego/v2/client/orm/internal/logs" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/hints" ) @@ -74,9 +78,9 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { - DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") + logs.DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } @@ -88,8 +92,8 @@ func (d *dbBaseSqlite) OperatorSQL(operator string) string { // generate functioned sql for sqlite. // only support DATE(text). -func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - if fi.fieldType == TypeDateField { +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) { + if fi.FieldType == TypeDateField { *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) } } @@ -114,7 +118,7 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { return "SELECT name FROM sqlite_master WHERE type = 'table'" } -// get columns in sqlite. +// get Columns in sqlite. func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) rows, err := db.QueryContext(ctx, query) @@ -135,7 +139,7 @@ func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table strin return columns, nil } -// get show columns sql in sqlite. +// get show Columns sql in sqlite. func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { return fmt.Sprintf("pragma table_info('%s')", table) } @@ -171,7 +175,7 @@ func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, inde case hints.KeyUseIndex, hints.KeyForceIndex: return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`)) default: - DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + logs.DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") return `` } } diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index a0b355ca..9d30afb3 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -19,6 +19,8 @@ import ( "strings" "time" + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/clauses" "github.com/beego/beego/v2/client/orm/clauses/order_clause" ) @@ -31,8 +33,8 @@ type dbTable struct { names []string sel bool inner bool - mi *modelInfo - fi *fieldInfo + mi *models.ModelInfo + fi *models.FieldInfo jtl *dbTable } @@ -40,14 +42,14 @@ type dbTable struct { type dbTables struct { tablesM map[string]*dbTable tables []*dbTable - mi *modelInfo + mi *models.ModelInfo base dbBaser skipEnd bool } // set table info to collection. // if not exist, create new. -func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { +func (t *dbTables) set(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) *dbTable { name := strings.Join(names, ExprSep) if j, ok := t.tablesM[name]; ok { j.name = name @@ -64,7 +66,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) } // add table info to collection. -func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { +func (t *dbTables) add(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) if _, ok := t.tablesM[name]; !ok { i := len(t.tables) + 1 @@ -82,29 +84,29 @@ func (t *dbTables) get(name string) (*dbTable, bool) { return j, ok } -// get related fields info in recursive depth loop. +// get related Fields info in recursive depth loop. // loop once, depth decreases one. -func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { - if depth < 0 || fi.fieldType == RelManyToMany { +func (t *dbTables) loopDepth(depth int, prefix string, fi *models.FieldInfo, related []string) []string { + if depth < 0 || fi.FieldType == RelManyToMany { return related } if prefix == "" { - prefix = fi.name + prefix = fi.Name } else { - prefix = prefix + ExprSep + fi.name + prefix = prefix + ExprSep + fi.Name } related = append(related, prefix) depth-- - for _, fi := range fi.relModelInfo.fields.fieldsRel { + for _, fi := range fi.RelModelInfo.Fields.FieldsRel { related = t.loopDepth(depth, prefix, fi, related) } return related } -// parse related fields. +// parse related Fields. func (t *dbTables) parseRelated(rels []string, depth int) { relsNum := len(rels) related := make([]string, relsNum) @@ -117,7 +119,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { } relDepth-- - for _, fi := range t.mi.fields.fieldsRel { + for _, fi := range t.mi.Fields.FieldsRel { related = t.loopDepth(relDepth, "", fi, related) } @@ -133,18 +135,18 @@ func (t *dbTables) parseRelated(rels []string, depth int) { inner := true for _, ex := range exs { - if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { - names = append(names, fi.name) - mmi = fi.relModelInfo + if fi, ok := mmi.Fields.GetByAny(ex); ok && fi.Rel && fi.FieldType != RelManyToMany { + names = append(names, fi.Name) + mmi = fi.RelModelInfo - if fi.null || t.skipEnd { + if fi.Null || t.skipEnd { inner = false } jt := t.set(names, mmi, fi, inner) jt.jtl = jtl - if fi.reverse { + if fi.Reverse { cancel = false } @@ -185,24 +187,24 @@ func (t *dbTables) getJoinSQL() (join string) { t1 = jt.jtl.index } t2 = jt.index - table = jt.mi.table + table = jt.mi.Table switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: - c1 = jt.fi.mi.fields.pk.column - for _, ffi := range jt.mi.fields.fieldsRel { - if jt.fi.mi == ffi.relModelInfo { - c2 = ffi.column + case jt.fi.FieldType == RelManyToMany || jt.fi.FieldType == RelReverseMany || jt.fi.Reverse && jt.fi.ReverseFieldInfo.FieldType == RelManyToMany: + c1 = jt.fi.Mi.Fields.Pk.Column + for _, ffi := range jt.mi.Fields.FieldsRel { + if jt.fi.Mi == ffi.RelModelInfo { + c2 = ffi.Column break } } default: - c1 = jt.fi.column - c2 = jt.fi.relModelInfo.fields.pk.column + c1 = jt.fi.Column + c2 = jt.fi.RelModelInfo.Fields.Pk.Column - if jt.fi.reverse { - c1 = jt.mi.fields.pk.column - c2 = jt.fi.reverseFieldInfo.column + if jt.fi.Reverse { + c1 = jt.mi.Fields.Pk.Column + c2 = jt.fi.ReverseFieldInfo.Column } } @@ -213,11 +215,11 @@ func (t *dbTables) getJoinSQL() (join string) { } // parse orm model struct field tag expression. -func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { +func (t *dbTables) parseExprs(mi *models.ModelInfo, exprs []string) (index, name string, info *models.FieldInfo, success bool) { var ( jtl *dbTable - fi *fieldInfo - fiN *fieldInfo + fi *models.FieldInfo + fiN *models.FieldInfo mmi = mi ) @@ -238,38 +240,38 @@ loopFor: } if i == 0 { - fi, ok = mmi.fields.GetByAny(ex) + fi, ok = mmi.Fields.GetByAny(ex) } _ = okN if ok { - isRel := fi.rel || fi.reverse + isRel := fi.Rel || fi.Reverse - names = append(names, fi.name) + names = append(names, fi.Name) switch { - case fi.rel: - mmi = fi.relModelInfo - if fi.fieldType == RelManyToMany { - mmi = fi.relThroughModelInfo + case fi.Rel: + mmi = fi.RelModelInfo + if fi.FieldType == RelManyToMany { + mmi = fi.RelThroughModelInfo } - case fi.reverse: - mmi = fi.reverseFieldInfo.mi + case fi.Reverse: + mmi = fi.ReverseFieldInfo.Mi } if i < num { - fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + fiN, okN = mmi.Fields.GetByAny(exprs[i+1]) } - if isRel && (!fi.mi.isThrough || num != i) { - if fi.null || t.skipEnd { + if isRel && (!fi.Mi.IsThrough || num != i) { + if fi.Null || t.skipEnd { inner = false } if t.skipEnd && okN || !t.skipEnd { - if t.skipEnd && okN && fiN.pk { + if t.skipEnd && okN && fiN.Pk { goto loopEnd } @@ -295,20 +297,20 @@ loopFor: info = fi if jtl == nil { - name = fi.name + name = fi.Name } else { - name = jtl.name + ExprSep + fi.name + name = jtl.name + ExprSep + fi.Name } switch { - case fi.rel: + case fi.Rel: - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { + case fi.Reverse: + switch fi.ReverseFieldInfo.FieldType { case RelOneToOne, RelForeignKey: index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name + info = fi.ReverseFieldInfo.Mi.Fields.Pk + name = info.Name } } @@ -382,7 +384,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) } - leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) where += fmt.Sprintf("%s %s ", leftCol, operSQL) @@ -415,7 +417,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } - groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q)) } groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) @@ -449,7 +451,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.Column, Q, order.SortString())) } } @@ -458,7 +460,7 @@ func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { } // generate limit sql. -func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { +func (t *dbTables) getLimitSQL(mi *models.ModelInfo, offset int64, limit int64) (limits string) { if limit == 0 { limit = int64(DefaultRowsLimit) } @@ -490,7 +492,7 @@ func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) } // crete new tables collection. -func newDbTables(mi *modelInfo, base dbBaser) *dbTables { +func newDbTables(mi *models.ModelInfo, base dbBaser) *dbTables { tables := &dbTables{} tables.tablesM = make(map[string]*dbTable) tables.mi = mi diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 48c5b4e7..8d91b091 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -41,9 +41,9 @@ func (d *dbBaseTidb) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } -// show columns sql of table for mysql. +// show Columns sql of table for mysql. func (d *dbBaseTidb) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } diff --git a/client/orm/db_utils.go b/client/orm/db_utils.go index 01f5a028..45c95f85 100644 --- a/client/orm/db_utils.go +++ b/client/orm/db_utils.go @@ -18,6 +18,10 @@ import ( "fmt" "reflect" "time" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // get table alias. @@ -29,32 +33,32 @@ func getDbAlias(name string) *alias { } // get pk column info. -func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { - fi := mi.fields.pk +func getExistPk(mi *models.ModelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.Fields.Pk - v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPositiveIntegerField > 0 { + v := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsPositiveIntegerField > 0 { vu := v.Uint() exist = vu > 0 value = vu - } else if fi.fieldType&IsIntegerField > 0 { + } else if fi.FieldType&IsIntegerField > 0 { vu := v.Int() exist = true value = vu - } else if fi.fieldType&IsRelField > 0 { - _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) + } else if fi.FieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.RelModelInfo, reflect.Indirect(v)) } else { vu := v.String() exist = vu != "" value = vu } - column = fi.column + column = fi.Column return } -// get fields description as flatted string. -func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { +// get Fields description as flatted string. +func getFlatParams(fi *models.FieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { outFor: for _, arg := range args { if arg == nil { @@ -74,32 +78,32 @@ outFor: case reflect.String: v := val.String() if fi != nil { - if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + if fi.FieldType == TypeTimeField || fi.FieldType == TypeDateField || fi.FieldType == TypeDateTimeField { var t time.Time var err error if len(v) >= 19 { s := v[:19] - t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) + t, err = time.ParseInLocation(utils.FormatDateTime, s, DefaultTimeLoc) } else if len(v) >= 10 { s := v if len(v) > 10 { s = v[:10] } - t, err = time.ParseInLocation(formatDate, s, tz) + t, err = time.ParseInLocation(utils.FormatDate, s, tz) } else { s := v if len(s) > 8 { s = v[:8] } - t, err = time.ParseInLocation(formatTime, s, tz) + t, err = time.ParseInLocation(utils.FormatTime, s, tz) } if err == nil { - if fi.fieldType == TypeDateField { - v = t.In(tz).Format(formatDate) - } else if fi.fieldType == TypeDateTimeField { - v = t.In(tz).Format(formatDateTime) + if fi.FieldType == TypeDateField { + v = t.In(tz).Format(utils.FormatDate) + } else if fi.FieldType == TypeDateTimeField { + v = t.In(tz).Format(utils.FormatDateTime) } else { - v = t.In(tz).Format(formatTime) + v = t.In(tz).Format(utils.FormatTime) } } } @@ -110,7 +114,7 @@ outFor: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: arg = val.Uint() case reflect.Float32: - arg, _ = StrTo(ToStr(arg)).Float64() + arg, _ = utils.StrTo(utils.ToStr(arg)).Float64() case reflect.Float64: arg = val.Float() case reflect.Bool: @@ -143,18 +147,18 @@ outFor: continue outFor case reflect.Struct: if v, ok := arg.(time.Time); ok { - if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(formatDate) - } else if fi != nil && fi.fieldType == TypeDateTimeField { - arg = v.In(tz).Format(formatDateTime) - } else if fi != nil && fi.fieldType == TypeTimeField { - arg = v.In(tz).Format(formatTime) + if fi != nil && fi.FieldType == TypeDateField { + arg = v.In(tz).Format(utils.FormatDate) + } else if fi != nil && fi.FieldType == TypeDateTimeField { + arg = v.In(tz).Format(utils.FormatDateTime) + } else if fi != nil && fi.FieldType == TypeTimeField { + arg = v.In(tz).Format(utils.FormatTime) } else { - arg = v.In(tz).Format(formatDateTime) + arg = v.In(tz).Format(utils.FormatDateTime) } } else { typ := val.Type() - name := getFullName(typ) + name := models.GetFullName(typ) var value interface{} if mmi, ok := defaultModelCache.getByFullName(name); ok { if _, vu, exist := getExistPk(mmi, val); exist { diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index 3b23284d..c8a28967 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,10 @@ import ( "reflect" "time" + utils2 "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -192,13 +196,13 @@ func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QueryS var ( name string md interface{} - mi *modelInfo + mi *models.ModelInfo ) if table, ok := ptrStructOrTableName.(string); ok { name = table } else { - name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + name = models.GetFullName(utils2.IndirectType(reflect.TypeOf(ptrStructOrTableName))) md = ptrStructOrTableName } @@ -303,7 +307,7 @@ func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, erro func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var ( md interface{} - mi *modelInfo + mi *models.ModelInfo ) sind := reflect.Indirect(reflect.ValueOf(mds)) diff --git a/client/orm/models_utils_test.go b/client/orm/internal/buffers/buffers.go similarity index 50% rename from client/orm/models_utils_test.go rename to client/orm/internal/buffers/buffers.go index 4dceda1c..045c00e0 100644 --- a/client/orm/models_utils_test.go +++ b/client/orm/internal/buffers/buffers.go @@ -1,10 +1,10 @@ -// Copyright 2020 +// Copyright 2014 beego Author. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,24 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package buffers -import ( - "reflect" - "testing" +import "github.com/valyala/bytebufferpool" - "github.com/stretchr/testify/assert" -) +var _ Buffer = &bytebufferpool.ByteBuffer{} -type NotApplicableModel struct { - Id int +type Buffer interface { + Write(p []byte) (int, error) + WriteString(s string) (int, error) + WriteByte(c byte) error } -func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { - return db == "default" +func Get() Buffer { + return bytebufferpool.Get() } -func TestIsApplicableTableForDB(t *testing.T) { - assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) - assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) +func Put(bf Buffer) { + bytebufferpool.Put(bf.(*bytebufferpool.ByteBuffer)) } diff --git a/client/orm/internal/logs/log.go b/client/orm/internal/logs/log.go new file mode 100644 index 00000000..3ddde3ad --- /dev/null +++ b/client/orm/internal/logs/log.go @@ -0,0 +1,20 @@ +package logs + +import ( + "io" + "log" + "os" +) + +var DebugLog = NewLog(os.Stdout) + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +func NewLog(out io.Writer) *Log { + d := new(Log) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) + return d +} diff --git a/client/orm/internal/models/models_fields.go b/client/orm/internal/models/models_fields.go new file mode 100644 index 00000000..70e5aafa --- /dev/null +++ b/client/orm/internal/models/models_fields.go @@ -0,0 +1,785 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "fmt" + "strconv" + "time" + + "github.com/beego/beego/v2/client/orm/internal/utils" +) + +// Define the Type enum +const ( + TypeBooleanField = 1 << iota + TypeVarCharField + TypeCharField + TypeTextField + TypeTimeField + TypeDateField + TypeDateTimeField + TypeBitField + TypeSmallIntegerField + TypeIntegerField + TypeBigIntegerField + TypePositiveBitField + TypePositiveSmallIntegerField + TypePositiveIntegerField + TypePositiveBigIntegerField + TypeFloatField + TypeDecimalField + TypeJSONField + TypeJsonbField + RelForeignKey + RelOneToOne + RelManyToMany + RelReverseOne + RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 + IsRelField = ^-RelReverseMany >> 18 << 19 + IsFieldType = ^-RelReverseMany<<1 + 1 +) + +// BooleanField A true/false field. +type BooleanField bool + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return bool(e) +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + *e = BooleanField(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return strconv.FormatBool(e.Value()) +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return TypeBooleanField +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + switch d := value.(type) { + case bool: + e.Set(d) + case string: + v, err := utils.StrTo(d).Bool() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return e.Value() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField string + +// Value return the CharField's Value +func (e CharField) Value() string { + return string(e) +} + +// Set CharField value +func (e *CharField) Set(d string) { + *e = CharField(d) +} + +// String return the CharField +func (e *CharField) String() string { + return e.Value() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return TypeVarCharField +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return e.Value() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField time.Time + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + *e = TimeField(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := utils.TimeParse(d, utils.FormatTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return e.Value() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField time.Time + +// Value return the time.Time +func (e DateField) Value() time.Time { + return time.Time(e) +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + *e = DateField(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := utils.TimeParse(d, utils.FormatDate) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return e.Value() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField time.Time + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + *e = DateTimeField(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return e.Value().String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return TypeDateTimeField +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := utils.TimeParse(d, utils.FormatDateTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return e.Value() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField float64 + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return float64(e) +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + *e = FloatField(d) +} + +// String return the string +func (e *FloatField) String() string { + return utils.ToStr(e.Value(), -1, 32) +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return TypeFloatField +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + switch d := value.(type) { + case float32: + e.Set(float64(d)) + case float64: + e.Set(d) + case string: + v, err := utils.StrTo(d).Float64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return e.Value() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField int16 + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return int16(e) +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + *e = SmallIntegerField(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return TypeSmallIntegerField +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int16: + e.Set(d) + case string: + v, err := utils.StrTo(d).Int16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField int32 + +// Value return the int32 +func (e IntegerField) Value() int32 { + return int32(e) +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + *e = IntegerField(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return TypeIntegerField +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int32: + e.Set(d) + case string: + v, err := utils.StrTo(d).Int32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return e.Value() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField int64 + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return int64(e) +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + *e = BigIntegerField(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return TypeBigIntegerField +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int64: + e.Set(d) + case string: + v, err := utils.StrTo(d).Int64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField uint16 + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return uint16(e) +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + *e = PositiveSmallIntegerField(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return TypePositiveSmallIntegerField +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint16: + e.Set(d) + case string: + v, err := utils.StrTo(d).Uint16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField uint32 + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return uint32(e) +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + *e = PositiveIntegerField(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint32: + e.Set(d) + case string: + v, err := utils.StrTo(d).Uint32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField uint64 + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return uint64(e) +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + *e = PositiveBigIntegerField(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return utils.ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint64: + e.Set(d) + case string: + v, err := utils.StrTo(d).Uint64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField string + +// Value return TextField value +func (e TextField) Value() string { + return string(e) +} + +// Set the TextField value +func (e *TextField) Set(d string) { + *e = TextField(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return e.Value() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return TypeTextField +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return e.Value() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField string + +// Value return JSONField value +func (j JSONField) Value() string { + return string(j) +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + *j = JSONField(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return TypeJSONField +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return j.Value() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField string + +// Value return JsonbField value +func (j JsonbField) Value() string { + return string(j) +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + *j = JsonbField(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return TypeJsonbField +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return j.Value() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/client/orm/models_info_f.go b/client/orm/internal/models/models_info_f.go similarity index 60% rename from client/orm/models_info_f.go rename to client/orm/internal/models/models_info_f.go index 6a9e7a99..3e5e4d6b 100644 --- a/client/orm/models_info_f.go +++ b/client/orm/internal/models/models_info_f.go @@ -12,147 +12,149 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( "errors" "fmt" "reflect" "strings" + + "github.com/beego/beego/v2/client/orm/internal/utils" ) var errSkipField = errors.New("skip field") -// field info collection -type fields struct { - pk *fieldInfo - columns map[string]*fieldInfo - fields map[string]*fieldInfo - fieldsLow map[string]*fieldInfo - fieldsByType map[int][]*fieldInfo - fieldsRel []*fieldInfo - fieldsReverse []*fieldInfo - fieldsDB []*fieldInfo - rels []*fieldInfo - orders []string - dbcols []string +// Fields field info collection +type Fields struct { + Pk *FieldInfo + Columns map[string]*FieldInfo + Fields map[string]*FieldInfo + FieldsLow map[string]*FieldInfo + FieldsByType map[int][]*FieldInfo + FieldsRel []*FieldInfo + FieldsReverse []*FieldInfo + FieldsDB []*FieldInfo + Rels []*FieldInfo + Orders []string + DBcols []string } -// add field info -func (f *fields) Add(fi *fieldInfo) (added bool) { - if f.fields[fi.name] == nil && f.columns[fi.column] == nil { - f.columns[fi.column] = fi - f.fields[fi.name] = fi - f.fieldsLow[strings.ToLower(fi.name)] = fi +// Add adds field info +func (f *Fields) Add(fi *FieldInfo) (added bool) { + if f.Fields[fi.Name] == nil && f.Columns[fi.Column] == nil { + f.Columns[fi.Column] = fi + f.Fields[fi.Name] = fi + f.FieldsLow[strings.ToLower(fi.Name)] = fi } else { return } - if _, ok := f.fieldsByType[fi.fieldType]; !ok { - f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + if _, ok := f.FieldsByType[fi.FieldType]; !ok { + f.FieldsByType[fi.FieldType] = make([]*FieldInfo, 0) } - f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) - f.orders = append(f.orders, fi.column) - if fi.dbcol { - f.dbcols = append(f.dbcols, fi.column) - f.fieldsDB = append(f.fieldsDB, fi) + f.FieldsByType[fi.FieldType] = append(f.FieldsByType[fi.FieldType], fi) + f.Orders = append(f.Orders, fi.Column) + if fi.DBcol { + f.DBcols = append(f.DBcols, fi.Column) + f.FieldsDB = append(f.FieldsDB, fi) } - if fi.rel { - f.fieldsRel = append(f.fieldsRel, fi) + if fi.Rel { + f.FieldsRel = append(f.FieldsRel, fi) } - if fi.reverse { - f.fieldsReverse = append(f.fieldsReverse, fi) + if fi.Reverse { + f.FieldsReverse = append(f.FieldsReverse, fi) } return true } -// get field info by name -func (f *fields) GetByName(name string) *fieldInfo { - return f.fields[name] +// GetByName get field info by name +func (f *Fields) GetByName(name string) *FieldInfo { + return f.Fields[name] } -// get field info by column name -func (f *fields) GetByColumn(column string) *fieldInfo { - return f.columns[column] +// GetByColumn get field info by column name +func (f *Fields) GetByColumn(column string) *FieldInfo { + return f.Columns[column] } -// get field info by string, name is prior -func (f *fields) GetByAny(name string) (*fieldInfo, bool) { - if fi, ok := f.fields[name]; ok { +// GetByAny get field info by string, name is prior +func (f *Fields) GetByAny(name string) (*FieldInfo, bool) { + if fi, ok := f.Fields[name]; ok { return fi, ok } - if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + if fi, ok := f.FieldsLow[strings.ToLower(name)]; ok { return fi, ok } - if fi, ok := f.columns[name]; ok { + if fi, ok := f.Columns[name]; ok { return fi, ok } return nil, false } -// create new field info collection -func newFields() *fields { - f := new(fields) - f.fields = make(map[string]*fieldInfo) - f.fieldsLow = make(map[string]*fieldInfo) - f.columns = make(map[string]*fieldInfo) - f.fieldsByType = make(map[int][]*fieldInfo) +// NewFields create new field info collection +func NewFields() *Fields { + f := new(Fields) + f.Fields = make(map[string]*FieldInfo) + f.FieldsLow = make(map[string]*FieldInfo) + f.Columns = make(map[string]*FieldInfo) + f.FieldsByType = make(map[int][]*FieldInfo) return f } -// single field info -type fieldInfo struct { - dbcol bool // table column fk and onetoone - inModel bool - auto bool - pk bool - null bool - index bool - unique bool - colDefault bool // whether has default tag - toText bool - autoNow bool - autoNowAdd bool - rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true - reverse bool - isFielder bool // implement Fielder interface - mi *modelInfo - fieldIndex []int - fieldType int - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField - initial StrTo // store the default value - size int - reverseField string - reverseFieldInfo *fieldInfo - reverseFieldInfoTwo *fieldInfo - reverseFieldInfoM2M *fieldInfo - relTable string - relThrough string - relThroughModelInfo *modelInfo - relModelInfo *modelInfo - digits int - decimals int - onDelete string - description string - timePrecision *int +// FieldInfo single field info +type FieldInfo struct { + DBcol bool // table column fk and onetoone + InModel bool + Auto bool + Pk bool + Null bool + Index bool + Unique bool + ColDefault bool // whether has default tag + ToText bool + AutoNow bool + AutoNowAdd bool + Rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true + Reverse bool + IsFielder bool // implement Fielder interface + Mi *ModelInfo + FieldIndex []int + FieldType int + Name string + FullName string + Column string + AddrValue reflect.Value + Sf reflect.StructField + Initial utils.StrTo // store the default value + Size int + ReverseField string + ReverseFieldInfo *FieldInfo + ReverseFieldInfoTwo *FieldInfo + ReverseFieldInfoM2M *FieldInfo + RelTable string + RelThrough string + RelThroughModelInfo *ModelInfo + RelModelInfo *ModelInfo + Digits int + Decimals int + OnDelete string + Description string + TimePrecision *int } -// new field info -func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { +// NewFieldInfo new field info +func NewFieldInfo(mi *ModelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *FieldInfo, err error) { var ( tag string tagValue string - initial StrTo // store the default value + initial utils.StrTo // store the default value fieldType int attrs map[string]bool tags map[string]string addrField reflect.Value ) - fi = new(fieldInfo) + fi = new(FieldInfo) // if field which CanAddr is the follow type // A value is addressable if it is an element of a slice, @@ -168,7 +170,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN } } - attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) + attrs, tags = ParseStructTag(sf.Tag.Get(DefaultStructTagName)) if _, ok := attrs["-"]; ok { return nil, errSkipField @@ -187,7 +189,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN checkType: switch f := addrField.Interface().(type) { case Fielder: - fi.isFielder = true + fi.IsFielder = true if field.Kind() == reflect.Ptr { err = fmt.Errorf("the model Fielder can not be use ptr") goto end @@ -211,9 +213,9 @@ checkType: case "m2m": fieldType = RelManyToMany if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv + fi.RelTable = tv } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv + fi.RelThrough = tv } break checkType default: @@ -231,9 +233,9 @@ checkType: case "many": fieldType = RelReverseMany if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv + fi.RelTable = tv } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv + fi.RelThrough = tv } break checkType default: @@ -295,117 +297,117 @@ checkType: goto end } - fi.fieldType = fieldType - fi.name = sf.Name - fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) - fi.addrValue = addrField - fi.sf = sf - fi.fullName = mi.fullName + mName + "." + sf.Name + fi.FieldType = fieldType + fi.Name = sf.Name + fi.Column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.AddrValue = addrField + fi.Sf = sf + fi.FullName = mi.FullName + mName + "." + sf.Name - fi.description = tags["description"] - fi.null = attrs["null"] - fi.index = attrs["index"] - fi.auto = attrs["auto"] - fi.pk = attrs["pk"] - fi.unique = attrs["unique"] + fi.Description = tags["description"] + fi.Null = attrs["null"] + fi.Index = attrs["index"] + fi.Auto = attrs["auto"] + fi.Pk = attrs["pk"] + fi.Unique = attrs["unique"] // Mark object property if there is attribute "default" in the orm configuration if _, ok := tags["default"]; ok { - fi.colDefault = true + fi.ColDefault = true } switch fieldType { case RelManyToMany, RelReverseMany, RelReverseOne: - fi.null = false - fi.index = false - fi.auto = false - fi.pk = false - fi.unique = false + fi.Null = false + fi.Index = false + fi.Auto = false + fi.Pk = false + fi.Unique = false default: - fi.dbcol = true + fi.DBcol = true } switch fieldType { case RelForeignKey, RelOneToOne, RelManyToMany: - fi.rel = true + fi.Rel = true if fieldType == RelOneToOne { - fi.unique = true + fi.Unique = true } case RelReverseMany, RelReverseOne: - fi.reverse = true + fi.Reverse = true } - if fi.rel && fi.dbcol { + if fi.Rel && fi.DBcol { switch onDelete { - case odCascade, odDoNothing: - case odSetDefault: + case OdCascade, OdDoNothing: + case OdSetDefault: if !initial.Exist() { err = errors.New("on_delete: set_default need set field a default value") goto end } - case odSetNULL: - if !fi.null { + case OdSetNULL: + if !fi.Null { err = errors.New("on_delete: set_null need set field null") goto end } default: if onDelete == "" { - onDelete = odCascade + onDelete = OdCascade } else { err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) goto end } } - fi.onDelete = onDelete + fi.OnDelete = onDelete } switch fieldType { case TypeBooleanField: case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: if size != "" { - v, e := StrTo(size).Int32() + v, e := utils.StrTo(size).Int32() if e != nil { err = fmt.Errorf("wrong size value `%s`", size) } else { - fi.size = int(v) + fi.Size = int(v) } } else { - fi.size = 255 - fi.toText = true + fi.Size = 255 + fi.ToText = true } case TypeTextField: - fi.index = false - fi.unique = false + fi.Index = false + fi.Unique = false case TypeTimeField, TypeDateField, TypeDateTimeField: if fieldType == TypeDateTimeField { if precision != "" { - v, e := StrTo(precision).Int() + v, e := utils.StrTo(precision).Int() if e != nil { err = fmt.Errorf("convert %s to int error:%v", precision, e) } else { - fi.timePrecision = &v + fi.TimePrecision = &v } } } if attrs["auto_now"] { - fi.autoNow = true + fi.AutoNow = true } else if attrs["auto_now_add"] { - fi.autoNowAdd = true + fi.AutoNowAdd = true } case TypeFloatField: case TypeDecimalField: d1 := digits d2 := decimals - v1, er1 := StrTo(d1).Int8() - v2, er2 := StrTo(d2).Int8() + v1, er1 := utils.StrTo(d1).Int8() + v2, er2 := utils.StrTo(d2).Int8() if er1 != nil || er2 != nil { err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) goto end } - fi.digits = int(v1) - fi.decimals = int(v2) + fi.Digits = int(v1) + fi.Decimals = int(v2) default: switch { case fieldType&IsIntegerField > 0: @@ -414,33 +416,33 @@ checkType: } if fieldType&IsIntegerField == 0 { - if fi.auto { + if fi.Auto { err = fmt.Errorf("non-integer type cannot set auto") goto end } } - if fi.auto || fi.pk { - if fi.auto { + if fi.Auto || fi.Pk { + if fi.Auto { switch addrField.Elem().Kind() { case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: default: err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) goto end } - fi.pk = true + fi.Pk = true } - fi.null = false - fi.index = false - fi.unique = false + fi.Null = false + fi.Index = false + fi.Unique = false } - if fi.unique { - fi.index = false + if fi.Unique { + fi.Index = false } // can not set default for these type - if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { + if fi.Auto || fi.Pk || fi.Unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { initial.Clear() } @@ -474,7 +476,7 @@ checkType: } } - fi.initial = initial + fi.Initial = initial end: if err != nil { return nil, err diff --git a/client/orm/internal/models/models_info_m.go b/client/orm/internal/models/models_info_m.go new file mode 100644 index 00000000..0dee0aa8 --- /dev/null +++ b/client/orm/internal/models/models_info_m.go @@ -0,0 +1,148 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "fmt" + "os" + "reflect" +) + +// ModelInfo single model info +type ModelInfo struct { + Manual bool + IsThrough bool + Pkg string + Name string + FullName string + Table string + Model interface{} + Fields *Fields + AddrField reflect.Value // store the original struct value + Uniques []string +} + +// NewModelInfo new model info +func NewModelInfo(val reflect.Value) (mi *ModelInfo) { + mi = &ModelInfo{} + mi.Fields = NewFields() + ind := reflect.Indirect(val) + mi.AddrField = val + mi.Name = ind.Type().Name() + mi.FullName = GetFullName(ind.Type()) + AddModelFields(mi, ind, "", []int{}) + return +} + +// AddModelFields index: FieldByIndex returns the nested field corresponding to index +func AddModelFields(mi *ModelInfo, ind reflect.Value, mName string, index []int) { + var ( + err error + fi *FieldInfo + sf reflect.StructField + ) + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + // if the field is unexported skip + if sf.PkgPath != "" { + continue + } + // add anonymous struct Fields + if sf.Anonymous { + AddModelFields(mi, field, mName+"."+sf.Name, append(index, i)) + continue + } + + fi, err = NewFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { + break + } + // record current field index + fi.FieldIndex = append(fi.FieldIndex, index...) + fi.FieldIndex = append(fi.FieldIndex, i) + fi.Mi = mi + fi.InModel = true + if !mi.Fields.Add(fi) { + err = fmt.Errorf("duplicate column name: %s", fi.Column) + break + } + if fi.Pk { + if mi.Fields.Pk != nil { + err = fmt.Errorf("one model must have one pk field only") + break + } else { + mi.Fields.Pk = fi + } + } + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } +} + +// NewM2MModelInfo combine related model info to new model info. +// prepare for relation models query. +func NewM2MModelInfo(m1, m2 *ModelInfo) (mi *ModelInfo) { + mi = new(ModelInfo) + mi.Fields = NewFields() + mi.Table = m1.Table + "_" + m2.Table + "s" + mi.Name = CamelString(mi.Table) + mi.FullName = m1.Pkg + "." + mi.Name + + fa := new(FieldInfo) // pk + f1 := new(FieldInfo) // m1 table RelForeignKey + f2 := new(FieldInfo) // m2 table RelForeignKey + fa.FieldType = TypeBigIntegerField + fa.Auto = true + fa.Pk = true + fa.DBcol = true + fa.Name = "Id" + fa.Column = "id" + fa.FullName = mi.FullName + "." + fa.Name + + f1.DBcol = true + f2.DBcol = true + f1.FieldType = RelForeignKey + f2.FieldType = RelForeignKey + f1.Name = CamelString(m1.Table) + f2.Name = CamelString(m2.Table) + f1.FullName = mi.FullName + "." + f1.Name + f2.FullName = mi.FullName + "." + f2.Name + f1.Column = m1.Table + "_id" + f2.Column = m2.Table + "_id" + f1.Rel = true + f2.Rel = true + f1.RelTable = m1.Table + f2.RelTable = m2.Table + f1.RelModelInfo = m1 + f2.RelModelInfo = m2 + f1.Mi = mi + f2.Mi = mi + + mi.Fields.Add(fa) + mi.Fields.Add(f1) + mi.Fields.Add(f2) + mi.Fields.Pk = fa + + mi.Uniques = []string{f1.Column, f2.Column} + return +} diff --git a/client/orm/models_utils.go b/client/orm/internal/models/models_utils.go similarity index 67% rename from client/orm/models_utils.go rename to client/orm/internal/models/models_utils.go index b2e5760e..b5204606 100644 --- a/client/orm/models_utils.go +++ b/client/orm/internal/models/models_utils.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( "database/sql" @@ -20,6 +20,8 @@ import ( "reflect" "strings" "time" + + "github.com/beego/beego/v2/client/orm/internal/logs" ) // 1 is attr @@ -48,15 +50,29 @@ var supportTag = map[string]int{ "precision": 2, } -// get reflect.Type name with package path. -func getFullName(typ reflect.Type) string { +type fn func(string) string + +var ( + NameStrategyMap = map[string]fn{ + DefaultNameStrategy: SnakeString, + SnakeAcronymNameStrategy: SnakeStringWithAcronym, + } + DefaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + NameStrategy = DefaultNameStrategy + defaultStructTagDelim = ";" + DefaultStructTagName = "orm" +) + +// GetFullName get reflect.Type name with package path. +func GetFullName(typ reflect.Type) string { return typ.PkgPath() + "." + typ.Name() } -// getTableName get struct table name. +// GetTableName get struct table name. // If the struct implement the TableName, then get the result as tablename // else use the struct name which will apply snakeString. -func getTableName(val reflect.Value) string { +func GetTableName(val reflect.Value) string { if fun := val.MethodByName("TableName"); fun.IsValid() { vals := fun.Call([]reflect.Value{}) // has return and the first val is string @@ -64,11 +80,11 @@ func getTableName(val reflect.Value) string { return vals[0].String() } } - return snakeString(reflect.Indirect(val).Type().Name()) + return SnakeString(reflect.Indirect(val).Type().Name()) } -// get table engine, myisam or innodb. -func getTableEngine(val reflect.Value) string { +// GetTableEngine get table engine, myisam or innodb. +func GetTableEngine(val reflect.Value) string { fun := val.MethodByName("TableEngine") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -79,8 +95,8 @@ func getTableEngine(val reflect.Value) string { return "" } -// get table index from method. -func getTableIndex(val reflect.Value) [][]string { +// GetTableIndex get table index from method. +func GetTableIndex(val reflect.Value) [][]string { fun := val.MethodByName("TableIndex") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -93,8 +109,8 @@ func getTableIndex(val reflect.Value) [][]string { return nil } -// get table unique from method -func getTableUnique(val reflect.Value) [][]string { +// GetTableUnique get table unique from method +func GetTableUnique(val reflect.Value) [][]string { fun := val.MethodByName("TableUnique") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -107,8 +123,8 @@ func getTableUnique(val reflect.Value) [][]string { return nil } -// get whether the table needs to be created for the database alias -func isApplicableTableForDB(val reflect.Value, db string) bool { +// IsApplicableTableForDB get whether the table needs to be created for the database alias +func IsApplicableTableForDB(val reflect.Value, db string) bool { if !val.IsValid() { return true } @@ -126,7 +142,7 @@ func isApplicableTableForDB(val reflect.Value, db string) bool { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col if col == "" { - column = nameStrategyMap[nameStrategy](sf.Name) + column = NameStrategyMap[NameStrategy](sf.Name) } switch ft { case RelForeignKey, RelOneToOne: @@ -218,8 +234,8 @@ func getFieldType(val reflect.Value) (ft int, err error) { return } -// parse struct tag string -func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { +// ParseStructTag parse struct tag string +func ParseStructTag(data string) (attrs map[string]bool, tags map[string]string) { attrs = make(map[string]bool) tags = make(map[string]string) for _, v := range strings.Split(data, defaultStructTagDelim) { @@ -236,8 +252,74 @@ func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) tags[name] = v } } else { - DebugLog.Println("unsupport orm tag", v) + logs.DebugLog.Println("unsupport orm tag", v) } } return } + +func SnakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data)) +} + +// SnakeString snake string, XxYy to xx_yy , XxYY to xx_y_y +func SnakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data)) +} + +// CamelString camel string, xx_yy to XxYy +func CamelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data) +} + +const ( + OdCascade = "cascade" + OdSetNULL = "set_null" + OdSetDefault = "set_default" + OdDoNothing = "do_nothing" +) diff --git a/client/orm/utils_test.go b/client/orm/internal/models/models_utils_test.go similarity index 75% rename from client/orm/utils_test.go rename to client/orm/internal/models/models_utils_test.go index 7d94cada..40bffc66 100644 --- a/client/orm/utils_test.go +++ b/client/orm/internal/models/models_utils_test.go @@ -1,10 +1,10 @@ -// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,27 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( + "reflect" "testing" + + "github.com/stretchr/testify/assert" ) -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} +type NotApplicableModel struct { + Id int +} - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } +func TestIsApplicableTableForDB(t *testing.T) { + assert.False(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) } func TestSnakeString(t *testing.T) { @@ -45,7 +44,7 @@ func TestSnakeString(t *testing.T) { } for _, v := range camel { - res := snakeString(v) + res := SnakeString(v) if res != answer[v] { t.Error("Unit Test Fail:", v, res, answer[v]) } @@ -62,7 +61,24 @@ func TestSnakeStringWithAcronym(t *testing.T) { } for _, v := range camel { - res := snakeStringWithAcronym(v) + res := SnakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := CamelString(v) if res != answer[v] { t.Error("Unit Test Fail:", v, res, answer[v]) } diff --git a/client/orm/internal/models/types.go b/client/orm/internal/models/types.go new file mode 100644 index 00000000..f3b7989a --- /dev/null +++ b/client/orm/internal/models/types.go @@ -0,0 +1,23 @@ +// Copyright 2023 beego-dev. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +// Fielder define field info +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} +} diff --git a/client/orm/internal/utils/utils.go b/client/orm/internal/utils/utils.go new file mode 100644 index 00000000..5e338487 --- /dev/null +++ b/client/orm/internal/utils/utils.go @@ -0,0 +1,249 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "time" +) + +// StrTo is the target string +type StrTo string + +// Set string +func (f *StrTo) Set(v string) { + if v != "" { + *f = StrTo(v) + } else { + f.Clear() + } +} + +// Clear string +func (f *StrTo) Clear() { + *f = StrTo(rune(0x1E)) +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return string(f) != string(rune(0x1E)) +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +// Int string to int +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err +} + +// String string to string +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, ArgInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, ArgInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +type ArgString []string + +// Get get string by index from string slice +func (a ArgString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type ArgInt []int + +// Get get int by index from int slice +func (a ArgInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// TimeParse parse time to string with location +func TimeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// IndirectType get pointer indirect type +func IndirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return IndirectType(v.Elem()) + default: + return v + } +} + +const ( + FormatTime = "15:04:05" + FormatDate = "2006-01-02" + FormatDateTime = "2006-01-02 15:04:05" +) + +var ( + DefaultTimeLoc = time.Local +) diff --git a/client/orm/invocation.go b/client/orm/invocation.go index 9e7c1974..48fdbf6e 100644 --- a/client/orm/invocation.go +++ b/client/orm/invocation.go @@ -17,6 +17,8 @@ package orm import ( "context" "time" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // Invocation represents an "Orm" invocation @@ -27,7 +29,7 @@ type Invocation struct { // the args are all arguments except context.Context Args []interface{} - mi *modelInfo + mi *models.ModelInfo // f is the Orm operation f func(ctx context.Context) []interface{} @@ -39,7 +41,7 @@ type Invocation struct { func (inv *Invocation) GetTableName() string { if inv.mi != nil { - return inv.mi.table + return inv.mi.Table } return "" } @@ -51,8 +53,8 @@ func (inv *Invocation) execute(ctx context.Context) []interface{} { // GetPkFieldName return the primary key of this table // if not found, "" is returned func (inv *Invocation) GetPkFieldName() string { - if inv.mi.fields.pk != nil { - return inv.mi.fields.pk.name + if inv.mi.Fields.Pk != nil { + return inv.mi.Fields.Pk.Name } return "" } diff --git a/client/orm/model_utils_test.go b/client/orm/model_utils_test.go index be97c58d..d3d57cdf 100644 --- a/client/orm/model_utils_test.go +++ b/client/orm/model_utils_test.go @@ -17,6 +17,8 @@ package orm import ( "testing" + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/stretchr/testify/assert" ) @@ -53,10 +55,10 @@ func TestDbBase_GetTables(t *testing.T) { assert.True(t, ok) assert.NotNil(t, mi) - engine := getTableEngine(mi.addrField) + engine := models.GetTableEngine(mi.AddrField) assert.Equal(t, "innodb", engine) - uniques := getTableUnique(mi.addrField) + uniques := models.GetTableUnique(mi.AddrField) assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques) - indexes := getTableIndex(mi.addrField) + indexes := models.GetTableIndex(mi.AddrField) assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes) } diff --git a/client/orm/models.go b/client/orm/models.go index 94630ba5..542ced59 100644 --- a/client/orm/models.go +++ b/client/orm/models.go @@ -21,15 +21,8 @@ import ( "runtime/debug" "strings" "sync" -) -const ( - odCascade = "cascade" - odSetNULL = "set_null" - odSetDefault = "set_default" - odDoNothing = "do_nothing" - defaultStructTagName = "orm" - defaultStructTagDelim = ";" + imodels "github.com/beego/beego/v2/client/orm/internal/models" ) var defaultModelCache = NewModelCacheHandler() @@ -38,22 +31,22 @@ var defaultModelCache = NewModelCacheHandler() type modelCache struct { sync.RWMutex // only used outsite for bootStrap orders []string - cache map[string]*modelInfo - cacheByFullName map[string]*modelInfo + cache map[string]*imodels.ModelInfo + cacheByFullName map[string]*imodels.ModelInfo done bool } // NewModelCacheHandler generator of modelCache func NewModelCacheHandler() *modelCache { return &modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), + cache: make(map[string]*imodels.ModelInfo), + cacheByFullName: make(map[string]*imodels.ModelInfo), } } // get all model info -func (mc *modelCache) all() map[string]*modelInfo { - m := make(map[string]*modelInfo, len(mc.cache)) +func (mc *modelCache) all() map[string]*imodels.ModelInfo { + m := make(map[string]*imodels.ModelInfo, len(mc.cache)) for k, v := range mc.cache { m[k] = v } @@ -61,8 +54,8 @@ func (mc *modelCache) all() map[string]*modelInfo { } // get ordered model info -func (mc *modelCache) allOrdered() []*modelInfo { - m := make([]*modelInfo, 0, len(mc.orders)) +func (mc *modelCache) allOrdered() []*imodels.ModelInfo { + m := make([]*imodels.ModelInfo, 0, len(mc.orders)) for _, table := range mc.orders { m = append(m, mc.cache[table]) } @@ -70,30 +63,30 @@ func (mc *modelCache) allOrdered() []*modelInfo { } // get model info by table name -func (mc *modelCache) get(table string) (mi *modelInfo, ok bool) { +func (mc *modelCache) get(table string) (mi *imodels.ModelInfo, ok bool) { mi, ok = mc.cache[table] return } // get model info by full name -func (mc *modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { +func (mc *modelCache) getByFullName(name string) (mi *imodels.ModelInfo, ok bool) { mi, ok = mc.cacheByFullName[name] return } -func (mc *modelCache) getByMd(md interface{}) (*modelInfo, bool) { +func (mc *modelCache) getByMd(md interface{}) (*imodels.ModelInfo, bool) { val := reflect.ValueOf(md) ind := reflect.Indirect(val) typ := ind.Type() - name := getFullName(typ) + name := imodels.GetFullName(typ) return mc.getByFullName(name) } // set model info to collection -func (mc *modelCache) set(table string, mi *modelInfo) *modelInfo { +func (mc *modelCache) set(table string, mi *imodels.ModelInfo) *imodels.ModelInfo { mii := mc.cache[table] mc.cache[table] = mi - mc.cacheByFullName[mi.fullName] = mi + mc.cacheByFullName[mi.FullName] = mi if mii == nil { mc.orders = append(mc.orders, table) } @@ -106,8 +99,8 @@ func (mc *modelCache) clean() { defer mc.Unlock() mc.orders = make([]string, 0) - mc.cache = make(map[string]*modelInfo) - mc.cacheByFullName = make(map[string]*modelInfo) + mc.cache = make(map[string]*imodels.ModelInfo) + mc.cacheByFullName = make(map[string]*imodels.ModelInfo) mc.done = false } @@ -120,7 +113,7 @@ func (mc *modelCache) bootstrap() { } var ( err error - models map[string]*modelInfo + models map[string]*imodels.ModelInfo ) if dataBaseCache.getDefault() == nil { err = fmt.Errorf("must have one register DataBase alias named `default`") @@ -131,51 +124,51 @@ func (mc *modelCache) bootstrap() { // RelManyToMany set the relTable models = mc.all() for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + for _, fi := range mi.Fields.Columns { + if fi.Rel || fi.Reverse { + elm := fi.AddrValue.Type().Elem() + if fi.FieldType == RelReverseMany || fi.FieldType == RelManyToMany { elm = elm.Elem() } // check the rel or reverse model already register - name := getFullName(elm) + name := imodels.GetFullName(elm) mii, ok := mc.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + if !ok || mii.Pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.FullName, elm.String()) goto end } - fi.relModelInfo = mii + fi.RelModelInfo = mii - switch fi.fieldType { + switch fi.FieldType { case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := mc.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + if fi.RelThrough != "" { + if i := strings.LastIndex(fi.RelThrough, "."); i != -1 && len(fi.RelThrough) > (i+1) { + pn := fi.RelThrough[:i] + rmi, ok := mc.getByFullName(fi.RelThrough) + if !ok || pn != rmi.Pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.FullName, fi.RelThrough) goto end } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table + fi.RelThroughModelInfo = rmi + fi.RelTable = rmi.Table } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.FullName, fi.RelThrough) goto end } } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable + i := imodels.NewM2MModelInfo(mi, mii) + if fi.RelTable != "" { + i.Table = fi.RelTable } - if v := mc.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + if v := mc.set(i.Table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.RelTable) goto end } - fi.relTable = i.table - fi.relThroughModelInfo = i + fi.RelTable = i.Table + fi.RelThroughModelInfo = i } - fi.relThroughModelInfo.isThrough = true + fi.RelThroughModelInfo.IsThrough = true } } } @@ -185,42 +178,42 @@ func (mc *modelCache) bootstrap() { // if not exist, add a new field to the relModelInfo models = mc.all() for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { + for _, fi := range mi.Fields.FieldsRel { + switch fi.FieldType { case RelForeignKey, RelOneToOne, RelManyToMany: inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { + for _, ffi := range fi.RelModelInfo.Fields.FieldsReverse { + if ffi.RelModelInfo == mi { inModel = true break } } if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne + rmi := fi.RelModelInfo + ffi := new(imodels.FieldInfo) + ffi.Name = mi.Name + ffi.Column = ffi.Name + ffi.FullName = rmi.FullName + "." + ffi.Name + ffi.Reverse = true + ffi.RelModelInfo = mi + ffi.Mi = rmi + if fi.FieldType == RelOneToOne { + ffi.FieldType = RelReverseOne } else { - ffi.fieldType = RelReverseMany + ffi.FieldType = RelReverseMany } - if !rmi.fields.Add(ffi) { + if !rmi.Fields.Add(ffi) { added := false for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { + ffi.Name = fmt.Sprintf("%s%d", mi.Name, cnt) + ffi.Column = ffi.Name + ffi.FullName = rmi.FullName + "." + ffi.Name + if added = rmi.Fields.Add(ffi); added { break } } if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.FullName, ffi.FullName)) } } } @@ -230,24 +223,24 @@ func (mc *modelCache) bootstrap() { models = mc.all() for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { + for _, fi := range mi.Fields.FieldsRel { + switch fi.FieldType { case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { + for _, ffi := range fi.RelThroughModelInfo.Fields.FieldsRel { + switch ffi.FieldType { case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi + if ffi.RelModelInfo == fi.RelModelInfo { + fi.ReverseFieldInfoTwo = ffi } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi + if ffi.RelModelInfo == mi { + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi } } } - if fi.reverseFieldInfoTwo == nil { + if fi.ReverseFieldInfoTwo == nil { err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) + fi.RelThroughModelInfo.FullName) goto end } } @@ -256,63 +249,63 @@ func (mc *modelCache) bootstrap() { models = mc.all() for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { + for _, fi := range mi.Fields.FieldsReverse { + switch fi.FieldType { case RelReverseOne: found := false mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelOneToOne] { + if ffi.RelModelInfo == mi { found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi + ffi.ReverseField = fi.Name + ffi.ReverseFieldInfo = fi break mForA } } if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName) goto end } case RelReverseMany: found := false mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelForeignKey] { + if ffi.RelModelInfo == mi { found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi + ffi.ReverseField = fi.Name + ffi.ReverseFieldInfo = fi break mForB } } if !found { mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelManyToMany] { + conditions := fi.RelThrough != "" && fi.RelThrough == ffi.RelThrough || + fi.RelTable != "" && fi.RelTable == ffi.RelTable || + fi.RelThrough == "" && fi.RelTable == "" + if ffi.RelModelInfo == mi && conditions { found = true - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi + fi.ReverseField = ffi.ReverseFieldInfoTwo.Name + fi.ReverseFieldInfo = ffi.ReverseFieldInfoTwo + fi.RelThroughModelInfo = ffi.RelThroughModelInfo + fi.ReverseFieldInfoTwo = ffi.ReverseFieldInfo + fi.ReverseFieldInfoM2M = ffi + ffi.ReverseFieldInfoM2M = fi break mForC } } } if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName) goto end } } @@ -334,7 +327,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo typ := reflect.Indirect(val).Type() if val.Kind() != reflect.Ptr { - err = fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)) + err = fmt.Errorf(" cannot use non-ptr model struct `%s`", imodels.GetFullName(typ)) return } // For this case: @@ -347,7 +340,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo if val.Elem().Kind() == reflect.Slice { val = reflect.New(val.Elem().Type().Elem()) } - table := getTableName(val) + table := imodels.GetTableName(val) if prefixOrSuffixStr != "" { if prefixOrSuffix { @@ -358,7 +351,7 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo } // models's fullname is pkgpath + struct name - name := getFullName(typ) + name := imodels.GetFullName(typ) if _, ok := mc.getByFullName(name); ok { err = fmt.Errorf(" model `%s` repeat register, must be unique\n", name) return @@ -368,26 +361,26 @@ func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, mo return nil } - mi := newModelInfo(val) - if mi.fields.pk == nil { + mi := imodels.NewModelInfo(val) + if mi.Fields.Pk == nil { outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { + for _, fi := range mi.Fields.FieldsDB { + if strings.ToLower(fi.Name) == "id" { + switch fi.AddrValue.Elem().Kind() { case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi + fi.Auto = true + fi.Pk = true + mi.Fields.Pk = fi break outFor } } } } - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true + mi.Table = table + mi.Pkg = typ.PkgPath() + mi.Model = model + mi.Manual = true mc.set(table, mi) } @@ -404,7 +397,7 @@ func (mc *modelCache) getDbDropSQL(al *alias) (queries []string, err error) { Q := al.DbBaser.TableQuote() for _, mi := range mc.allOrdered() { - queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.Table, Q)) } return queries, nil } @@ -424,33 +417,33 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes for _, mi := range mc.allOrdered() { sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.FullName) sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.Table, Q) - columns := make([]string, 0, len(mi.fields.fieldsDB)) + columns := make([]string, 0, len(mi.Fields.FieldsDB)) sqlIndexes := [][]string{} var commentIndexes []int // store comment indexes for postgres - for i, fi := range mi.fields.fieldsDB { - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + for i, fi := range mi.Fields.FieldsDB { + column := fmt.Sprintf(" %s%s%s ", Q, fi.Column, Q) col := getColumnTyp(al, fi) - if fi.auto { + if fi.Auto { switch al.Driver { case DRSqlite, DRPostgres: column += T["auto"] default: column += col + " " + T["auto"] } - } else if fi.pk { + } else if fi.Pk { column += col + " " + T["pk"] } else { column += col - if !fi.null { + if !fi.Null { column += " " + "NOT NULL" } @@ -461,42 +454,42 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes // Append attribute DEFAULT column += getColumnDefault(fi) - if fi.unique { + if fi.Unique { column += " " + "UNIQUE" } - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) + if fi.Index { + sqlIndexes = append(sqlIndexes, []string{fi.Column}) } } if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) + column = strings.Replace(column, "%COL%", fi.Column, -1) } - if fi.description != "" && al.Driver != DRSqlite { + if fi.Description != "" && al.Driver != DRSqlite { if al.Driver == DRPostgres { commentIndexes = append(commentIndexes, i) } else { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + column += " " + fmt.Sprintf("COMMENT '%s'", fi.Description) } } columns = append(columns, column) } - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) + if mi.Model != nil { + allnames := imodels.GetTableUnique(mi.AddrField) + if !mi.Manual && len(mi.Uniques) > 0 { + allnames = append(allnames, mi.Uniques) } for _, names := range allnames { cols := make([]string, 0, len(names)) for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) + if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol { + cols = append(cols, fi.Column) } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.FullName)) } } column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) @@ -509,8 +502,8 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes if al.Driver == DRMySQL { var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) + if mi.Model != nil { + engine = imodels.GetTableEngine(mi.AddrField) } if engine == "" { engine = al.Engine @@ -524,24 +517,24 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes for _, index := range commentIndexes { sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';", Q, - mi.table, + mi.Table, Q, Q, - mi.fields.fieldsDB[index].column, + mi.Fields.FieldsDB[index].Column, Q, - mi.fields.fieldsDB[index].description) + mi.Fields.FieldsDB[index].Description) } } queries = append(queries, sql) - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { + if mi.Model != nil { + for _, names := range imodels.GetTableIndex(mi.AddrField) { cols := make([]string, 0, len(names)) for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) + if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol { + cols = append(cols, fi.Column) } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.FullName)) } } sqlIndexes = append(sqlIndexes, cols) @@ -549,16 +542,16 @@ func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes } for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") + name := mi.Table + "_" + strings.Join(names, "_") cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.Table, Q, Q, cols, Q) index := dbIndex{} - index.Table = mi.table + index.Table = mi.Table index.Name = name index.SQL = sql - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + tableIndexes[mi.Table] = append(tableIndexes[mi.Table], index) } } diff --git a/client/orm/models_fields.go b/client/orm/models_fields.go index b4fad94f..4f07ea18 100644 --- a/client/orm/models_fields.go +++ b/client/orm/models_fields.go @@ -15,91 +15,47 @@ package orm import ( - "fmt" - "strconv" - "time" + "github.com/beego/beego/v2/client/orm/internal/models" ) // Define the Type enum const ( - TypeBooleanField = 1 << iota - TypeVarCharField - TypeCharField - TypeTextField - TypeTimeField - TypeDateField - TypeDateTimeField - TypeBitField - TypeSmallIntegerField - TypeIntegerField - TypeBigIntegerField - TypePositiveBitField - TypePositiveSmallIntegerField - TypePositiveIntegerField - TypePositiveBigIntegerField - TypeFloatField - TypeDecimalField - TypeJSONField - TypeJsonbField - RelForeignKey - RelOneToOne - RelManyToMany - RelReverseOne - RelReverseMany + TypeBooleanField = models.TypeBooleanField + TypeVarCharField = models.TypeVarCharField + TypeCharField = models.TypeCharField + TypeTextField = models.TypeTextField + TypeTimeField = models.TypeTimeField + TypeDateField = models.TypeDateField + TypeDateTimeField = models.TypeDateTimeField + TypeBitField = models.TypeBitField + TypeSmallIntegerField = models.TypeSmallIntegerField + TypeIntegerField = models.TypeIntegerField + TypeBigIntegerField = models.TypeBigIntegerField + TypePositiveBitField = models.TypePositiveBitField + TypePositiveSmallIntegerField = models.TypePositiveSmallIntegerField + TypePositiveIntegerField = models.TypePositiveIntegerField + TypePositiveBigIntegerField = models.TypePositiveBigIntegerField + TypeFloatField = models.TypeFloatField + TypeDecimalField = models.TypeDecimalField + TypeJSONField = models.TypeJSONField + TypeJsonbField = models.TypeJsonbField + RelForeignKey = models.RelForeignKey + RelOneToOne = models.RelOneToOne + RelManyToMany = models.RelManyToMany + RelReverseOne = models.RelReverseOne + RelReverseMany = models.RelReverseMany ) // Define some logic enum const ( - IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 - IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 - IsRelField = ^-RelReverseMany >> 18 << 19 - IsFieldType = ^-RelReverseMany<<1 + 1 + IsIntegerField = models.IsIntegerField + IsPositiveIntegerField = models.IsPositiveIntegerField + IsRelField = models.IsRelField + IsFieldType = models.IsFieldType ) // BooleanField A true/false field. -type BooleanField bool - -// Value return the BooleanField -func (e BooleanField) Value() bool { - return bool(e) -} - -// Set will set the BooleanField -func (e *BooleanField) Set(d bool) { - *e = BooleanField(d) -} - -// String format the Bool to string -func (e *BooleanField) String() string { - return strconv.FormatBool(e.Value()) -} - -// FieldType return BooleanField the type -func (e *BooleanField) FieldType() int { - return TypeBooleanField -} - -// SetRaw set the interface to bool -func (e *BooleanField) SetRaw(value interface{}) error { - switch d := value.(type) { - case bool: - e.Set(d) - case string: - v, err := StrTo(d).Bool() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the current value -func (e *BooleanField) RawValue() interface{} { - return e.Value() -} +type BooleanField = models.BooleanField // verify the BooleanField implement the Fielder interface var _ Fielder = new(BooleanField) @@ -108,43 +64,7 @@ var _ Fielder = new(BooleanField) // required values tag: size // The size is enforced at the database level and in models’s validation. // eg: `orm:"size(120)"` -type CharField string - -// Value return the CharField's Value -func (e CharField) Value() string { - return string(e) -} - -// Set CharField value -func (e *CharField) Set(d string) { - *e = CharField(d) -} - -// String return the CharField -func (e *CharField) String() string { - return e.Value() -} - -// FieldType return the enum type -func (e *CharField) FieldType() int { - return TypeVarCharField -} - -// SetRaw set the interface to string -func (e *CharField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the CharField value -func (e *CharField) RawValue() interface{} { - return e.Value() -} +type CharField = models.CharField // verify CharField implement Fielder var _ Fielder = new(CharField) @@ -162,49 +82,7 @@ var _ Fielder = new(CharField) // Note that the current date is always used; it’s not just a default value that you can override. // // eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type TimeField time.Time - -// Value return the time.Time -func (e TimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the TimeField's value -func (e *TimeField) Set(d time.Time) { - *e = TimeField(d) -} - -// String convert time to string -func (e *TimeField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *TimeField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *TimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return time value -func (e *TimeField) RawValue() interface{} { - return e.Value() -} +type TimeField = models.TimeField var _ Fielder = new(TimeField) @@ -221,49 +99,7 @@ var _ Fielder = new(TimeField) // Note that the current date is always used; it’s not just a default value that you can override. // // eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type DateField time.Time - -// Value return the time.Time -func (e DateField) Value() time.Time { - return time.Time(e) -} - -// Set set the DateField's value -func (e *DateField) Set(d time.Time) { - *e = DateField(d) -} - -// String convert datetime to string -func (e *DateField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *DateField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *DateField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDate) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return Date value -func (e *DateField) RawValue() interface{} { - return e.Value() -} +type DateField = models.DateField // verify DateField implement fielder interface var _ Fielder = new(DateField) @@ -271,513 +107,67 @@ var _ Fielder = new(DateField) // DateTimeField A date, represented in go by a time.Time instance. // datetime values like 2006-01-02 15:04:05 // Takes the same extra arguments as DateField. -type DateTimeField time.Time - -// Value return the datetime value -func (e DateTimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the time.Time to datetime -func (e *DateTimeField) Set(d time.Time) { - *e = DateTimeField(d) -} - -// String return the time's String -func (e *DateTimeField) String() string { - return e.Value().String() -} - -// FieldType return the enum TypeDateTimeField -func (e *DateTimeField) FieldType() int { - return TypeDateTimeField -} - -// SetRaw convert the string or time.Time to DateTimeField -func (e *DateTimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDateTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the datetime value -func (e *DateTimeField) RawValue() interface{} { - return e.Value() -} +type DateTimeField = models.DateTimeField // verify datetime implement fielder -var _ Fielder = new(DateTimeField) +var _ models.Fielder = new(DateTimeField) // FloatField A floating-point number represented in go by a float32 value. -type FloatField float64 - -// Value return the FloatField value -func (e FloatField) Value() float64 { - return float64(e) -} - -// Set the Float64 -func (e *FloatField) Set(d float64) { - *e = FloatField(d) -} - -// String return the string -func (e *FloatField) String() string { - return ToStr(e.Value(), -1, 32) -} - -// FieldType return the enum type -func (e *FloatField) FieldType() int { - return TypeFloatField -} - -// SetRaw converter interface Float64 float32 or string to FloatField -func (e *FloatField) SetRaw(value interface{}) error { - switch d := value.(type) { - case float32: - e.Set(float64(d)) - case float64: - e.Set(d) - case string: - v, err := StrTo(d).Float64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the FloatField value -func (e *FloatField) RawValue() interface{} { - return e.Value() -} +type FloatField = models.FloatField // verify FloatField implement Fielder var _ Fielder = new(FloatField) // SmallIntegerField -32768 to 32767 -type SmallIntegerField int16 - -// Value return int16 value -func (e SmallIntegerField) Value() int16 { - return int16(e) -} - -// Set the SmallIntegerField value -func (e *SmallIntegerField) Set(d int16) { - *e = SmallIntegerField(d) -} - -// String convert smallint to string -func (e *SmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type SmallIntegerField -func (e *SmallIntegerField) FieldType() int { - return TypeSmallIntegerField -} - -// SetRaw convert interface int16/string to int16 -func (e *SmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int16: - e.Set(d) - case string: - v, err := StrTo(d).Int16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return smallint value -func (e *SmallIntegerField) RawValue() interface{} { - return e.Value() -} +type SmallIntegerField = models.SmallIntegerField // verify SmallIntegerField implement Fielder var _ Fielder = new(SmallIntegerField) // IntegerField -2147483648 to 2147483647 -type IntegerField int32 - -// Value return the int32 -func (e IntegerField) Value() int32 { - return int32(e) -} - -// Set IntegerField value -func (e *IntegerField) Set(d int32) { - *e = IntegerField(d) -} - -// String convert Int32 to string -func (e *IntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return the enum type -func (e *IntegerField) FieldType() int { - return TypeIntegerField -} - -// SetRaw convert interface int32/string to int32 -func (e *IntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int32: - e.Set(d) - case string: - v, err := StrTo(d).Int32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return IntegerField value -func (e *IntegerField) RawValue() interface{} { - return e.Value() -} +type IntegerField = models.IntegerField // verify IntegerField implement Fielder var _ Fielder = new(IntegerField) // BigIntegerField -9223372036854775808 to 9223372036854775807. -type BigIntegerField int64 - -// Value return int64 -func (e BigIntegerField) Value() int64 { - return int64(e) -} - -// Set the BigIntegerField value -func (e *BigIntegerField) Set(d int64) { - *e = BigIntegerField(d) -} - -// String convert BigIntegerField to string -func (e *BigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *BigIntegerField) FieldType() int { - return TypeBigIntegerField -} - -// SetRaw convert interface int64/string to int64 -func (e *BigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int64: - e.Set(d) - case string: - v, err := StrTo(d).Int64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return BigIntegerField value -func (e *BigIntegerField) RawValue() interface{} { - return e.Value() -} +type BigIntegerField = models.BigIntegerField // verify BigIntegerField implement Fielder var _ Fielder = new(BigIntegerField) // PositiveSmallIntegerField 0 to 65535 -type PositiveSmallIntegerField uint16 - -// Value return uint16 -func (e PositiveSmallIntegerField) Value() uint16 { - return uint16(e) -} - -// Set PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) Set(d uint16) { - *e = PositiveSmallIntegerField(d) -} - -// String convert uint16 to string -func (e *PositiveSmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveSmallIntegerField) FieldType() int { - return TypePositiveSmallIntegerField -} - -// SetRaw convert Interface uint16/string to uint16 -func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint16: - e.Set(d) - case string: - v, err := StrTo(d).Uint16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue returns PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) RawValue() interface{} { - return e.Value() -} +type PositiveSmallIntegerField = models.PositiveSmallIntegerField // verify PositiveSmallIntegerField implement Fielder var _ Fielder = new(PositiveSmallIntegerField) // PositiveIntegerField 0 to 4294967295 -type PositiveIntegerField uint32 - -// Value return PositiveIntegerField value. Uint32 -func (e PositiveIntegerField) Value() uint32 { - return uint32(e) -} - -// Set the PositiveIntegerField value -func (e *PositiveIntegerField) Set(d uint32) { - *e = PositiveIntegerField(d) -} - -// String convert PositiveIntegerField to string -func (e *PositiveIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint32/string to Uint32 -func (e *PositiveIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint32: - e.Set(d) - case string: - v, err := StrTo(d).Uint32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the PositiveIntegerField Value -func (e *PositiveIntegerField) RawValue() interface{} { - return e.Value() -} +type PositiveIntegerField = models.PositiveIntegerField // verify PositiveIntegerField implement Fielder var _ Fielder = new(PositiveIntegerField) // PositiveBigIntegerField 0 to 18446744073709551615 -type PositiveBigIntegerField uint64 - -// Value return uint64 -func (e PositiveBigIntegerField) Value() uint64 { - return uint64(e) -} - -// Set PositiveBigIntegerField value -func (e *PositiveBigIntegerField) Set(d uint64) { - *e = PositiveBigIntegerField(d) -} - -// String convert PositiveBigIntegerField to string -func (e *PositiveBigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveBigIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint64/string to Uint64 -func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint64: - e.Set(d) - case string: - v, err := StrTo(d).Uint64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return PositiveBigIntegerField value -func (e *PositiveBigIntegerField) RawValue() interface{} { - return e.Value() -} +type PositiveBigIntegerField = models.PositiveBigIntegerField // verify PositiveBigIntegerField implement Fielder var _ Fielder = new(PositiveBigIntegerField) // TextField A large text field. -type TextField string - -// Value return TextField value -func (e TextField) Value() string { - return string(e) -} - -// Set the TextField value -func (e *TextField) Set(d string) { - *e = TextField(d) -} - -// String convert TextField to string -func (e *TextField) String() string { - return e.Value() -} - -// FieldType return enum type -func (e *TextField) FieldType() int { - return TypeTextField -} - -// SetRaw convert interface string to string -func (e *TextField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return TextField value -func (e *TextField) RawValue() interface{} { - return e.Value() -} +type TextField = models.TextField // verify TextField implement Fielder var _ Fielder = new(TextField) // JSONField postgres json field. -type JSONField string - -// Value return JSONField value -func (j JSONField) Value() string { - return string(j) -} - -// Set the JSONField value -func (j *JSONField) Set(d string) { - *j = JSONField(d) -} - -// String convert JSONField to string -func (j *JSONField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JSONField) FieldType() int { - return TypeJSONField -} - -// SetRaw convert interface string to string -func (j *JSONField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JSONField value -func (j *JSONField) RawValue() interface{} { - return j.Value() -} +type JSONField = models.JSONField // verify JSONField implement Fielder -var _ Fielder = new(JSONField) +var _ models.Fielder = new(JSONField) // JsonbField postgres json field. -type JsonbField string - -// Value return JsonbField value -func (j JsonbField) Value() string { - return string(j) -} - -// Set the JsonbField value -func (j *JsonbField) Set(d string) { - *j = JsonbField(d) -} - -// String convert JsonbField to string -func (j *JsonbField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JsonbField) FieldType() int { - return TypeJsonbField -} - -// SetRaw convert interface string to string -func (j *JsonbField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JsonbField value -func (j *JsonbField) RawValue() interface{} { - return j.Value() -} +type JsonbField = models.JsonbField // verify JsonbField implement Fielder -var _ Fielder = new(JsonbField) +var _ models.Fielder = new(JsonbField) diff --git a/client/orm/models_info_m.go b/client/orm/models_info_m.go deleted file mode 100644 index b94480ca..00000000 --- a/client/orm/models_info_m.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" -) - -// single model info -type modelInfo struct { - manual bool - isThrough bool - pkg string - name string - fullName string - table string - model interface{} - fields *fields - addrField reflect.Value // store the original struct value - uniques []string -} - -// new model info -func newModelInfo(val reflect.Value) (mi *modelInfo) { - mi = &modelInfo{} - mi.fields = newFields() - ind := reflect.Indirect(val) - mi.addrField = val - mi.name = ind.Type().Name() - mi.fullName = getFullName(ind.Type()) - addModelFields(mi, ind, "", []int{}) - return -} - -// index: FieldByIndex returns the nested field corresponding to index -func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { - var ( - err error - fi *fieldInfo - sf reflect.StructField - ) - - for i := 0; i < ind.NumField(); i++ { - field := ind.Field(i) - sf = ind.Type().Field(i) - // if the field is unexported skip - if sf.PkgPath != "" { - continue - } - // add anonymous struct fields - if sf.Anonymous { - addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) - continue - } - - fi, err = newFieldInfo(mi, field, sf, mName) - if err == errSkipField { - err = nil - continue - } else if err != nil { - break - } - // record current field index - fi.fieldIndex = append(fi.fieldIndex, index...) - fi.fieldIndex = append(fi.fieldIndex, i) - fi.mi = mi - fi.inModel = true - if !mi.fields.Add(fi) { - err = fmt.Errorf("duplicate column name: %s", fi.column) - break - } - if fi.pk { - if mi.fields.pk != nil { - err = fmt.Errorf("one model must have one pk field only") - break - } else { - mi.fields.pk = fi - } - } - } - - if err != nil { - fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) - os.Exit(2) - } -} - -// combine related model info to new model info. -// prepare for relation models query. -func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { - mi = new(modelInfo) - mi.fields = newFields() - mi.table = m1.table + "_" + m2.table + "s" - mi.name = camelString(mi.table) - mi.fullName = m1.pkg + "." + mi.name - - fa := new(fieldInfo) // pk - f1 := new(fieldInfo) // m1 table RelForeignKey - f2 := new(fieldInfo) // m2 table RelForeignKey - fa.fieldType = TypeBigIntegerField - fa.auto = true - fa.pk = true - fa.dbcol = true - fa.name = "Id" - fa.column = "id" - fa.fullName = mi.fullName + "." + fa.name - - f1.dbcol = true - f2.dbcol = true - f1.fieldType = RelForeignKey - f2.fieldType = RelForeignKey - f1.name = camelString(m1.table) - f2.name = camelString(m2.table) - f1.fullName = mi.fullName + "." + f1.name - f2.fullName = mi.fullName + "." + f2.name - f1.column = m1.table + "_id" - f2.column = m2.table + "_id" - f1.rel = true - f2.rel = true - f1.relTable = m1.table - f2.relTable = m2.table - f1.relModelInfo = m1 - f2.relModelInfo = m2 - f1.mi = mi - f2.mi = mi - - mi.fields.Add(fa) - mi.fields.Add(f1) - mi.fields.Add(f2) - mi.fields.pk = fa - - mi.uniques = []string{f1.column, f2.column} - return -} diff --git a/client/orm/models_test.go b/client/orm/models_test.go index ea8a89fc..52bafd9e 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -22,6 +22,8 @@ import ( "strings" "time" + "github.com/beego/beego/v2/client/orm/internal/models" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -79,7 +81,7 @@ func (e *SliceStringField) RawValue() interface{} { return e.String() } -var _ Fielder = new(SliceStringField) +var _ models.Fielder = new(SliceStringField) // A json field. type JSONFieldTest struct { @@ -111,7 +113,7 @@ func (e *JSONFieldTest) RawValue() interface{} { return e.String() } -var _ Fielder = new(JSONFieldTest) +var _ models.Fielder = new(JSONFieldTest) type Data struct { ID int `orm:"column(id)"` diff --git a/client/orm/orm.go b/client/orm/orm.go index 3f4a374b..bd84e6df 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.8 -// +build go1.8 - // Package orm provide ORM for MySQL/PostgreSQL/sqlite // Simple Usage // @@ -57,9 +54,12 @@ import ( "database/sql" "errors" "fmt" - "os" "reflect" - "time" + + ilogs "github.com/beego/beego/v2/client/orm/internal/logs" + iutils "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" @@ -75,10 +75,10 @@ const ( // Define common vars var ( Debug = false - DebugLog = NewLog(os.Stdout) + DebugLog = ilogs.DebugLog DefaultRowsLimit = -1 DefaultRelsDepth = 2 - DefaultTimeLoc = time.Local + DefaultTimeLoc = iutils.DefaultTimeLoc ErrTxDone = errors.New(" transaction already done") ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") @@ -107,7 +107,7 @@ var ( ) // get model info and model reflect value -func (*ormBase) getMi(md interface{}) (mi *modelInfo) { +func (*ormBase) getMi(md interface{}) (mi *models.ModelInfo) { val := reflect.ValueOf(md) ind := reflect.Indirect(val) typ := ind.Type() @@ -116,19 +116,19 @@ func (*ormBase) getMi(md interface{}) (mi *modelInfo) { } // get need ptr model info and model reflect value -func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { +func (*ormBase) getPtrMiInd(md interface{}) (mi *models.ModelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", models.GetFullName(typ))) } mi = getTypeMi(typ) return } -func getTypeMi(mdTyp reflect.Type) *modelInfo { - name := getFullName(mdTyp) +func getTypeMi(mdTyp reflect.Type) *models.ModelInfo { + name := models.GetFullName(mdTyp) if mi, ok := defaultModelCache.getByFullName(name); ok { return mi } @@ -136,10 +136,10 @@ func getTypeMi(mdTyp reflect.Type) *modelInfo { } // get field info from model info by given field name -func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { - fi, ok := mi.fields.GetByAny(name) +func (*ormBase) getFieldInfo(mi *models.ModelInfo, name string) *models.FieldInfo { + fi, ok := mi.Fields.GetByAny(name) if !ok { - panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) + panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.FullName)) } return fi } @@ -179,11 +179,11 @@ func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 return err == nil, id, err } - id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id, vid := int64(0), ind.FieldByIndex(mi.Fields.Pk.FieldIndex) + if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { id = int64(vid.Uint()) - } else if mi.fields.pk.rel { - return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + } else if mi.Fields.Pk.Rel { + return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.Fields.Pk.RelModelInfo.Fields.Pk.Name) } else { id = vid.Int() } @@ -209,12 +209,12 @@ func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, err } // set auto pk field -func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) +func (*ormBase) setPk(mi *models.ModelInfo, ind reflect.Value, id int64) { + if mi.Fields.Pk.Auto { + if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetUint(uint64(id)) } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) + ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetInt(id) } } } @@ -276,7 +276,7 @@ func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, col } // update model to database. -// cols set the columns those want to update. +// cols set the Columns those want to update. func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { return o.UpdateWithCtx(context.Background(), md, cols...) } @@ -304,10 +304,10 @@ func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { fi := o.getFieldInfo(mi, name) switch { - case fi.fieldType == RelManyToMany: - case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: + case fi.FieldType == RelManyToMany: + case fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough: default: - panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) + panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.Name, mi.FullName)) } return newQueryM2M(md, o, mi, fi, ind) @@ -362,7 +362,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str } }) - switch fi.fieldType { + switch fi.FieldType { case RelOneToOne, RelForeignKey, RelReverseOne: limit = 1 offset = 0 @@ -376,11 +376,11 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str qs.orders = order_clause.ParseOrder(order) } - find := ind.FieldByIndex(fi.fieldIndex) + find := ind.FieldByIndex(fi.FieldIndex) var nums int64 var err error - switch fi.fieldType { + switch fi.FieldType { case RelOneToOne, RelForeignKey, RelReverseOne: val := reflect.New(find.Type().Elem()) container := val.Interface() @@ -397,7 +397,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str } // get QuerySeter for related models to md model -func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { +func (o *ormBase) queryRelated(md interface{}, name string) (*models.ModelInfo, *models.FieldInfo, reflect.Value, *querySet) { mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) @@ -408,14 +408,14 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI var qs *querySet - switch fi.fieldType { + switch fi.FieldType { case RelOneToOne, RelForeignKey, RelManyToMany: - if !fi.inModel { + if !fi.InModel { break } qs = o.getRelQs(md, mi, fi) case RelReverseOne, RelReverseMany: - if !fi.inModel { + if !fi.InModel { break } qs = o.getReverseQs(md, mi, fi) @@ -429,41 +429,41 @@ func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldI } // get reverse relation QuerySeter -func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { +func (o *ormBase) getReverseQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet { + switch fi.FieldType { case RelReverseOne, RelReverseMany: default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) + panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.Name, mi.FullName)) } var q *querySet - if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { - q = newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + if fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough { + q = newQuerySet(o, fi.RelModelInfo).(*querySet) + q.cond = NewCondition().And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md) } else { - q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) + q = newQuerySet(o, fi.ReverseFieldInfo.Mi).(*querySet) + q.cond = NewCondition().And(fi.ReverseFieldInfo.Column, md) } return q } // get relation QuerySeter -func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { +func (o *ormBase) getRelQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet { + switch fi.FieldType { case RelOneToOne, RelForeignKey, RelManyToMany: default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.Name, mi.FullName)) } - q := newQuerySet(o, fi.relModelInfo).(*querySet) + q := newQuerySet(o, fi.RelModelInfo).(*querySet) q.cond = NewCondition() - if fi.fieldType == RelManyToMany { - q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + if fi.FieldType == RelManyToMany { + q.cond = q.cond.And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md) } else { - q.cond = q.cond.And(fi.reverseFieldInfo.column, md) + q.cond = q.cond.And(fi.ReverseFieldInfo.Column, md) } return q @@ -475,12 +475,12 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { - name = nameStrategyMap[defaultNameStrategy](table) + name = models.NameStrategyMap[models.DefaultNameStrategy](table) if mi, ok := defaultModelCache.get(name); ok { qs = newQuerySet(o, mi) } } else { - name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + name = models.GetFullName(iutils.IndirectType(reflect.TypeOf(ptrStructOrTableName))) if mi, ok := defaultModelCache.getByFullName(name); ok { qs = newQuerySet(o, mi) } diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index e6f8bc83..50ebc3a6 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -22,23 +22,22 @@ import ( "log" "strings" "time" + + "github.com/beego/beego/v2/client/orm/internal/logs" ) -// Log implement the log.Logger -type Log struct { - *log.Logger -} - -// costomer log func -var LogFunc func(query map[string]interface{}) +type Log = logs.Log // NewLog set io.Writer to create a Logger. -func NewLog(out io.Writer) *Log { - d := new(Log) +func NewLog(out io.Writer) *logs.Log { + d := new(logs.Log) d.Logger = log.New(out, "[ORM]", log.LstdFlags) return d } +// LogFunc costomer log func +var LogFunc func(query map[string]interface{}) + func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { logMap := make(map[string]interface{}) sub := time.Since(t) / 1e5 @@ -64,7 +63,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error if LogFunc != nil { LogFunc(logMap) } - DebugLog.Println(con) + logs.DebugLog.Println(con) } // statement query logger struct. diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 50c1ca41..55395fe5 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -18,11 +18,13 @@ import ( "context" "fmt" "reflect" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // an insert queryer struct type insertSet struct { - mi *modelInfo + mi *models.ModelInfo orm *ormBase stmt stmtQuerier closed bool @@ -42,23 +44,23 @@ func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, e val := reflect.ValueOf(md) ind := reflect.Indirect(val) typ := ind.Type() - name := getFullName(typ) + name := models.GetFullName(typ) if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) } - if name != o.mi.fullName { - panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) + if name != o.mi.FullName { + panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.FullName, name)) } id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } if id > 0 { - if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) + if o.mi.Fields.Pk.Auto { + if o.mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetUint(uint64(id)) } else { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) + ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetInt(id) } } } @@ -75,7 +77,7 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *models.ModelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 44312ae3..6dc66b3d 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -17,13 +17,15 @@ package orm import ( "context" "reflect" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // model to model struct type queryM2M struct { md interface{} - mi *modelInfo - fi *fieldInfo + mi *models.ModelInfo + fi *models.FieldInfo qs *querySet ind reflect.Value } @@ -42,9 +44,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi - mi := fi.relThroughModelInfo - mfi := fi.reverseFieldInfo - rfi := fi.reverseFieldInfoTwo + mi := fi.RelThroughModelInfo + mfi := fi.ReverseFieldInfo + rfi := fi.ReverseFieldInfoTwo orm := o.qs.orm dbase := orm.alias.DbBaser @@ -53,9 +55,9 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e var otherValues []interface{} var otherNames []string - for _, colname := range mi.fields.dbcols { - if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && - mi.fields.columns[colname] != mi.fields.pk { + for _, colname := range mi.Fields.DBcols { + if colname != mfi.Column && colname != rfi.Column && colname != fi.Mi.Fields.Pk.Column && + mi.Fields.Columns[colname] != mi.Fields.Pk { otherNames = append(otherNames, colname) } } @@ -84,7 +86,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e panic(ErrMissPK) } - names := []string{mfi.column, rfi.column} + names := []string{mfi.Column, rfi.Column} values := make([]interface{}, 0, len(models)*2) for _, md := range models { @@ -94,7 +96,7 @@ func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, e if ind.Kind() != reflect.Struct { v2 = ind.Interface() } else { - _, v2, exist = getExistPk(fi.relModelInfo, ind) + _, v2, exist = getExistPk(fi.RelModelInfo, ind) if !exist { panic(ErrMissPK) } @@ -114,9 +116,9 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi - qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) + qs := o.qs.Filter(fi.ReverseFieldInfo.Name, o.md) - return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() + return qs.Filter(fi.ReverseFieldInfoTwo.Name+ExprSep+"in", mds).Delete() } // check model is existed in relationship of origin model @@ -126,8 +128,8 @@ func (o *queryM2M) Exist(md interface{}) bool { func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md). + Filter(fi.ReverseFieldInfoTwo.Name, md).ExistWithCtx(ctx) } // clean all models in related of origin model @@ -137,7 +139,7 @@ func (o *queryM2M) Clear() (int64, error) { func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model @@ -147,18 +149,18 @@ func (o *queryM2M) Count() (int64, error) { func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md qm2m.mi = mi qm2m.fi = fi qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + qm2m.qs = newQuerySet(o, fi.RelThroughModelInfo).(*querySet) return qm2m } diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index c922a37f..8464741b 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,6 +18,10 @@ import ( "context" "fmt" + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" ) @@ -54,7 +58,7 @@ func ColValue(opt operator, value interface{}) interface{} { default: panic(fmt.Errorf("orm.ColValue wrong operator")) } - v, err := StrTo(ToStr(value)).Int64() + v, err := utils.StrTo(utils.ToStr(value)).Int64() if err != nil { panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) } @@ -66,7 +70,7 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo + mi *models.ModelInfo cond *Condition related []string relDepth int @@ -113,13 +117,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { // set offset number func (o *querySet) setOffset(num interface{}) { - o.offset = ToInt64(num) + o.offset = utils.ToInt64(num) } // add LIMIT value. // args[0] means offset, e.g. LIMIT num,offset. func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { - o.limit = ToInt64(limit) + o.limit = utils.ToInt64(limit) if len(args) > 0 { o.setOffset(args[0]) } @@ -273,7 +277,7 @@ func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { } // query all data and map to containers. -// cols means the columns when querying. +// cols means the Columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { return o.AllWithCtx(context.Background(), container, cols...) } @@ -283,7 +287,7 @@ func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols . } // query one row data and map to containers. -// cols means the columns when querying. +// cols means the Columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { return o.OneWithCtx(context.Background(), container, cols...) } @@ -366,7 +370,7 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) } // create new QuerySeter. -func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { +func newQuerySet(orm *ormBase, mi *models.ModelInfo) QuerySeter { o := new(querySet) o.mi = mi o.orm = orm diff --git a/client/orm/orm_raw.go b/client/orm/orm_raw.go index f40d7c86..d1b1d0a7 100644 --- a/client/orm/orm_raw.go +++ b/client/orm/orm_raw.go @@ -20,6 +20,10 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/pkg/errors" ) @@ -95,7 +99,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } else if v, ok := value.(bool); ok { ind.SetBool(v) } else { - v, _ := StrTo(ToStr(value)).Bool() + v, _ := utils.StrTo(utils.ToStr(value)).Bool() ind.SetBool(v) } @@ -103,7 +107,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if value == nil { ind.SetString("") } else { - ind.SetString(ToStr(value)) + ind.SetString(utils.ToStr(value)) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -117,7 +121,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: ind.SetInt(int64(val.Uint())) default: - v, _ := StrTo(ToStr(value)).Int64() + v, _ := utils.StrTo(utils.ToStr(value)).Int64() ind.SetInt(v) } } @@ -132,7 +136,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: ind.SetUint(val.Uint()) default: - v, _ := StrTo(ToStr(value)).Uint64() + v, _ := utils.StrTo(utils.ToStr(value)).Uint64() ind.SetUint(v) } } @@ -145,7 +149,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Float64: ind.SetFloat(val.Float()) default: - v, _ := StrTo(ToStr(value)).Float64() + v, _ := utils.StrTo(utils.ToStr(value)).Float64() ind.SetFloat(v) } } @@ -170,20 +174,20 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if str != "" { if len(str) >= 19 { str = str[:19] - t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) + t, err := time.ParseInLocation(utils.FormatDateTime, str, o.orm.alias.TZ) if err == nil { t = t.In(DefaultTimeLoc) ind.Set(reflect.ValueOf(t)) } } else if len(str) >= 10 { str = str[:10] - t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) + t, err := time.ParseInLocation(utils.FormatDate, str, DefaultTimeLoc) if err == nil { ind.Set(reflect.ValueOf(t)) } } else if len(str) >= 8 { str = str[:8] - t, err := time.ParseInLocation(formatTime, str, DefaultTimeLoc) + t, err := time.ParseInLocation(utils.FormatTime, str, DefaultTimeLoc) if err == nil { ind.Set(reflect.ValueOf(t)) } @@ -287,7 +291,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { refs = make([]interface{}, 0, len(containers)) sInds []reflect.Value eTyps []reflect.Type - sMi *modelInfo + sMi *models.ModelInfo ) structMode := false for _, container := range containers { @@ -313,7 +317,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { } structMode = true - fn := getFullName(typ) + fn := models.GetFullName(typ) if mi, ok := defaultModelCache.getByFullName(fn); ok { sMi = mi } @@ -370,16 +374,16 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { if sMi != nil { for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { + if fi := sMi.Fields.GetByColumn(col); fi != nil { value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsRelField > 0 { + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) } - if fi.isFielder { - fd := field.Addr().Interface().(Fielder) + if fi.IsFielder { + fd := field.Addr().Interface().(models.Fielder) err := fd.SetRaw(value) if err != nil { return errors.Errorf("set raw error:%s", err) @@ -406,12 +410,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { // thanks @Gazeboxu. tags := structTagMap[fe.Tag] if tags == nil { - _, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) + _, tags = models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName)) structTagMap[fe.Tag] = tags } var col string if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) + col = models.NameStrategyMap[models.NameStrategy](fe.Name) } if v, ok := columnsMp[col]; ok { value := reflect.ValueOf(v).Elem().Interface() @@ -449,7 +453,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { refs = make([]interface{}, 0, len(containers)) sInds []reflect.Value eTyps []reflect.Type - sMi *modelInfo + sMi *models.ModelInfo ) structMode := false for _, container := range containers { @@ -474,7 +478,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { } structMode = true - fn := getFullName(typ) + fn := models.GetFullName(typ) if mi, ok := defaultModelCache.getByFullName(fn); ok { sMi = mi } @@ -537,16 +541,16 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { if sMi != nil { for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { + if fi := sMi.Fields.GetByColumn(col); fi != nil { value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsRelField > 0 { + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) } - if fi.isFielder { - fd := field.Addr().Interface().(Fielder) + if fi.IsFielder { + fd := field.Addr().Interface().(models.Fielder) err := fd.SetRaw(value) if err != nil { return 0, errors.Errorf("set raw error:%s", err) @@ -570,10 +574,10 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + _, tags := models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName)) var col string if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) + col = models.NameStrategyMap[models.NameStrategy](fe.Name) } if v, ok := columnsMp[col]; ok { value := reflect.ValueOf(v).Elem().Interface() @@ -837,7 +841,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in } default: - if id := ind.FieldByName(camelString(key)); id.IsValid() { + if id := ind.FieldByName(models.CamelString(key)); id.IsValid() { o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) } } diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 4fbd3a20..0c31e085 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.8 -// +build go1.8 - package orm import ( @@ -32,6 +29,12 @@ import ( "testing" "time" + "github.com/beego/beego/v2/client/orm/internal/logs" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/client/orm/clauses/order_clause" @@ -41,9 +44,9 @@ import ( var _ = os.PathSeparator var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" + testDate = utils.FormatDate + " -0700" + testDateTime = utils.FormatDateTime + " -0700" + testTime = utils.FormatTime + " -0700" ) type argAny []interface{} @@ -72,7 +75,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er case time.Time: if v2, vo := b.(time.Time); vo { if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) + format := utils.ToStr(arg.Get(1)) a = v.Format(format) b = v2.Format(format) ok = a == b @@ -82,7 +85,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er } } default: - ok = ToStr(a) == ToStr(b) + ok = utils.ToStr(a) == utils.ToStr(b) } ok = is && ok || !is && !ok if !ok { @@ -250,14 +253,14 @@ func TestRegisterModels(_ *testing.T) { func TestModelSyntax(t *testing.T) { user := &User{} ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) + fn := models.GetFullName(ind.Type()) _, ok := defaultModelCache.getByFullName(fn) throwFail(t, AssertIs(ok, true)) mi, ok := defaultModelCache.get("user") throwFail(t, AssertIs(ok, true)) if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + throwFail(t, AssertIs(mi.Fields.GetByName("ShouldSkip") == nil, true)) } } @@ -561,7 +564,7 @@ func TestNullDataTypes(t *testing.T) { assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) - // test support for pointer fields using RawSeter.QueryRows() + // test support for pointer Fields using RawSeter.QueryRows() var dnList []*DataNull Q := dDbBaser.TableQuote() num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) @@ -1894,7 +1897,7 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(row.Id, 4)) throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com")) - // test for sql.Null* fields + // test for sql.Null* Fields nData := &DataNull{ NullString: sql.NullString{String: "test sql.null", Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true}, @@ -2003,7 +2006,7 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].Age, 30)) - // test for sql.Null* fields + // test for sql.Null* Fields nData := &DataNull{ NullString: sql.NullString{String: "test sql.null", Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true}, @@ -2616,7 +2619,7 @@ func TestSnake(t *testing.T) { "tag_666Name": "tag_666_name", } for name, want := range cases { - got := snakeString(name) + got := models.SnakeString(name) throwFail(t, AssertIs(got, want)) } } @@ -2637,10 +2640,10 @@ func TestIgnoreCaseTag(t *testing.T) { if t == nil { return } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) + throwFail(t, AssertIs(info.Fields.GetByName("NOO").Column, "n")) + throwFail(t, AssertIs(info.Fields.GetByName("Name01").Null, true)) + throwFail(t, AssertIs(info.Fields.GetByName("Name02").Column, "Name")) + throwFail(t, AssertIs(info.Fields.GetByName("Name03").Column, "name")) } func TestInsertOrUpdate(t *testing.T) { @@ -2934,9 +2937,9 @@ func TestDebugLog(t *testing.T) { func captureDebugLogOutput(f func()) string { var buf bytes.Buffer - DebugLog.SetOutput(&buf) + logs.DebugLog.SetOutput(&buf) defer func() { - DebugLog.SetOutput(os.Stderr) + logs.DebugLog.SetOutput(os.Stderr) }() f() return buf.String() diff --git a/client/orm/qb_mysql.go b/client/orm/qb_mysql.go index 19130496..df65e11d 100644 --- a/client/orm/qb_mysql.go +++ b/client/orm/qb_mysql.go @@ -28,7 +28,7 @@ type MySQLQueryBuilder struct { tokens []string } -// Select will join the fields +// Select will join the Fields func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace)) return qb @@ -94,7 +94,7 @@ func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { return qb } -// OrderBy join the Order by fields +// OrderBy join the Order by Fields func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace)) return qb @@ -124,7 +124,7 @@ func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { return qb } -// GroupBy join the Group by fields +// GroupBy join the Group by Fields func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace)) return qb diff --git a/client/orm/qb_postgres.go b/client/orm/qb_postgres.go index d7f21692..3e5ec1c6 100644 --- a/client/orm/qb_postgres.go +++ b/client/orm/qb_postgres.go @@ -19,7 +19,7 @@ func processingStr(str []string) string { return s } -// Select will join the fields +// Select will join the Fields func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder { var str string n := len(fields) @@ -121,7 +121,7 @@ func (qb *PostgresQueryBuilder) In(vals ...string) QueryBuilder { return qb } -// OrderBy join the Order by fields +// OrderBy join the Order by Fields func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder { str := processingStr(fields) qb.tokens = append(qb.tokens, "ORDER BY", str) @@ -152,7 +152,7 @@ func (qb *PostgresQueryBuilder) Offset(offset int) QueryBuilder { return qb } -// GroupBy join the Group by fields +// GroupBy join the Group by Fields func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder { str := processingStr(fields) qb.tokens = append(qb.tokens, "GROUP BY", str) diff --git a/client/orm/types.go b/client/orm/types.go index 140eeda3..649d29fc 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -20,11 +20,13 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) -// TableNaming is usually used by model +// TableNameI is usually used by model // when you custom your table name, please implement this interfaces // for example: // @@ -95,22 +97,16 @@ type Driver interface { Type() DriverType } -// Fielder define field info -type Fielder interface { - String() string - FieldType() int - SetRaw(interface{}) error - RawValue() interface{} -} +type Fielder = models.Fielder type TxBeginner interface { - // self control transaction + // Begin self control transaction Begin() (TxOrmer, error) BeginWithCtx(ctx context.Context) (TxOrmer, error) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) - // closure control transaction + // DoTx closure control transaction DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error @@ -146,27 +142,27 @@ type txEnder interface { RollbackUnlessCommit() error } -// Data Manipulation Language +// DML Data Manipulation Language type DML interface { - // insert model data to database + // Insert insert model data to database // for example: // user := new(User) // id, err = Ormer.Insert(user) // user must be a pointer and Insert will set user's pk field Insert(md interface{}) (int64, error) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) - // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // InsertOrUpdate mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // if colu type is integer : can use(+-*/), string : convert(colu,"value") // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") // if colu type is integer : can use(+-*/), string : colu || "value" InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) - // insert some models to database + // InsertMulti inserts some models to database InsertMulti(bulk int, mds interface{}) (int64, error) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) - // update model to database. - // cols set the columns those want to update. - // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // Update updates model to database. + // cols set the Columns those want to update. + // find model by Id(pk) field and update Columns specified by Fields, if cols is null then update all Columns // for example: // user := User{Id: 2} // user.Langs = append(user.Langs, "zh-CN", "en-US") @@ -175,11 +171,11 @@ type DML interface { // num, err = Ormer.Update(&user, "Langs", "Extra") Update(md interface{}, cols ...string) (int64, error) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) - // delete model in database + // Delete deletes model in database Delete(md interface{}, cols ...string) (int64, error) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) - // return a raw query seter for raw sql string. + // Raw return a raw query seter for raw sql string. // for example: // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() // // update user testing's name to slene @@ -187,9 +183,9 @@ type DML interface { RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter } -// Data Query Language +// DQL Data Query Language type DQL interface { - // read data to model + // Read reads data to model // for example: // this will find User by Id field // u = &User{Id: user.Id} @@ -200,16 +196,16 @@ type DQL interface { Read(md interface{}, cols ...string) error ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // ReadForUpdate Like Read(), but with "FOR UPDATE" clause, useful in transaction. // Some databases are not support this feature. ReadForUpdate(md interface{}, cols ...string) error ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist + // ReadOrCreate Try to read a row from the database, or insert one if it doesn't exist ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) - // load related models to md model. + // LoadRelated load related models to md model. // args are limit, offset int and order string. // // example: @@ -224,20 +220,20 @@ type DQL interface { LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) - // create a models to models queryer + // QueryM2M create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer - // NOTE: this method is deprecated, context parameter will not take effect. + // QueryM2MWithCtx NOTE: this method is deprecated, context parameter will not take effect. // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer - // return a QuerySeter for table operations. + // QueryTable return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter - // NOTE: this method is deprecated, context parameter will not take effect. + // QueryTableWithCtx NOTE: this method is deprecated, context parameter will not take effect. // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter @@ -278,7 +274,7 @@ type Inserter interface { // QuerySeter query seter type QuerySeter interface { - // add condition expression to QuerySeter. + // Filter add condition expression to QuerySeter. // for example: // filter by UserName == 'slene' // qs.Filter("UserName", "slene") @@ -287,22 +283,22 @@ type QuerySeter interface { // // time compare // qs.Filter("created", time.Now()) Filter(string, ...interface{}) QuerySeter - // add raw sql to querySeter. + // FilterRaw add raw sql to querySeter. // for example: // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) FilterRaw(string, string) QuerySeter - // add NOT condition to querySeter. + // Exclude add NOT condition to querySeter. // have the same usage as Filter Exclude(string, ...interface{}) QuerySeter - // set condition to QuerySeter. + // SetCond set condition to QuerySeter. // sql's where condition // cond := orm.NewCondition() // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // num, err := qs.SetCond(cond1).Count() SetCond(*Condition) QuerySeter - // get condition from QuerySeter. + // GetCond get condition from QuerySeter. // sql's where condition // cond := orm.NewCondition() // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) @@ -312,7 +308,7 @@ type QuerySeter interface { // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // num, err := qs.SetCond(cond).Count() GetCond() *Condition - // add LIMIT value. + // Limit add LIMIT value. // args[0] means offset, e.g. LIMIT num,offset. // if Limit <= 0 then Limit will be set to default limit ,eg 1000 // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 @@ -320,19 +316,19 @@ type QuerySeter interface { // qs.Limit(10, 2) // // sql-> limit 10 offset 2 Limit(limit interface{}, args ...interface{}) QuerySeter - // add OFFSET value + // Offset add OFFSET value // same as Limit function's args[0] Offset(offset interface{}) QuerySeter - // add GROUP BY expression + // GroupBy add GROUP BY expression // for example: // qs.GroupBy("id") GroupBy(exprs ...string) QuerySeter - // add ORDER expression. + // OrderBy add ORDER expression. // "column" means ASC, "-column" means DESC. // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter - // add ORDER expression by order clauses + // OrderClauses add ORDER expression by order clauses // for example: // OrderClauses( // order_clause.Clause( @@ -354,50 +350,50 @@ type QuerySeter interface { // order_clause.Raw(),//default false.if true, do not check field is valid or not // )) OrderClauses(orders ...*order_clause.Order) QuerySeter - // add FORCE INDEX expression. + // ForceIndex add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive ForceIndex(indexes ...string) QuerySeter - // add USE INDEX expression. + // UseIndex add USE INDEX expression. // for example: // qs.UseIndex(`idx_name1`,`idx_name2`) // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive UseIndex(indexes ...string) QuerySeter - // add IGNORE INDEX expression. + // IgnoreIndex add IGNORE INDEX expression. // for example: // qs.IgnoreIndex(`idx_name1`,`idx_name2`) // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive IgnoreIndex(indexes ...string) QuerySeter - // set relation model to query together. + // RelatedSel set relation model to query together. // it will query relation models and assign to parent model. // for example: - // // will load all related fields use left join . + // // will load all related Fields use left join . // qs.RelatedSel().One(&user) // // will load related field only profile // qs.RelatedSel("profile").One(&user) // user.Profile.Age = 32 RelatedSel(params ...interface{}) QuerySeter - // Set Distinct + // Distinct Set Distinct // for example: // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). // Distinct(). // All(&permissions) Distinct() QuerySeter - // set FOR UPDATE to query. + // ForUpdate set FOR UPDATE to query. // for example: // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) ForUpdate() QuerySeter - // return QuerySeter execution result number + // Count returns QuerySeter execution result number // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) CountWithCtx(context.Context) (int64, error) - // check result empty or not after QuerySeter executed + // Exist check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool ExistWithCtx(context.Context) bool - // execute update with parameters + // Update execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ // "Nums": ColValue(Col_Minus, 50), @@ -407,13 +403,13 @@ type QuerySeter interface { // }) // user slene's name will change to slene2 Update(values Params) (int64, error) UpdateWithCtx(ctx context.Context, values Params) (int64, error) - // delete from table + // Delete delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) DeleteWithCtx(context.Context) (int64, error) - // return an insert queryer. + // PrepareInsert return an insert queryer. // it can be used in times. // example: // i,err := sq.PrepareInsert() @@ -422,21 +418,21 @@ type QuerySeter interface { // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) PrepareInsertWithCtx(context.Context) (Inserter, error) - // query all data and map to containers. - // cols means the columns when querying. + // All query all data and map to containers. + // cols means the Columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) - // query one row data and map to containers. - // cols means the columns when querying. + // One query one row data and map to containers. + // cols means the Columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error OneWithCtx(ctx context.Context, container interface{}, cols ...string) error - // query all data and map to []map[string]interface. + // Values query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. // for example: @@ -444,21 +440,21 @@ type QuerySeter interface { // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) - // query all data and map to [][]interface + // ValuesList query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) - // query all data and map to []interface. + // ValuesFlat query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. + // RowsToMap query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value @@ -469,7 +465,7 @@ type QuerySeter interface { // "found": 200, // } RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. + // RowsToStruct query all rows into struct with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value @@ -480,7 +476,7 @@ type QuerySeter interface { // Found int // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) - // aggregate func. + // Aggregate aggregate func. // for example: // type result struct { // DeptName string @@ -494,7 +490,7 @@ type QuerySeter interface { // QueryM2Mer model to model query struct // all operations are on the m2m table only, will not affect the origin model table type QueryM2Mer interface { - // add models to origin models when creating queryM2M. + // Add adds models to origin models when creating queryM2M. // example: // m2m := orm.QueryM2M(post,"Tag") // m2m.Add(&Tag1{},&Tag2{}) @@ -507,20 +503,20 @@ type QueryM2Mer interface { // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) AddWithCtx(context.Context, ...interface{}) (int64, error) - // remove models following the origin model relationship + // Remove removes models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) RemoveWithCtx(context.Context, ...interface{}) (int64, error) - // check model is existed in relationship of origin model + // Exist checks model is existed in relationship of origin model Exist(interface{}) bool ExistWithCtx(context.Context, interface{}) bool - // clean all models in related of origin model + // Clear cleans all models in related of origin model Clear() (int64, error) ClearWithCtx(context.Context) (int64, error) - // count all related models of origin model + // Count counts all related models of origin model Count() (int64, error) CountWithCtx(context.Context) (int64, error) } @@ -538,32 +534,32 @@ type RawPreparer interface { // sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) // rs := Ormer.Raw(sql, 1) type RawSeter interface { - // execute sql and get result + // Exec execute sql and get result Exec() (sql.Result, error) - // query data and map to container + // QueryRow query data and map to container // for example: // var name string // var id int // rs.QueryRow(&id,&name) // id==2 name=="slene" QueryRow(containers ...interface{}) error - // query data rows and map to container + // QueryRows query data rows and map to container // var ids []int // var names []int // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} QueryRows(containers ...interface{}) (int64, error) SetArgs(...interface{}) RawSeter - // query data to []map[string]interface + // Values query data to []map[string]interface // see QuerySeter's Values Values(container *[]Params, cols ...string) (int64, error) - // query data to [][]interface + // ValuesList query data to [][]interface // see QuerySeter's ValuesList ValuesList(container *[]ParamsList, cols ...string) (int64, error) - // query data to []interface + // ValuesFlat query data to []interface // see QuerySeter's ValuesFlat ValuesFlat(container *ParamsList, cols ...string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. + // RowsToMap query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value @@ -574,7 +570,7 @@ type RawSeter interface { // "found": 200, // } RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. + // RowsToStruct query all rows into struct with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value @@ -586,7 +582,7 @@ type RawSeter interface { // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) - // return prepared raw statement for used in times. + // Prepare return prepared raw statement for used in times. // for example: // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) @@ -626,32 +622,32 @@ type dbQuerier interface { // base database struct type dbBaser interface { - Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *models.ModelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error) - Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string - GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) - GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) + GenerateOperatorSQL(*models.ModelInfo, *models.FieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*models.FieldInfo, string, *string) + PrepareInsert(context.Context, dbQuerier, *models.ModelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) - HasReturningID(*modelInfo, *string) bool + HasReturningID(*models.ModelInfo, *string) bool TimeFromDB(*time.Time, *time.Location) TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string @@ -660,8 +656,8 @@ type dbBaser interface { ShowTablesQuery() string ShowColumnsQuery(string) string IndexExists(context.Context, dbQuerier, string, string) bool - collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(context.Context, dbQuerier, *modelInfo, []string) error + collectFieldValue(*models.ModelInfo, *models.FieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(context.Context, dbQuerier, *models.ModelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/client/orm/utils.go b/client/orm/utils.go index 8d05c080..40a818d6 100644 --- a/client/orm/utils.go +++ b/client/orm/utils.go @@ -15,305 +15,15 @@ package orm import ( - "fmt" - "math/big" - "reflect" - "strconv" - "strings" - "time" + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/internal/utils" ) -type fn func(string) string +type StrTo = utils.StrTo -var ( - nameStrategyMap = map[string]fn{ - defaultNameStrategy: snakeString, - SnakeAcronymNameStrategy: snakeStringWithAcronym, - } - defaultNameStrategy = "snakeString" - SnakeAcronymNameStrategy = "snakeStringWithAcronym" - nameStrategy = defaultNameStrategy -) - -// StrTo is the target string -type StrTo string - -// Set string -func (f *StrTo) Set(v string) { - if v != "" { - *f = StrTo(v) - } else { - f.Clear() - } -} - -// Clear string -func (f *StrTo) Clear() { - *f = StrTo(rune(0x1E)) -} - -// Exist check string exist -func (f StrTo) Exist() bool { - return string(f) != string(rune(0x1E)) -} - -// Bool string to bool -func (f StrTo) Bool() (bool, error) { - return strconv.ParseBool(f.String()) -} - -// Float32 string to float32 -func (f StrTo) Float32() (float32, error) { - v, err := strconv.ParseFloat(f.String(), 32) - return float32(v), err -} - -// Float64 string to float64 -func (f StrTo) Float64() (float64, error) { - return strconv.ParseFloat(f.String(), 64) -} - -// Int string to int -func (f StrTo) Int() (int, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int(v), err -} - -// Int8 string to int8 -func (f StrTo) Int8() (int8, error) { - v, err := strconv.ParseInt(f.String(), 10, 8) - return int8(v), err -} - -// Int16 string to int16 -func (f StrTo) Int16() (int16, error) { - v, err := strconv.ParseInt(f.String(), 10, 16) - return int16(v), err -} - -// Int32 string to int32 -func (f StrTo) Int32() (int32, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int32(v), err -} - -// Int64 string to int64 -func (f StrTo) Int64() (int64, error) { - v, err := strconv.ParseInt(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) // octal - if !ok { - return v, err - } - return ni.Int64(), nil - } - return v, err -} - -// Uint string to uint -func (f StrTo) Uint() (uint, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint(v), err -} - -// Uint8 string to uint8 -func (f StrTo) Uint8() (uint8, error) { - v, err := strconv.ParseUint(f.String(), 10, 8) - return uint8(v), err -} - -// Uint16 string to uint16 -func (f StrTo) Uint16() (uint16, error) { - v, err := strconv.ParseUint(f.String(), 10, 16) - return uint16(v), err -} - -// Uint32 string to uint32 -func (f StrTo) Uint32() (uint32, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint32(v), err -} - -// Uint64 string to uint64 -func (f StrTo) Uint64() (uint64, error) { - v, err := strconv.ParseUint(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) - if !ok { - return v, err - } - return ni.Uint64(), nil - } - return v, err -} - -// String string to string -func (f StrTo) String() string { - if f.Exist() { - return string(f) - } - return "" -} - -// ToStr interface to string -func ToStr(value interface{}, args ...int) (s string) { - switch v := value.(type) { - case bool: - s = strconv.FormatBool(v) - case float32: - s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) - case float64: - s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) - case int: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int8: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int16: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int32: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int64: - s = strconv.FormatInt(v, argInt(args).Get(0, 10)) - case uint: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint8: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint16: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint32: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint64: - s = strconv.FormatUint(v, argInt(args).Get(0, 10)) - case string: - s = v - case []byte: - s = string(v) - default: - s = fmt.Sprintf("%v", v) - } - return s -} - -// ToInt64 interface to int64 -func ToInt64(value interface{}) (d int64) { - val := reflect.ValueOf(value) - switch value.(type) { - case int, int8, int16, int32, int64: - d = val.Int() - case uint, uint8, uint16, uint32, uint64: - d = int64(val.Uint()) - default: - panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) - } - return -} - -func snakeStringWithAcronym(s string) string { - data := make([]byte, 0, len(s)*2) - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - before := false - after := false - if i > 0 { - before = s[i-1] >= 'a' && s[i-1] <= 'z' - } - if i+1 < num { - after = s[i+1] >= 'a' && s[i+1] <= 'z' - } - if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { - data = append(data, '_') - } - data = append(data, d) - } - return strings.ToLower(string(data)) -} - -// snake string, XxYy to xx_yy , XxYY to xx_y_y -func snakeString(s string) string { - data := make([]byte, 0, len(s)*2) - j := false - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - if i > 0 && d >= 'A' && d <= 'Z' && j { - data = append(data, '_') - } - if d != '_' { - j = true - } - data = append(data, d) - } - return strings.ToLower(string(data)) -} - -// SetNameStrategy set different name strategy func SetNameStrategy(s string) { - if SnakeAcronymNameStrategy != s { - nameStrategy = defaultNameStrategy - } - nameStrategy = s -} - -// camel string, xx_yy to XxYy -func camelString(s string) string { - data := make([]byte, 0, len(s)) - flag, num := true, len(s)-1 - for i := 0; i <= num; i++ { - d := s[i] - if d == '_' { - flag = true - continue - } else if flag { - if d >= 'a' && d <= 'z' { - d = d - 32 - } - flag = false - } - data = append(data, d) - } - return string(data) -} - -type argString []string - -// get string by index from string slice -func (a argString) Get(i int, args ...string) (r string) { - if i >= 0 && i < len(a) { - r = a[i] - } else if len(args) > 0 { - r = args[0] - } - return -} - -type argInt []int - -// get int by index from int slice -func (a argInt) Get(i int, args ...int) (r int) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -// parse time to string with location -func timeParse(dateString, format string) (time.Time, error) { - tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) - return tp, err -} - -// get pointer indirect type -func indirectType(v reflect.Type) reflect.Type { - switch v.Kind() { - case reflect.Ptr: - return indirectType(v.Elem()) - default: - return v + if models.SnakeAcronymNameStrategy != s { + models.NameStrategy = models.DefaultNameStrategy } + models.NameStrategy = s } diff --git a/go.mod b/go.mod index 94bd9f08..590a03f6 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec github.com/stretchr/testify v1.8.1 + github.com/valyala/bytebufferpool v1.0.0 go.etcd.io/etcd/client/v3 v3.5.9 go.opentelemetry.io/otel v1.11.2 go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2 diff --git a/go.sum b/go.sum index ae4caa49..995d6898 100644 --- a/go.sum +++ b/go.sum @@ -197,6 +197,8 @@ github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112/go.mod h1:Z4AUp2K github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=