From c5bd3c39960dc6cdc1ab2bfd22ff070a96f9e534 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 21 Mar 2021 23:52:55 +0800 Subject: [PATCH 1/2] Support RollbackUnlessCommit --- client/orm/db_alias.go | 8 ++++++ client/orm/orm.go | 4 +++ client/orm/orm_test.go | 60 ++++++++++++++++++++++++++++++++++-------- client/orm/types.go | 37 +++++++++++++++++--------- 4 files changed, 86 insertions(+), 23 deletions(-) diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go index 29e0904c..72c447b3 100644 --- a/client/orm/db_alias.go +++ b/client/orm/db_alias.go @@ -232,6 +232,14 @@ func (t *TxDB) Rollback() error { return t.tx.Rollback() } +func (t *TxDB) RollbackUnlessCommit() error { + err := t.tx.Rollback() + if err != sql.ErrTxDone { + return err + } + return nil +} + var _ dbQuerier = new(TxDB) var _ txEnder = new(TxDB) diff --git a/client/orm/orm.go b/client/orm/orm.go index 3f342868..fa96de4f 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -593,6 +593,10 @@ func (t *txOrm) Rollback() error { return t.db.(txEnder).Rollback() } +func (t *txOrm) RollbackUnlessCommit() error { + return t.db.(txEnder).RollbackUnlessCommit() +} + // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index f6e7a841..3254a01b 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -159,6 +159,7 @@ func throwFail(t *testing.T, err error, args ...interface{}) { } } +// deprecated using assert.XXX func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -2248,27 +2249,64 @@ func TestTransaction(t *testing.T) { } err = to.Rollback() - throwFail(t, err) - + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) + assert.Nil(t, err) + assert.Equal(t, int64(0), num) to, err = o.Begin() - throwFail(t, err) + assert.Nil(t, err) tag.Name = "commit" id, err = to.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) + assert.Nil(t, err) + assert.True(t, id > 0) - to.Commit() - throwFail(t, err) + err = to.Commit() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), num) + +} + +func TestTxOrmRollbackUnlessCommit(t *testing.T) { + o := NewOrm() + var tag Tag + + // test not commited and call RollbackUnlessCommit + to, err := o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err := to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + num, err := o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(0), num) + + // test commit and call RollbackUnlessCommit + + to, err = o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + + err = to.Commit() + assert.Nil(t, err) + + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + + num, err = o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(1), num) } func TestTransactionIsolationLevel(t *testing.T) { diff --git a/client/orm/types.go b/client/orm/types.go index ab3ddac4..f9f74652 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -110,10 +110,35 @@ type TxBeginner interface { } type TxCommitter interface { + txEnder +} + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { Commit() error Rollback() error + + // RollbackUnlessCommit if the transaction has been committed, do nothing, or transaction will be rollback + // For example: + // ```go + // txOrm := orm.Begin() + // defer txOrm.RollbackUnlessCommit() + // err := txOrm.Insert() // do something + // if err != nil { + // return err + // } + // txOrm.Commit() + // ``` + RollbackUnlessCommit() error } + // Data Manipulation Language type DML interface { // insert model data to database @@ -592,18 +617,6 @@ type dbQuerier interface { // QueryRow(query string, args ...interface{}) *sql.Row // } -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - // base database struct type dbBaser interface { Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error From 629d59200359cb7552bbea78aa10b7125f80f329 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 22 Mar 2021 00:02:49 +0800 Subject: [PATCH 2/2] Add change log Support RollbackUnlessCommit --- CHANGELOG.md | 1 + client/orm/filter_orm_decorator.go | 16 ++++++++++++++++ client/orm/filter_orm_decorator_test.go | 4 ++++ client/orm/mock/mock_orm.go | 5 +++++ client/orm/mock/mock_orm_test.go | 13 +++++++++++++ client/orm/orm_log.go | 7 +++++++ 6 files changed, 46 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35659217..2a6e9df7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) - Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) +- Support `RollbackUnlessCommit` API. [4542](https://github.com/beego/beego/pull/4542) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index caf2b3f9..6a9ecc53 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -503,6 +503,22 @@ func (f *filterOrmDecorator) Rollback() error { return f.convertError(res[0]) } +func (f *filterOrmDecorator) RollbackUnlessCommit() error { + inv := &Invocation{ + Method: "RollbackUnlessCommit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.RollbackUnlessCommit() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + func (f *filterOrmDecorator) convertError(v interface{}) error { if v == nil { return nil diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 566499dd..6c3bc72b 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -402,6 +402,10 @@ func (f *filterMockOrm) Rollback() error { return errors.New("rollback") } +func (f *filterMockOrm) RollbackUnlessCommit() error { + return errors.New("rollback unless commit") +} + func (f *filterMockOrm) DBStats() *sql.DBStats { return &sql.DBStats{ MaxOpenConnections: -1, diff --git a/client/orm/mock/mock_orm.go b/client/orm/mock/mock_orm.go index 16ae8612..853a4213 100644 --- a/client/orm/mock/mock_orm.go +++ b/client/orm/mock/mock_orm.go @@ -160,3 +160,8 @@ func MockCommit(err error) *Mock { func MockRollback(err error) *Mock { return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil) } + +// MockRollbackUnlessCommit support RollbackUnlessCommit +func MockRollbackUnlessCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "RollbackUnlessCommit"), []interface{}{err}, nil) +} \ No newline at end of file diff --git a/client/orm/mock/mock_orm_test.go b/client/orm/mock/mock_orm_test.go index 1b321f01..d34774d0 100644 --- a/client/orm/mock/mock_orm_test.go +++ b/client/orm/mock/mock_orm_test.go @@ -241,6 +241,19 @@ func TestTransactionRollback(t *testing.T) { assert.Equal(t, mock, err) } +func TestTransactionRollbackUnlessCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRollbackUnlessCommit(mock)) + + //u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.RollbackUnlessCommit() + assert.Equal(t, mock, err) +} + func TestTransactionCommit(t *testing.T) { s := StartMock() defer s.Clear() diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 6a89f557..da3ef732 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -206,6 +206,13 @@ func (d *dbQueryLog) Rollback() error { return err } +func (d *dbQueryLog) RollbackUnlessCommit() error { + a := time.Now() + err := d.db.(txEnder).RollbackUnlessCommit() + debugLogQueies(d.alias, "tx.RollbackUnlessCommit", "ROLLBACK UNLESS COMMIT", a, err) + return err +} + func (d *dbQueryLog) SetDB(db dbQuerier) { d.db = db }