diff --git a/orm/db_utils.go b/orm/db_utils.go index 34ddfae9..0279a14a 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -154,7 +154,7 @@ outFor: typ := val.Type() name := getFullName(typ) var value interface{} - if mmi, ok := modelCache.getByFN(name); ok { + if mmi, ok := modelCache.getByFullName(name); ok { if _, vu, exist := getExistPk(mmi, val); exist { value = vu } diff --git a/orm/models.go b/orm/models.go index 5cfc69b1..1d5a4dc2 100644 --- a/orm/models.go +++ b/orm/models.go @@ -29,39 +29,18 @@ const ( var ( modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFN: make(map[string]*modelInfo), - } - supportTag = map[string]int{ - "-": 1, - "null": 1, - "index": 1, - "unique": 1, - "pk": 1, - "auto": 1, - "auto_now": 1, - "auto_now_add": 1, - "size": 2, - "column": 2, - "default": 2, - "rel": 2, - "reverse": 2, - "rel_table": 2, - "rel_through": 2, - "digits": 2, - "decimals": 2, - "on_delete": 2, - "type": 2, + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), } ) // model info collection type _modelCache struct { - sync.RWMutex // only used outsite for bootStrap - orders []string - cache map[string]*modelInfo - cacheByFN map[string]*modelInfo - done bool + sync.RWMutex // only used outsite for bootStrap + orders []string + cache map[string]*modelInfo + cacheByFullName map[string]*modelInfo + done bool } // get all model info @@ -88,9 +67,9 @@ func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { return } -// get model info by field name -func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { - mi, ok = mc.cacheByFN[name] +// get model info by full name +func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] return } @@ -98,7 +77,7 @@ func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { mii := mc.cache[table] mc.cache[table] = mi - mc.cacheByFN[mi.fullName] = mi + mc.cacheByFullName[mi.fullName] = mi if mii == nil { mc.orders = append(mc.orders, table) } @@ -109,7 +88,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) clean() { mc.orders = make([]string, 0) mc.cache = make(map[string]*modelInfo) - mc.cacheByFN = make(map[string]*modelInfo) + mc.cacheByFullName = make(map[string]*modelInfo) mc.done = false } diff --git a/orm/models_boot.go b/orm/models_boot.go index 364b68e9..7d9746d0 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -26,12 +26,14 @@ import ( // prefix means table name prefix. func registerModel(prefix string, model interface{}) { val := reflect.ValueOf(model) - ind := reflect.Indirect(val) - typ := ind.Type() + typ := reflect.Indirect(val).Type() if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } + // For this case: + // u := &User{} + // registerModel(&u) if typ.Kind() == reflect.Ptr { panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) } @@ -41,9 +43,9 @@ func registerModel(prefix string, model interface{}) { if prefix != "" { table = prefix + table } - + // models's fullname is pkgpath + struct name name := getFullName(typ) - if _, ok := modelCache.getByFN(name); ok { + if _, ok := modelCache.getByFullName(name); ok { fmt.Printf(" model `%s` repeat register, must be unique\n", name) os.Exit(2) } @@ -110,7 +112,7 @@ func bootStrap() { } name := getFullName(elm) - mii, ok := modelCache.getByFN(name) + mii, ok := modelCache.getByFullName(name) if ok == false || mii.pkg != elm.PkgPath() { err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) goto end @@ -123,7 +125,7 @@ func bootStrap() { msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFN(fi.relThrough) + rmi, ok := modelCache.getByFullName(fi.relThrough) if ok == false || pn != rmi.pkg { err = errors.New(msg + " cannot find table") goto end diff --git a/orm/models_info_f.go b/orm/models_info_f.go index be6c9aa4..33db0d4f 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -152,6 +152,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN fi = new(fieldInfo) + // if field which CanAddr is the follow type + // A value is addressable if it is an element of a slice, + // an element of an addressable array, a field of an + // addressable struct, or the result of dereferencing a pointer. addrField = field if field.CanAddr() && field.Kind() != reflect.Ptr { addrField = field.Addr() @@ -162,7 +166,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN } } - parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags) + attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) if _, ok := attrs["-"]; ok { return nil, errSkipField @@ -188,7 +192,7 @@ checkType: } fieldType = f.FieldType() if fieldType&IsRelField > 0 { - err = fmt.Errorf("unsupport rel type custom field") + err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") goto end } default: diff --git a/orm/models_info_m.go b/orm/models_info_m.go index bbb82444..2e0905ec 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -29,31 +29,25 @@ type modelInfo struct { model interface{} fields *fields manual bool - addrField reflect.Value + addrField reflect.Value //store the original struct value uniques []string isThrough bool } // new model info -func newModelInfo(val reflect.Value) (info *modelInfo) { - - info = &modelInfo{} - info.fields = newFields() - +func newModelInfo(val reflect.Value) (mi *modelInfo) { + mi = &modelInfo{} + mi.fields = newFields() ind := reflect.Indirect(val) - typ := ind.Type() - - info.addrField = val - - info.name = typ.Name() - info.fullName = getFullName(typ) - - addModelFields(info, ind, "", []int{}) - + mi.addrField = val + mi.name = ind.Type().Name() + mi.fullName = getFullName(ind.Type()) + addModelFields(mi, ind, "", []int{}) return } -func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []int) { +// index: FieldByIndex returns the nested field corresponding to index +func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { var ( err error fi *fieldInfo @@ -63,43 +57,39 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in for i := 0; i < ind.NumField(); i++ { field := ind.Field(i) sf = ind.Type().Field(i) + // if the field is unexported skip if sf.PkgPath != "" { continue } // add anonymous struct fields if sf.Anonymous { - addModelFields(info, field, mName+"."+sf.Name, append(index, i)) + addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) continue } - fi, err = newFieldInfo(info, field, sf, mName) - - if err != nil { - if err == errSkipField { - err = nil - continue - } + fi, err = newFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { break } - - added := info.fields.Add(fi) - if added == false { + //record current field index + fi.fieldIndex = append(index, i) + fi.mi = mi + fi.inModel = true + if mi.fields.Add(fi) == false { err = fmt.Errorf("duplicate column name: %s", fi.column) break } - if fi.pk { - if info.fields.pk != nil { + if mi.fields.pk != nil { err = fmt.Errorf("one model must have one pk field only") break } else { - info.fields.pk = fi + mi.fields.pk = fi } } - - fi.fieldIndex = append(index, i) - fi.mi = info - fi.inModel = true } if err != nil { @@ -110,12 +100,12 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in // combine related model info to new model info. // prepare for relation models query. -func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { - info = new(modelInfo) - info.fields = newFields() - info.table = m1.table + "_" + m2.table + "s" - info.name = camelString(info.table) - info.fullName = m1.pkg + "." + info.name +func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { + mi = new(modelInfo) + mi.fields = newFields() + mi.table = m1.table + "_" + m2.table + "s" + mi.name = camelString(mi.table) + mi.fullName = m1.pkg + "." + mi.name fa := new(fieldInfo) f1 := new(fieldInfo) @@ -126,7 +116,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { fa.dbcol = true fa.name = "Id" fa.column = "id" - fa.fullName = info.fullName + "." + fa.name + fa.fullName = mi.fullName + "." + fa.name f1.dbcol = true f2.dbcol = true @@ -134,8 +124,8 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { f2.fieldType = RelForeignKey f1.name = camelString(m1.table) f2.name = camelString(m2.table) - f1.fullName = info.fullName + "." + f1.name - f2.fullName = info.fullName + "." + f2.name + f1.fullName = mi.fullName + "." + f1.name + f2.fullName = mi.fullName + "." + f2.name f1.column = m1.table + "_id" f2.column = m2.table + "_id" f1.rel = true @@ -144,14 +134,14 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { f2.relTable = m2.table f1.relModelInfo = m1 f2.relModelInfo = m2 - f1.mi = info - f2.mi = info + f1.mi = mi + f2.mi = mi - info.fields.Add(fa) - info.fields.Add(f1) - info.fields.Add(f2) - info.fields.pk = fa + mi.fields.Add(fa) + mi.fields.Add(f1) + mi.fields.Add(f2) + mi.fields.pk = fa - info.uniques = []string{f1.column, f2.column} + mi.uniques = []string{f1.column, f2.column} return } diff --git a/orm/models_utils.go b/orm/models_utils.go index 4c4b0f24..2a1f9d53 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -22,25 +22,47 @@ import ( "time" ) +// 1 is attr +// 2 is tag +var supportTag = map[string]int{ + "-": 1, + "null": 1, + "index": 1, + "unique": 1, + "pk": 1, + "auto": 1, + "auto_now": 1, + "auto_now_add": 1, + "size": 2, + "column": 2, + "default": 2, + "rel": 2, + "reverse": 2, + "rel_table": 2, + "rel_through": 2, + "digits": 2, + "decimals": 2, + "on_delete": 2, + "type": 2, +} + // get reflect.Type name with package path. func getFullName(typ reflect.Type) string { return typ.PkgPath() + "." + typ.Name() } -// get table name. method, or field name. auto snaked. +// getTableName get struct table name. +// If the struct implement the TableName, then get the result as tablename +// else use the struct name which will apply snakeString. func getTableName(val reflect.Value) string { - ind := reflect.Indirect(val) - fun := val.MethodByName("TableName") - if fun.IsValid() { + if fun := val.MethodByName("TableName"); fun.IsValid() { vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 { - val := vals[0] - if val.Kind() == reflect.String { - return val.String() - } + // has return and the first val is string + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() } } - return snakeString(ind.Type().Name()) + return snakeString(reflect.Indirect(val).Type().Name()) } // get table engine, mysiam or innodb. @@ -189,21 +211,25 @@ func getFieldType(val reflect.Value) (ft int, err error) { } // parse struct tag string -func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { - attr := make(map[string]bool) - tag := make(map[string]string) +func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { + attrs = make(map[string]bool) + tags = make(map[string]string) for _, v := range strings.Split(data, defaultStructTagDelim) { + if v == "" { + continue + } v = strings.TrimSpace(v) if t := strings.ToLower(v); supportTag[t] == 1 { - attr[t] = true + attrs[t] = true } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { name := t[:i] if supportTag[name] == 2 { v = v[i+1 : len(v)-1] - tag[name] = v + tags[name] = v } + } else { + DebugLog.Println("unsupport orm tag", v) } } - *attrs = attr - *tags = tag + return } diff --git a/orm/orm.go b/orm/orm.go index 994ed7e3..390d300f 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -104,7 +104,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } name := getFullName(typ) - if mi, ok := modelCache.getByFN(name); ok { + if mi, ok := modelCache.getByFullName(name); ok { return mi, ind } panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) @@ -427,7 +427,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { } } else { name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) - if mi, ok := modelCache.getByFN(name); ok { + if mi, ok := modelCache.getByFullName(name); ok { qs = newQuerySet(o, mi) } } diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 3b945833..a968b1a1 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -286,7 +286,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { structMode = true fn := getFullName(typ) - if mi, ok := modelCache.getByFN(fn); ok { + if mi, ok := modelCache.getByFullName(fn); ok { sMi = mi } } else { @@ -355,12 +355,9 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { for i := 0; i < ind.NumField(); i++ { f := ind.Field(i) fe := ind.Type().Field(i) - - var attrs map[string]bool - var tags map[string]string - parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) var col string - if col = tags["column"]; len(col) == 0 { + if col = tags["column"]; col == "" { col = snakeString(fe.Name) } if v, ok := columnsMp[col]; ok { @@ -422,7 +419,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { structMode = true fn := getFullName(typ) - if mi, ok := modelCache.getByFN(fn); ok { + if mi, ok := modelCache.getByFullName(fn); ok { sMi = mi } } else { @@ -499,12 +496,9 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { for i := 0; i < ind.NumField(); i++ { f := ind.Field(i) fe := ind.Type().Field(i) - - var attrs map[string]bool - var tags map[string]string - parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) var col string - if col = tags["column"]; len(col) == 0 { + if col = tags["column"]; col == "" { col = snakeString(fe.Name) } if v, ok := columnsMp[col]; ok { diff --git a/orm/orm_test.go b/orm/orm_test.go index bb1831f9..5e288039 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -227,7 +227,7 @@ func TestModelSyntax(t *testing.T) { user := &User{} ind := reflect.ValueOf(user).Elem() fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFN(fn) + mi, ok := modelCache.getByFullName(fn) throwFail(t, AssertIs(ok, true)) mi, ok = modelCache.get("user")