aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--db.go23
-rw-r--r--db_test.go17
-rw-r--r--tx.go8
3 files changed, 43 insertions, 5 deletions
diff --git a/db.go b/db.go
index 035a310..0bce184 100644
--- a/db.go
+++ b/db.go
@@ -349,19 +349,26 @@ func (db *DB) removeTx(t *Tx) {
}
}
-// Do executes a function within the context of a read-write transaction.
+// Do executes a function within the context of a read-write managed transaction.
// If no error is returned from the function then the transaction is committed.
// If an error is returned then the entire transaction is rolled back.
// Any error that is returned from the function or returned from the commit is
// returned from the Do() method.
+//
+// Attempting to manually commit or rollback within the function will cause a panic.
func (db *DB) Do(fn func(*Tx) error) error {
t, err := db.RWTx()
if err != nil {
return err
}
+ // Mark as a managed tx so that the inner function cannot manually commit.
+ t.managed = true
+
// If an error is returned from the function then rollback and return error.
- if err := fn(t); err != nil {
+ err = fn(t)
+ t.managed = false
+ if err != nil {
t.Rollback()
return err
}
@@ -369,8 +376,10 @@ func (db *DB) Do(fn func(*Tx) error) error {
return t.Commit()
}
-// With executes a function within the context of a transaction.
+// With executes a function within the context of a managed transaction.
// Any error that is returned from the function is returned from the With() method.
+//
+// Attempting to manually rollback within the function will cause a panic.
func (db *DB) With(fn func(*Tx) error) error {
t, err := db.Tx()
if err != nil {
@@ -378,8 +387,14 @@ func (db *DB) With(fn func(*Tx) error) error {
}
defer t.Rollback()
+ // Mark as a managed tx so that the inner function cannot manually rollback.
+ t.managed = true
+
// If an error is returned from the function then pass it through.
- return fn(t)
+ err = fn(t)
+ t.managed = false
+
+ return err
}
// Copy writes the entire database to a writer.
diff --git a/db_test.go b/db_test.go
index 04abd75..2882ba8 100644
--- a/db_test.go
+++ b/db_test.go
@@ -111,6 +111,23 @@ func TestDBTxBlockWhileClosed(t *testing.T) {
})
}
+// Ensure a panic occurs while trying to commit a managed transaction.
+func TestDBTxBlockWithManualCommitAndRollback(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ db.Do(func(tx *Tx) error {
+ tx.CreateBucket("widgets")
+ assert.Panics(t, func() { tx.Commit() })
+ assert.Panics(t, func() { tx.Rollback() })
+ return nil
+ })
+ db.With(func(tx *Tx) error {
+ assert.Panics(t, func() { tx.Commit() })
+ assert.Panics(t, func() { tx.Rollback() })
+ return nil
+ })
+ })
+}
+
// Ensure that the database can be copied to a file path.
func TestDBCopyFile(t *testing.T) {
withOpenDB(func(db *DB, path string) {
diff --git a/tx.go b/tx.go
index 181444e..c18fbf7 100644
--- a/tx.go
+++ b/tx.go
@@ -18,6 +18,7 @@ type txid uint64
// quickly grow.
type Tx struct {
writable bool
+ managed bool
db *DB
meta *meta
buckets *buckets
@@ -155,7 +156,9 @@ func (t *Tx) DeleteBucket(name string) error {
// Commit writes all changes to disk and updates the meta page.
// Returns an error if a disk write error occurs.
func (t *Tx) Commit() error {
- if t.db == nil {
+ if t.managed {
+ panic("managed tx commit not allowed")
+ } else if t.db == nil {
return nil
} else if !t.writable {
t.Rollback()
@@ -194,6 +197,9 @@ func (t *Tx) Commit() error {
// Rollback closes the transaction and ignores all previous updates.
func (t *Tx) Rollback() {
+ if t.managed {
+ panic("managed tx rollback not allowed")
+ }
t.close()
}