diff options
-rw-r--r-- | bucket.go | 15 | ||||
-rw-r--r-- | cursor.go | 8 | ||||
-rw-r--r-- | error.go | 4 | ||||
-rw-r--r-- | tx.go | 34 | ||||
-rw-r--r-- | tx_test.go | 45 |
5 files changed, 89 insertions, 17 deletions
@@ -55,7 +55,9 @@ func (b *Bucket) Get(key []byte) []byte { // If the key exist then its previous value will be overwritten. // Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large. func (b *Bucket) Put(key []byte, value []byte) error { - if !b.Writable() { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { return ErrBucketNotWritable } @@ -82,7 +84,9 @@ func (b *Bucket) Put(key []byte, value []byte) error { // If the key does not exist then nothing is done and a nil error is returned. // Returns an error if the bucket was created from a read-only transaction. func (b *Bucket) Delete(key []byte) error { - if !b.Writable() { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { return ErrBucketNotWritable } @@ -98,7 +102,9 @@ func (b *Bucket) Delete(key []byte) error { // NextSequence returns an autoincrementing integer for the bucket. func (b *Bucket) NextSequence() (int, error) { - if !b.Writable() { + if b.tx.db == nil { + return 0, ErrTxClosed + } else if !b.Writable() { return 0, ErrBucketNotWritable } @@ -118,6 +124,9 @@ func (b *Bucket) NextSequence() (int, error) { // If the provided function returns an error then the iteration is stopped and // the error is returned to the caller. func (b *Bucket) ForEach(fn func(k, v []byte) error) error { + if b.tx.db == nil { + return ErrTxClosed + } c := b.Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { if err := fn(k, v); err != nil { @@ -16,6 +16,7 @@ type Cursor struct { // First moves the cursor to the first item in the bucket and returns its key and value. // If the bucket is empty then a nil key and value are returned. func (c *Cursor) First() (key []byte, value []byte) { + _assert(c.tx.db != nil, "tx closed") c.stack = c.stack[:0] p, n := c.tx.pageNode(c.root) c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) @@ -26,6 +27,7 @@ func (c *Cursor) First() (key []byte, value []byte) { // Last moves the cursor to the last item in the bucket and returns its key and value. // If the bucket is empty then a nil key and value are returned. func (c *Cursor) Last() (key []byte, value []byte) { + _assert(c.tx.db != nil, "tx closed") c.stack = c.stack[:0] p, n := c.tx.pageNode(c.root) ref := elemRef{page: p, node: n} @@ -38,6 +40,8 @@ func (c *Cursor) Last() (key []byte, value []byte) { // Next moves the cursor to the next item in the bucket and returns its key and value. // If the cursor is at the end of the bucket then a nil key and value are returned. func (c *Cursor) Next() (key []byte, value []byte) { + _assert(c.tx.db != nil, "tx closed") + // Attempt to move over one element until we're successful. // Move up the stack as we hit the end of each page in our stack. for i := len(c.stack) - 1; i >= 0; i-- { @@ -62,6 +66,8 @@ func (c *Cursor) Next() (key []byte, value []byte) { // Prev moves the cursor to the previous item in the bucket and returns its key and value. // If the cursor is at the beginning of the bucket then a nil key and value are returned. func (c *Cursor) Prev() (key []byte, value []byte) { + _assert(c.tx.db != nil, "tx closed") + // Attempt to move back one element until we're successful. // Move up the stack as we hit the beginning of each page in our stack. for i := len(c.stack) - 1; i >= 0; i-- { @@ -87,6 +93,8 @@ func (c *Cursor) Prev() (key []byte, value []byte) { // If the key does not exist then the next key is used. If no keys // follow, a nil value is returned. func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) { + _assert(c.tx.db != nil, "tx closed") + // Start from root page/node and traverse to correct page. c.stack = c.stack[:0] c.search(seek, c.root) @@ -20,6 +20,10 @@ var ( // read-only transaction. ErrTxNotWritable = &Error{"tx not writable", nil} + // ErrTxClosed is returned when committing or rolling back a transaction + // that has already been committed or rolled back. + ErrTxClosed = &Error{"tx closed", nil} + // ErrBucketNotFound is returned when trying to access a bucket that has // not been created yet. ErrBucketNotFound = &Error{"bucket not found", nil} @@ -97,7 +97,9 @@ func (t *Tx) Buckets() []*Bucket { // CreateBucket creates a new bucket. // Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long. func (t *Tx) CreateBucket(name string) error { - if !t.writable { + if t.db == nil { + return ErrTxClosed + } else if !t.writable { return ErrTxNotWritable } else if b := t.Bucket(name); b != nil { return ErrBucketExists @@ -133,7 +135,9 @@ func (t *Tx) CreateBucketIfNotExists(name string) error { // DeleteBucket deletes a bucket. // Returns an error if the bucket cannot be found. func (t *Tx) DeleteBucket(name string) error { - if !t.writable { + if t.db == nil { + return ErrTxClosed + } else if !t.writable { return ErrTxNotWritable } @@ -159,10 +163,9 @@ func (t *Tx) Commit() error { if t.managed { panic("managed tx commit not allowed") } else if t.db == nil { - return nil + return ErrTxClosed } else if !t.writable { - t.Rollback() - return nil + return ErrTxNotWritable } defer t.close() @@ -196,22 +199,23 @@ func (t *Tx) Commit() error { } // Rollback closes the transaction and ignores all previous updates. -func (t *Tx) Rollback() { +func (t *Tx) Rollback() error { if t.managed { panic("managed tx rollback not allowed") + } else if t.db == nil { + return ErrTxClosed } t.close() + return nil } func (t *Tx) close() { - if t.db != nil { - if t.writable { - t.db.rwlock.Unlock() - } else { - t.db.removeTx(t) - } - t.db = nil + if t.writable { + t.db.rwlock.Unlock() + } else { + t.db.removeTx(t) } + t.db = nil } // allocate returns a contiguous block of memory starting at a given page. @@ -433,7 +437,9 @@ func (t *Tx) forEachPage(pgid pgid, depth int, fn func(*page, int)) { // Page returns page information for a given page number. // This is only available from writable transactions. func (t *Tx) Page(id int) (*PageInfo, error) { - if !t.writable { + if t.db == nil { + return nil, ErrTxClosed + } else if !t.writable { return nil, ErrTxNotWritable } else if pgid(id) >= t.meta.pgid { return nil, nil @@ -13,6 +13,33 @@ import ( "github.com/stretchr/testify/assert" ) +// Ensure that committing a closed transaction returns an error. +func TestTxCommitClosed(t *testing.T) { + withOpenDB(func(db *DB, path string) { + tx, _ := db.RWTx() + tx.CreateBucket("foo") + assert.NoError(t, tx.Commit()) + assert.Equal(t, tx.Commit(), ErrTxClosed) + }) +} + +// Ensure that rolling back a closed transaction returns an error. +func TestTxRollbackClosed(t *testing.T) { + withOpenDB(func(db *DB, path string) { + tx, _ := db.RWTx() + assert.NoError(t, tx.Rollback()) + assert.Equal(t, tx.Rollback(), ErrTxClosed) + }) +} + +// Ensure that committing a read-only transaction returns an error. +func TestTxCommitReadOnly(t *testing.T) { + withOpenDB(func(db *DB, path string) { + tx, _ := db.Tx() + assert.Equal(t, tx.Commit(), ErrTxNotWritable) + }) +} + // Ensure that the database can retrieve a list of buckets. func TestTxBuckets(t *testing.T) { withOpenDB(func(db *DB, path string) { @@ -41,6 +68,15 @@ func TestTxCreateBucketReadOnly(t *testing.T) { }) } +// Ensure that creating a bucket on a closed transaction returns an error. +func TestTxCreateBucketClosed(t *testing.T) { + withOpenDB(func(db *DB, path string) { + tx, _ := db.RWTx() + tx.Commit() + assert.Equal(t, tx.CreateBucket("foo"), ErrTxClosed) + }) +} + // Ensure that a Tx can retrieve a bucket. func TestTxBucket(t *testing.T) { withOpenDB(func(db *DB, path string) { @@ -204,6 +240,15 @@ func TestTxDeleteBucket(t *testing.T) { }) } +// Ensure that deleting a bucket on a closed transaction returns an error. +func TestTxDeleteBucketClosed(t *testing.T) { + withOpenDB(func(db *DB, path string) { + tx, _ := db.RWTx() + tx.Commit() + assert.Equal(t, tx.DeleteBucket("foo"), ErrTxClosed) + }) +} + // Ensure that deleting a bucket with a read-only transaction returns an error. func TestTxDeleteBucketReadOnly(t *testing.T) { withOpenDB(func(db *DB, path string) { |