diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go index 29e0904c..72c447b3 100644 --- a/client/orm/db_alias.go +++ b/client/orm/db_alias.go @@ -232,6 +232,14 @@ func (t *TxDB) Rollback() error { return t.tx.Rollback() } +func (t *TxDB) RollbackUnlessCommit() error { + err := t.tx.Rollback() + if err != sql.ErrTxDone { + return err + } + return nil +} + var _ dbQuerier = new(TxDB) var _ txEnder = new(TxDB) diff --git a/client/orm/orm.go b/client/orm/orm.go index 3f342868..fa96de4f 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -593,6 +593,10 @@ func (t *txOrm) Rollback() error { return t.db.(txEnder).Rollback() } +func (t *txOrm) RollbackUnlessCommit() error { + return t.db.(txEnder).RollbackUnlessCommit() +} + // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index f6e7a841..3254a01b 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -159,6 +159,7 @@ func throwFail(t *testing.T, err error, args ...interface{}) { } } +// deprecated using assert.XXX func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -2248,27 +2249,64 @@ func TestTransaction(t *testing.T) { } err = to.Rollback() - throwFail(t, err) - + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) + assert.Nil(t, err) + assert.Equal(t, int64(0), num) to, err = o.Begin() - throwFail(t, err) + assert.Nil(t, err) tag.Name = "commit" id, err = to.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) + assert.Nil(t, err) + assert.True(t, id > 0) - to.Commit() - throwFail(t, err) + err = to.Commit() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), num) + +} + +func TestTxOrmRollbackUnlessCommit(t *testing.T) { + o := NewOrm() + var tag Tag + + // test not commited and call RollbackUnlessCommit + to, err := o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err := to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + num, err := o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(0), num) + + // test commit and call RollbackUnlessCommit + + to, err = o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + + err = to.Commit() + assert.Nil(t, err) + + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + + num, err = o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(1), num) } func TestTransactionIsolationLevel(t *testing.T) { diff --git a/client/orm/types.go b/client/orm/types.go index ab3ddac4..f9f74652 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -110,10 +110,35 @@ type TxBeginner interface { } type TxCommitter interface { + txEnder +} + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { Commit() error Rollback() error + + // RollbackUnlessCommit if the transaction has been committed, do nothing, or transaction will be rollback + // For example: + // ```go + // txOrm := orm.Begin() + // defer txOrm.RollbackUnlessCommit() + // err := txOrm.Insert() // do something + // if err != nil { + // return err + // } + // txOrm.Commit() + // ``` + RollbackUnlessCommit() error } + // Data Manipulation Language type DML interface { // insert model data to database @@ -592,18 +617,6 @@ type dbQuerier interface { // QueryRow(query string, args ...interface{}) *sql.Row // } -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - // base database struct type dbBaser interface { Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error