aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bucket.go15
-rw-r--r--cursor.go8
-rw-r--r--error.go4
-rw-r--r--tx.go34
-rw-r--r--tx_test.go45
5 files changed, 89 insertions, 17 deletions
diff --git a/bucket.go b/bucket.go
index c72ddc8..e2bf60c 100644
--- a/bucket.go
+++ b/bucket.go
@@ -55,7 +55,9 @@ func (b *Bucket) Get(key []byte) []byte {
// If the key exist then its previous value will be overwritten.
// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large.
func (b *Bucket) Put(key []byte, value []byte) error {
- if !b.Writable() {
+ if b.tx.db == nil {
+ return ErrTxClosed
+ } else if !b.Writable() {
return ErrBucketNotWritable
}
@@ -82,7 +84,9 @@ func (b *Bucket) Put(key []byte, value []byte) error {
// If the key does not exist then nothing is done and a nil error is returned.
// Returns an error if the bucket was created from a read-only transaction.
func (b *Bucket) Delete(key []byte) error {
- if !b.Writable() {
+ if b.tx.db == nil {
+ return ErrTxClosed
+ } else if !b.Writable() {
return ErrBucketNotWritable
}
@@ -98,7 +102,9 @@ func (b *Bucket) Delete(key []byte) error {
// NextSequence returns an autoincrementing integer for the bucket.
func (b *Bucket) NextSequence() (int, error) {
- if !b.Writable() {
+ if b.tx.db == nil {
+ return 0, ErrTxClosed
+ } else if !b.Writable() {
return 0, ErrBucketNotWritable
}
@@ -118,6 +124,9 @@ func (b *Bucket) NextSequence() (int, error) {
// If the provided function returns an error then the iteration is stopped and
// the error is returned to the caller.
func (b *Bucket) ForEach(fn func(k, v []byte) error) error {
+ if b.tx.db == nil {
+ return ErrTxClosed
+ }
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
if err := fn(k, v); err != nil {
diff --git a/cursor.go b/cursor.go
index d66f1c1..55bf568 100644
--- a/cursor.go
+++ b/cursor.go
@@ -16,6 +16,7 @@ type Cursor struct {
// First moves the cursor to the first item in the bucket and returns its key and value.
// If the bucket is empty then a nil key and value are returned.
func (c *Cursor) First() (key []byte, value []byte) {
+ _assert(c.tx.db != nil, "tx closed")
c.stack = c.stack[:0]
p, n := c.tx.pageNode(c.root)
c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
@@ -26,6 +27,7 @@ func (c *Cursor) First() (key []byte, value []byte) {
// Last moves the cursor to the last item in the bucket and returns its key and value.
// If the bucket is empty then a nil key and value are returned.
func (c *Cursor) Last() (key []byte, value []byte) {
+ _assert(c.tx.db != nil, "tx closed")
c.stack = c.stack[:0]
p, n := c.tx.pageNode(c.root)
ref := elemRef{page: p, node: n}
@@ -38,6 +40,8 @@ func (c *Cursor) Last() (key []byte, value []byte) {
// Next moves the cursor to the next item in the bucket and returns its key and value.
// If the cursor is at the end of the bucket then a nil key and value are returned.
func (c *Cursor) Next() (key []byte, value []byte) {
+ _assert(c.tx.db != nil, "tx closed")
+
// Attempt to move over one element until we're successful.
// Move up the stack as we hit the end of each page in our stack.
for i := len(c.stack) - 1; i >= 0; i-- {
@@ -62,6 +66,8 @@ func (c *Cursor) Next() (key []byte, value []byte) {
// Prev moves the cursor to the previous item in the bucket and returns its key and value.
// If the cursor is at the beginning of the bucket then a nil key and value are returned.
func (c *Cursor) Prev() (key []byte, value []byte) {
+ _assert(c.tx.db != nil, "tx closed")
+
// Attempt to move back one element until we're successful.
// Move up the stack as we hit the beginning of each page in our stack.
for i := len(c.stack) - 1; i >= 0; i-- {
@@ -87,6 +93,8 @@ func (c *Cursor) Prev() (key []byte, value []byte) {
// If the key does not exist then the next key is used. If no keys
// follow, a nil value is returned.
func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) {
+ _assert(c.tx.db != nil, "tx closed")
+
// Start from root page/node and traverse to correct page.
c.stack = c.stack[:0]
c.search(seek, c.root)
diff --git a/error.go b/error.go
index 7f867b2..f88f47e 100644
--- a/error.go
+++ b/error.go
@@ -20,6 +20,10 @@ var (
// read-only transaction.
ErrTxNotWritable = &Error{"tx not writable", nil}
+ // ErrTxClosed is returned when committing or rolling back a transaction
+ // that has already been committed or rolled back.
+ ErrTxClosed = &Error{"tx closed", nil}
+
// ErrBucketNotFound is returned when trying to access a bucket that has
// not been created yet.
ErrBucketNotFound = &Error{"bucket not found", nil}
diff --git a/tx.go b/tx.go
index c18fbf7..8c78184 100644
--- a/tx.go
+++ b/tx.go
@@ -97,7 +97,9 @@ func (t *Tx) Buckets() []*Bucket {
// CreateBucket creates a new bucket.
// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long.
func (t *Tx) CreateBucket(name string) error {
- if !t.writable {
+ if t.db == nil {
+ return ErrTxClosed
+ } else if !t.writable {
return ErrTxNotWritable
} else if b := t.Bucket(name); b != nil {
return ErrBucketExists
@@ -133,7 +135,9 @@ func (t *Tx) CreateBucketIfNotExists(name string) error {
// DeleteBucket deletes a bucket.
// Returns an error if the bucket cannot be found.
func (t *Tx) DeleteBucket(name string) error {
- if !t.writable {
+ if t.db == nil {
+ return ErrTxClosed
+ } else if !t.writable {
return ErrTxNotWritable
}
@@ -159,10 +163,9 @@ func (t *Tx) Commit() error {
if t.managed {
panic("managed tx commit not allowed")
} else if t.db == nil {
- return nil
+ return ErrTxClosed
} else if !t.writable {
- t.Rollback()
- return nil
+ return ErrTxNotWritable
}
defer t.close()
@@ -196,22 +199,23 @@ func (t *Tx) Commit() error {
}
// Rollback closes the transaction and ignores all previous updates.
-func (t *Tx) Rollback() {
+func (t *Tx) Rollback() error {
if t.managed {
panic("managed tx rollback not allowed")
+ } else if t.db == nil {
+ return ErrTxClosed
}
t.close()
+ return nil
}
func (t *Tx) close() {
- if t.db != nil {
- if t.writable {
- t.db.rwlock.Unlock()
- } else {
- t.db.removeTx(t)
- }
- t.db = nil
+ if t.writable {
+ t.db.rwlock.Unlock()
+ } else {
+ t.db.removeTx(t)
}
+ t.db = nil
}
// allocate returns a contiguous block of memory starting at a given page.
@@ -433,7 +437,9 @@ func (t *Tx) forEachPage(pgid pgid, depth int, fn func(*page, int)) {
// Page returns page information for a given page number.
// This is only available from writable transactions.
func (t *Tx) Page(id int) (*PageInfo, error) {
- if !t.writable {
+ if t.db == nil {
+ return nil, ErrTxClosed
+ } else if !t.writable {
return nil, ErrTxNotWritable
} else if pgid(id) >= t.meta.pgid {
return nil, nil
diff --git a/tx_test.go b/tx_test.go
index afdbb02..fe33869 100644
--- a/tx_test.go
+++ b/tx_test.go
@@ -13,6 +13,33 @@ import (
"github.com/stretchr/testify/assert"
)
+// Ensure that committing a closed transaction returns an error.
+func TestTxCommitClosed(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ tx, _ := db.RWTx()
+ tx.CreateBucket("foo")
+ assert.NoError(t, tx.Commit())
+ assert.Equal(t, tx.Commit(), ErrTxClosed)
+ })
+}
+
+// Ensure that rolling back a closed transaction returns an error.
+func TestTxRollbackClosed(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ tx, _ := db.RWTx()
+ assert.NoError(t, tx.Rollback())
+ assert.Equal(t, tx.Rollback(), ErrTxClosed)
+ })
+}
+
+// Ensure that committing a read-only transaction returns an error.
+func TestTxCommitReadOnly(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ tx, _ := db.Tx()
+ assert.Equal(t, tx.Commit(), ErrTxNotWritable)
+ })
+}
+
// Ensure that the database can retrieve a list of buckets.
func TestTxBuckets(t *testing.T) {
withOpenDB(func(db *DB, path string) {
@@ -41,6 +68,15 @@ func TestTxCreateBucketReadOnly(t *testing.T) {
})
}
+// Ensure that creating a bucket on a closed transaction returns an error.
+func TestTxCreateBucketClosed(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ tx, _ := db.RWTx()
+ tx.Commit()
+ assert.Equal(t, tx.CreateBucket("foo"), ErrTxClosed)
+ })
+}
+
// Ensure that a Tx can retrieve a bucket.
func TestTxBucket(t *testing.T) {
withOpenDB(func(db *DB, path string) {
@@ -204,6 +240,15 @@ func TestTxDeleteBucket(t *testing.T) {
})
}
+// Ensure that deleting a bucket on a closed transaction returns an error.
+func TestTxDeleteBucketClosed(t *testing.T) {
+ withOpenDB(func(db *DB, path string) {
+ tx, _ := db.RWTx()
+ tx.Commit()
+ assert.Equal(t, tx.DeleteBucket("foo"), ErrTxClosed)
+ })
+}
+
// Ensure that deleting a bucket with a read-only transaction returns an error.
func TestTxDeleteBucketReadOnly(t *testing.T) {
withOpenDB(func(db *DB, path string) {