aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--db.go14
-rw-r--r--db_test.go60
-rw-r--r--tx.go6
3 files changed, 80 insertions, 0 deletions
diff --git a/db.go b/db.go
index b839e23..4df79e2 100644
--- a/db.go
+++ b/db.go
@@ -440,6 +440,13 @@ func (db *DB) Update(fn func(*Tx) error) error {
return err
}
+ // Make sure the transaction rolls back in the event of a panic.
+ defer func() {
+ if t.db != nil {
+ t.rollback()
+ }
+ }()
+
// Mark as a managed tx so that the inner function cannot manually commit.
t.managed = true
@@ -464,6 +471,13 @@ func (db *DB) View(fn func(*Tx) error) error {
return err
}
+ // Make sure the transaction rolls back in the event of a panic.
+ defer func() {
+ if t.db != nil {
+ t.rollback()
+ }
+ }()
+
// Mark as a managed tx so that the inner function cannot manually rollback.
t.managed = true
diff --git a/db_test.go b/db_test.go
index 9bd1ac1..9849ed6 100644
--- a/db_test.go
+++ b/db_test.go
@@ -262,6 +262,37 @@ func TestDB_Update_ManualCommitAndRollback(t *testing.T) {
})
}
+// Ensure a write transaction that panics does not hold open locks.
+func TestDB_Update_Panic(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ warn("recover: update", r)
+ }
+ }()
+ db.Update(func(tx *Tx) error {
+ tx.CreateBucket([]byte("widgets"))
+ panic("omg")
+ return nil
+ })
+ }()
+
+ // Verify we can update again.
+ err := db.Update(func(tx *Tx) error {
+ _, err := tx.CreateBucket([]byte("widgets"))
+ return err
+ })
+ assert.NoError(t, err)
+
+ // Verify that our change persisted.
+ err = db.Update(func(tx *Tx) error {
+ assert.NotNil(t, tx.Bucket([]byte("widgets")))
+ return nil
+ })
+ })
+}
+
// Ensure a database can return an error through a read-only transactional block.
func TestDB_View_Error(t *testing.T) {
withOpenDB(func(db *DB, path string) {
@@ -272,6 +303,35 @@ func TestDB_View_Error(t *testing.T) {
})
}
+// Ensure a read transaction that panics does not hold open locks.
+func TestDB_View_Panic(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ db.Update(func(tx *Tx) error {
+ tx.CreateBucket([]byte("widgets"))
+ return nil
+ })
+
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ warn("recover: view", r)
+ }
+ }()
+ db.View(func(tx *Tx) error {
+ assert.NotNil(t, tx.Bucket([]byte("widgets")))
+ panic("omg")
+ return nil
+ })
+ }()
+
+ // Verify that we can still use read transactions.
+ db.View(func(tx *Tx) error {
+ assert.NotNil(t, tx.Bucket([]byte("widgets")))
+ return nil
+ })
+ })
+}
+
// Ensure that an error is returned when a database write fails.
func TestDB_Commit_WriteFail(t *testing.T) {
t.Skip("pending") // TODO(benbjohnson)
diff --git a/tx.go b/tx.go
index 5eada65..759913d 100644
--- a/tx.go
+++ b/tx.go
@@ -212,6 +212,9 @@ func (tx *Tx) Rollback() error {
}
func (tx *Tx) rollback() {
+ if tx.db == nil {
+ return
+ }
if tx.writable {
tx.db.freelist.rollback(tx.id())
tx.db.freelist.reload(tx.db.page(tx.db.meta().freelist))
@@ -220,6 +223,9 @@ func (tx *Tx) rollback() {
}
func (tx *Tx) close() {
+ if tx.db == nil {
+ return
+ }
if tx.writable {
// Grab freelist stats.
var freelistFreeN = tx.db.freelist.free_count()