diff options
author | Matt Joiner <anacrolix@gmail.com> | 2019-11-04 15:52:42 +1100 |
---|---|---|
committer | Matt Joiner <anacrolix@gmail.com> | 2019-11-04 15:52:42 +1100 |
commit | bbcb2aadd77d362849993806e5c3900b11a1502e (patch) | |
tree | 0c3969c1cb031acdf1a23975b850934d130da87d /tx.go | |
parent | Use atomic pointers for Var data (diff) | |
download | stm-bbcb2aadd77d362849993806e5c3900b11a1502e.tar.gz stm-bbcb2aadd77d362849993806e5c3900b11a1502e.tar.xz |
Remove global lock
Diffstat (limited to 'tx.go')
-rw-r--r-- | tx.go | 66 |
1 files changed, 59 insertions, 7 deletions
@@ -1,13 +1,17 @@ package stm import ( + "sort" "sync" + "unsafe" ) // A Tx represents an atomic transaction. type Tx struct { reads map[*Var]uint64 writes map[*Var]interface{} + locks []*sync.Mutex + mu sync.Mutex cond sync.Cond } @@ -26,26 +30,29 @@ func (tx *Tx) verify() bool { func (tx *Tx) commit() { for v, val := range tx.writes { v.changeValue(val) - for tx := range v.watchers { + v.watchers.Range(func(k, _ interface{}) bool { + tx := k.(*Tx) + tx.mu.Lock() tx.cond.Broadcast() - delete(v.watchers, tx) - } + tx.mu.Unlock() + return true + }) } } // wait blocks until another transaction modifies any of the Vars read by tx. func (tx *Tx) wait() { - globalLock.Lock() + tx.mu.Lock() // probably can around verify for v := range tx.reads { - v.watchers[tx] = struct{}{} + v.watchers.Store(tx, nil) } for tx.verify() { tx.cond.Wait() } for v := range tx.reads { - delete(v.watchers, tx) + v.watchers.Delete(tx) } - globalLock.Unlock() + tx.mu.Unlock() // move back to verify? } // Get returns the value of v as of the start of the transaction. @@ -98,9 +105,54 @@ func (tx *Tx) reset() { for k := range tx.writes { delete(tx.writes, k) } + tx.resetLocks() } func (tx *Tx) recycle() { tx.reset() txPool.Put(tx) } + +func (tx *Tx) lockAllVars() { + tx.resetLocks() + tx.collectAllLocks() + tx.sortLocks() + tx.lock() +} + +func (tx *Tx) resetLocks() { + tx.locks = tx.locks[:0] +} + +func (tx *Tx) collectReadLocks() { + for v := range tx.reads { + tx.locks = append(tx.locks, &v.mu) + } +} + +func (tx *Tx) collectAllLocks() { + tx.collectReadLocks() + for v := range tx.writes { + if _, ok := tx.reads[v]; !ok { + tx.locks = append(tx.locks, &v.mu) + } + } +} + +func (tx *Tx) sortLocks() { + sort.Slice(tx.locks, func(i, j int) bool { + return uintptr(unsafe.Pointer(tx.locks[i])) < uintptr(unsafe.Pointer(tx.locks[j])) + }) +} + +func (tx *Tx) lock() { + for _, l := range tx.locks { + l.Lock() + } +} + +func (tx *Tx) unlock() { + for _, l := range tx.locks { + l.Unlock() + } +} |