aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--funcs.go12
-rw-r--r--tx.go16
-rw-r--r--var.go12
3 files changed, 27 insertions, 13 deletions
diff --git a/funcs.go b/funcs.go
index f87efb8..2170004 100644
--- a/funcs.go
+++ b/funcs.go
@@ -8,6 +8,7 @@ retry:
reads: make(map[*Var]uint64),
writes: make(map[*Var]interface{}),
}
+ tx.cond.L = &globalLock
if catchRetry(fn, tx) {
// wait for one of the variables we read to change before retrying
tx.wait()
@@ -40,14 +41,9 @@ func AtomicGet(v *Var) interface{} {
// AtomicSet is a helper function that atomically writes a value.
func AtomicSet(v *Var, val interface{}) {
- // since we're only doing one operation, we don't need a full transaction
- globalLock.Lock()
- v.mu.Lock()
- v.val = val
- v.version++
- v.mu.Unlock()
- globalCond.Broadcast()
- globalLock.Unlock()
+ Atomically(func(tx *Tx) {
+ tx.Set(v, val)
+ })
}
// Compose is a helper function that composes multiple transactions into a
diff --git a/tx.go b/tx.go
index 127480e..a17cc13 100644
--- a/tx.go
+++ b/tx.go
@@ -1,9 +1,14 @@
package stm
+import (
+ "sync"
+)
+
// A Tx represents an atomic transaction.
type Tx struct {
reads map[*Var]uint64
writes map[*Var]interface{}
+ cond sync.Cond
}
// Check that none of the logged values have changed since the transaction began.
@@ -25,6 +30,9 @@ func (tx *Tx) commit() {
v.mu.Lock()
v.val = val
v.version++
+ for tx := range v.watchers {
+ tx.cond.Broadcast()
+ }
v.mu.Unlock()
}
}
@@ -32,8 +40,14 @@ func (tx *Tx) commit() {
// wait blocks until another transaction modifies any of the Vars read by tx.
func (tx *Tx) wait() {
globalCond.L.Lock()
+ for v := range tx.reads {
+ v.watchers[tx] = struct{}{}
+ }
for tx.verify() {
- globalCond.Wait()
+ tx.cond.Wait()
+ }
+ for v := range tx.reads {
+ delete(v.watchers, tx)
}
globalCond.L.Unlock()
}
diff --git a/var.go b/var.go
index 3494352..05e86cf 100644
--- a/var.go
+++ b/var.go
@@ -4,12 +4,16 @@ import "sync"
// A Var holds an STM variable.
type Var struct {
- val interface{}
- version uint64
- mu sync.Mutex
+ val interface{}
+ version uint64
+ mu sync.Mutex
+ watchers map[*Tx]struct{}
}
// NewVar returns a new STM variable.
func NewVar(val interface{}) *Var {
- return &Var{val: val}
+ return &Var{
+ val: val,
+ watchers: make(map[*Tx]struct{}),
+ }
}