From 47f9746e11e2908a23bccedd61cda3163475430e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 7 Sep 2021 09:43:27 +0800 Subject: [PATCH] refactor: fix deepsource analyze --- client/orm/orm.go | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/client/orm/orm.go b/client/orm/orm.go index d6ae0c80..8cb8b9b5 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -108,16 +108,30 @@ var ( ) // get model info and model reflect value -func (*ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (*ormBase) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { + mi = getModelInfo(typ) + return +} + +// get need ptr model info and model reflect value +func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } - name := getFullName(typ) + mi = getModelInfo(typ) + return +} + +func getModelInfo(mdTyp reflect.Type) *modelInfo { + name := getFullName(mdTyp) if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind + return mi } panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) } @@ -137,7 +151,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } @@ -147,7 +161,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } @@ -158,7 +172,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create @@ -184,7 +198,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err @@ -228,7 +242,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac if bulk <= 1 { for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) + mi, _ := o.getMiInd(ind.Interface()) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err @@ -239,7 +253,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac cnt++ } } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + mi, _ := o.getMiInd(sind.Index(0).Interface()) return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil @@ -251,7 +265,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err @@ -269,7 +283,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } @@ -280,14 +294,14 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) return num, err } // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) switch { @@ -384,7 +398,7 @@ func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name str // get QuerySeter for related models to md model func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind)