aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--funcs.go8
-rw-r--r--global.go6
-rw-r--r--tx.go66
-rw-r--r--var.go5
4 files changed, 66 insertions, 19 deletions
diff --git a/funcs.go b/funcs.go
index bd2cf20..a311c9f 100644
--- a/funcs.go
+++ b/funcs.go
@@ -10,7 +10,7 @@ var (
reads: make(map[*Var]uint64),
writes: make(map[*Var]interface{}),
}
- tx.cond.L = &globalLock
+ tx.cond.L = &tx.mu
return tx
}}
)
@@ -44,14 +44,14 @@ retry:
goto retry
}
// verify the read log
- globalLock.Lock()
+ tx.lockAllVars()
if !tx.verify() {
- globalLock.Unlock()
+ tx.unlock()
goto retry
}
// commit the write log and broadcast that variables have changed
tx.commit()
- globalLock.Unlock()
+ tx.unlock()
tx.recycle()
return ret
}
diff --git a/global.go b/global.go
index bc5f18a..4974334 100644
--- a/global.go
+++ b/global.go
@@ -1,7 +1 @@
package stm
-
-import "sync"
-
-// The globalLock serializes transaction verification/committal. globalCond is
-// used to signal that at least one Var has changed.
-var globalLock sync.Mutex
diff --git a/tx.go b/tx.go
index f6d2fbf..72a58b2 100644
--- a/tx.go
+++ b/tx.go
@@ -1,13 +1,17 @@
package stm
import (
+ "sort"
"sync"
+ "unsafe"
)
// A Tx represents an atomic transaction.
type Tx struct {
reads map[*Var]uint64
writes map[*Var]interface{}
+ locks []*sync.Mutex
+ mu sync.Mutex
cond sync.Cond
}
@@ -26,26 +30,29 @@ func (tx *Tx) verify() bool {
func (tx *Tx) commit() {
for v, val := range tx.writes {
v.changeValue(val)
- for tx := range v.watchers {
+ v.watchers.Range(func(k, _ interface{}) bool {
+ tx := k.(*Tx)
+ tx.mu.Lock()
tx.cond.Broadcast()
- delete(v.watchers, tx)
- }
+ tx.mu.Unlock()
+ return true
+ })
}
}
// wait blocks until another transaction modifies any of the Vars read by tx.
func (tx *Tx) wait() {
- globalLock.Lock()
+ tx.mu.Lock() // probably can around verify
for v := range tx.reads {
- v.watchers[tx] = struct{}{}
+ v.watchers.Store(tx, nil)
}
for tx.verify() {
tx.cond.Wait()
}
for v := range tx.reads {
- delete(v.watchers, tx)
+ v.watchers.Delete(tx)
}
- globalLock.Unlock()
+ tx.mu.Unlock() // move back to verify?
}
// Get returns the value of v as of the start of the transaction.
@@ -98,9 +105,54 @@ func (tx *Tx) reset() {
for k := range tx.writes {
delete(tx.writes, k)
}
+ tx.resetLocks()
}
func (tx *Tx) recycle() {
tx.reset()
txPool.Put(tx)
}
+
+func (tx *Tx) lockAllVars() {
+ tx.resetLocks()
+ tx.collectAllLocks()
+ tx.sortLocks()
+ tx.lock()
+}
+
+func (tx *Tx) resetLocks() {
+ tx.locks = tx.locks[:0]
+}
+
+func (tx *Tx) collectReadLocks() {
+ for v := range tx.reads {
+ tx.locks = append(tx.locks, &v.mu)
+ }
+}
+
+func (tx *Tx) collectAllLocks() {
+ tx.collectReadLocks()
+ for v := range tx.writes {
+ if _, ok := tx.reads[v]; !ok {
+ tx.locks = append(tx.locks, &v.mu)
+ }
+ }
+}
+
+func (tx *Tx) sortLocks() {
+ sort.Slice(tx.locks, func(i, j int) bool {
+ return uintptr(unsafe.Pointer(tx.locks[i])) < uintptr(unsafe.Pointer(tx.locks[j]))
+ })
+}
+
+func (tx *Tx) lock() {
+ for _, l := range tx.locks {
+ l.Lock()
+ }
+}
+
+func (tx *Tx) unlock() {
+ for _, l := range tx.locks {
+ l.Unlock()
+ }
+}
diff --git a/var.go b/var.go
index 2d0cb21..51160e2 100644
--- a/var.go
+++ b/var.go
@@ -1,6 +1,7 @@
package stm
import (
+ "sync"
"sync/atomic"
"unsafe"
)
@@ -8,7 +9,8 @@ import (
// Holds an STM variable.
type Var struct {
state *varSnapshot
- watchers map[*Tx]struct{}
+ watchers sync.Map
+ mu sync.Mutex
}
func (v *Var) addr() *unsafe.Pointer {
@@ -35,6 +37,5 @@ func NewVar(val interface{}) *Var {
state: &varSnapshot{
val: val,
},
- watchers: make(map[*Tx]struct{}),
}
}