aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--funcs.go17
-rw-r--r--global.go1
-rw-r--r--tx.go20
-rw-r--r--var.go12
4 files changed, 30 insertions, 20 deletions
diff --git a/funcs.go b/funcs.go
index a7ff6e0..4967e83 100644
--- a/funcs.go
+++ b/funcs.go
@@ -8,6 +8,7 @@ retry:
reads: make(map[*Var]uint64),
writes: make(map[*Var]interface{}),
}
+ tx.cond.L = &globalLock
var ret interface{}
if func() (retry bool) {
defer func() {
@@ -37,10 +38,7 @@ retry:
goto retry
}
// commit the write log and broadcast that variables have changed
- if len(tx.writes) > 0 {
- tx.commit()
- globalCond.Broadcast()
- }
+ tx.commit()
globalLock.Unlock()
return ret
}
@@ -58,14 +56,9 @@ func AtomicGet(v *Var) interface{} {
// AtomicSet is a helper function that atomically writes a value.
func AtomicSet(v *Var, val interface{}) {
- // since we're only doing one operation, we don't need a full transaction
- globalLock.Lock()
- v.mu.Lock()
- v.val = val
- v.version++
- v.mu.Unlock()
- globalCond.Broadcast()
- globalLock.Unlock()
+ Atomically(func(tx *Tx) {
+ tx.Set(v, val)
+ })
}
// Compose is a helper function that composes multiple transactions into a
diff --git a/global.go b/global.go
index 479864e..bc5f18a 100644
--- a/global.go
+++ b/global.go
@@ -5,4 +5,3 @@ import "sync"
// The globalLock serializes transaction verification/committal. globalCond is
// used to signal that at least one Var has changed.
var globalLock sync.Mutex
-var globalCond = sync.NewCond(&globalLock)
diff --git a/tx.go b/tx.go
index 123c490..5dd41ae 100644
--- a/tx.go
+++ b/tx.go
@@ -1,9 +1,14 @@
package stm
+import (
+ "sync"
+)
+
// A Tx represents an atomic transaction.
type Tx struct {
reads map[*Var]uint64
writes map[*Var]interface{}
+ cond sync.Cond
}
// Check that none of the logged values have changed since the transaction began.
@@ -25,17 +30,26 @@ func (tx *Tx) commit() {
v.mu.Lock()
v.val = val
v.version++
+ for tx := range v.watchers {
+ tx.cond.Broadcast()
+ }
v.mu.Unlock()
}
}
// wait blocks until another transaction modifies any of the Vars read by tx.
func (tx *Tx) wait() {
- globalCond.L.Lock()
+ globalLock.Lock()
+ for v := range tx.reads {
+ v.watchers[tx] = struct{}{}
+ }
for tx.verify() {
- globalCond.Wait()
+ tx.cond.Wait()
+ }
+ for v := range tx.reads {
+ delete(v.watchers, tx)
}
- globalCond.L.Unlock()
+ globalLock.Unlock()
}
// Get returns the value of v as of the start of the transaction.
diff --git a/var.go b/var.go
index 3494352..05e86cf 100644
--- a/var.go
+++ b/var.go
@@ -4,12 +4,16 @@ import "sync"
// A Var holds an STM variable.
type Var struct {
- val interface{}
- version uint64
- mu sync.Mutex
+ val interface{}
+ version uint64
+ mu sync.Mutex
+ watchers map[*Tx]struct{}
}
// NewVar returns a new STM variable.
func NewVar(val interface{}) *Var {
- return &Var{val: val}
+ return &Var{
+ val: val,
+ watchers: make(map[*Tx]struct{}),
+ }
}