diff --git a/CHANGELOG.md b/CHANGELOG.md index 80e75c70..c3a7c42f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,7 @@ - Fix 4734: do not reset id in Delete function. [4738](https://github.com/beego/beego/pull/4738) [4742](https://github.com/beego/beego/pull/4742) - Fix 4699: Remove Remove goyaml2 dependency. [4755](https://github.com/beego/beego/pull/4755) - Fix 4698: Prompt error when config format is incorrect. [4757](https://github.com/beego/beego/pull/4757) - +- Fix 4674: Tx Orm missing debug log [4756](https://github.com/beego/beego/pull/4756) ## Fix Sonar - [4677](https://github.com/beego/beego/pull/4677) diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a4a215f1..edeaaade 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -462,7 +462,7 @@ func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.T TxStartTime: f.txStartTime, TxName: getTxNameFromCtx(ctx), f: func(c context.Context) []interface{} { - err := doTxTemplate(f, c, opts, task) + err := doTxTemplate(c, f, opts, task) return []interface{}{err} }, } @@ -518,7 +518,7 @@ func (f *filterOrmDecorator) RollbackUnlessCommit() error { return f.convertError(res[0]) } -func (f *filterOrmDecorator) convertError(v interface{}) error { +func (*filterOrmDecorator) convertError(v interface{}) error { if v == nil { return nil } diff --git a/client/orm/orm.go b/client/orm/orm.go index 47e46400..30753acf 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -108,22 +108,36 @@ var ( ) // get model info and model reflect value -func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (*ormBase) getMi(md interface{}) (mi *modelInfo) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + mi = getTypeMi(typ) + return +} + +// get need ptr model info and model reflect value +func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { + if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } - name := getFullName(typ) + mi = getTypeMi(typ) + return +} + +func getTypeMi(mdTyp reflect.Type) *modelInfo { + name := getFullName(mdTyp) if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind + return mi } panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) } // get field info from model info by given field name -func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { +func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) @@ -137,7 +151,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } @@ -147,7 +161,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } @@ -158,7 +172,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create @@ -184,7 +198,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err @@ -196,7 +210,7 @@ func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, err } // set auto pk field -func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { +func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) @@ -228,7 +242,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac if bulk <= 1 { for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) + mi := o.getMi(ind.Interface()) id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err @@ -239,7 +253,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac cnt++ } } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + mi := o.getMi(sind.Index(0).Interface()) return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil @@ -251,7 +265,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err @@ -269,7 +283,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } @@ -280,14 +294,14 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) return num, err } // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) switch { @@ -318,7 +332,7 @@ func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (in return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { +func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { _, fi, ind, qs := o.queryRelated(md, name) var relDepth int @@ -384,7 +398,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s // get QuerySeter for related models to md model func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) @@ -488,7 +502,7 @@ func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { return o.RawWithCtx(context.Background(), query, args...) } -func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { +func (o *ormBase) RawWithCtx(_ context.Context, query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } @@ -536,6 +550,11 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO db: &TxDB{tx: tx}, }, } + + if Debug { + _txOrm.db = newDbQueryLog(o.alias, _txOrm.db) + } + var taskTxOrm TxOrmer = _txOrm return taskTxOrm, nil } @@ -553,10 +572,10 @@ func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, t } func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { - return doTxTemplate(o, ctx, opts, task) + return doTxTemplate(ctx, o, opts, task) } -func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions, +func doTxTemplate(ctx context.Context, o TxBeginner, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) if err != nil { diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 05f16a34..764e5b6d 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -129,7 +129,8 @@ func getCaller(skip int) string { if cur == line { flag = ">>" } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + ls := formatLines(string(lines[o+i])) + code := fmt.Sprintf(" %s %5d: %s", flag, cur, ls) if code != "" { codes = append(codes, code) } @@ -142,6 +143,10 @@ func getCaller(skip int) string { return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } +func formatLines(s string) string { + return strings.ReplaceAll(s, "\t", " ") +} + // Deprecated: Using stretchr/testify/assert func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { @@ -212,7 +217,7 @@ func TestSyncDb(t *testing.T) { modelCache.clean() } -func TestRegisterModels(t *testing.T) { +func TestRegisterModels(_ *testing.T) { RegisterModel(new(Data), new(DataNull), new(DataCustom)) RegisterModel(new(User)) RegisterModel(new(Profile)) @@ -245,10 +250,10 @@ func TestModelSyntax(t *testing.T) { user := &User{} ind := reflect.ValueOf(user).Elem() fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) + _, ok := modelCache.getByFullName(fn) throwFail(t, AssertIs(ok, true)) - mi, ok = modelCache.get("user") + mi, ok := modelCache.get("user") throwFail(t, AssertIs(ok, true)) if ok { throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) @@ -883,10 +888,11 @@ func TestCustomField(t *testing.T) { func TestExpr(t *testing.T) { user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") + var qs QuerySeter + assert.NotPanics(t, func() { qs = dORM.QueryTable(user) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable((*User)(nil)) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("User") }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("user") }) num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -1719,7 +1725,7 @@ func TestQueryM2M(t *testing.T) { throwFailNow(t, AssertIs(num, 1)) } -func TestQueryRelate(t *testing.T) { +func TestQueryRelate(_ *testing.T) { // post := &Post{Id: 2} // qs := dORM.QueryRelate(post, "Tags") @@ -2069,9 +2075,7 @@ func TestRawPrepare(t *testing.T) { err error pre RawPreparer ) - switch { - case IsMysql || IsSqlite: - + if IsMysql || IsSqlite { pre, err = dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() assert.Nil(t, err) if pre != nil { @@ -2106,9 +2110,7 @@ func TestRawPrepare(t *testing.T) { assert.Nil(t, err) assert.Equal(t, num, int64(3)) } - - case IsPostgres: - + } else if IsPostgres { pre, err = dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() assert.Nil(t, err) if pre != nil { @@ -2238,8 +2240,7 @@ func TestTransaction(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - switch { - case IsMysql || IsSqlite: + if IsMysql || IsSqlite { res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { @@ -2859,12 +2860,10 @@ func TestCondition(t *testing.T) { hasCycle(p.cond) } } - return } hasCycle(cond) // cycleFlag was true,meaning use self as sub cond throwFail(t, AssertIs(!cycleFlag, true)) - return } func TestContextCanceled(t *testing.T) { @@ -2886,3 +2885,58 @@ func TestContextCanceled(t *testing.T) { _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) throwFail(t, AssertIs(err, context.Canceled)) } + +func TestDebugLog(t *testing.T) { + txCommitFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + _, txerr = txOrm.QueryTable(&User{}).Count() + return + }) + } + + txRollbackFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + txOrm.Insert(user) + txerr = fmt.Errorf("mock error") + return + }) + } + + Debug = true + output1 := captureDebugLogOutput(txCommitFn) + + assert.Contains(t, output1, "START TRANSACTION") + assert.Contains(t, output1, "COMMIT") + + output2 := captureDebugLogOutput(txRollbackFn) + + assert.Contains(t, output2, "START TRANSACTION") + assert.Contains(t, output2, "ROLLBACK") + + Debug = false + output1 = captureDebugLogOutput(txCommitFn) + assert.EqualValues(t, output1, "") + + output2 = captureDebugLogOutput(txRollbackFn) + assert.EqualValues(t, output2, "") +} + +func captureDebugLogOutput(f func()) string { + var buf bytes.Buffer + DebugLog.SetOutput(&buf) + defer func() { + DebugLog.SetOutput(os.Stderr) + }() + f() + return buf.String() +}