From 5d3a8dccf84bcdccb5e51eec11db732eb1911277 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 23 Oct 2019 18:27:11 +1100 Subject: Register transaction condition with each read Var only --- funcs.go | 12 ++++-------- tx.go | 16 +++++++++++++++- var.go | 12 ++++++++---- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/funcs.go b/funcs.go index f87efb8..2170004 100644 --- a/funcs.go +++ b/funcs.go @@ -8,6 +8,7 @@ retry: reads: make(map[*Var]uint64), writes: make(map[*Var]interface{}), } + tx.cond.L = &globalLock if catchRetry(fn, tx) { // wait for one of the variables we read to change before retrying tx.wait() @@ -40,14 +41,9 @@ func AtomicGet(v *Var) interface{} { // AtomicSet is a helper function that atomically writes a value. func AtomicSet(v *Var, val interface{}) { - // since we're only doing one operation, we don't need a full transaction - globalLock.Lock() - v.mu.Lock() - v.val = val - v.version++ - v.mu.Unlock() - globalCond.Broadcast() - globalLock.Unlock() + Atomically(func(tx *Tx) { + tx.Set(v, val) + }) } // Compose is a helper function that composes multiple transactions into a diff --git a/tx.go b/tx.go index 127480e..a17cc13 100644 --- a/tx.go +++ b/tx.go @@ -1,9 +1,14 @@ package stm +import ( + "sync" +) + // A Tx represents an atomic transaction. type Tx struct { reads map[*Var]uint64 writes map[*Var]interface{} + cond sync.Cond } // Check that none of the logged values have changed since the transaction began. @@ -25,6 +30,9 @@ func (tx *Tx) commit() { v.mu.Lock() v.val = val v.version++ + for tx := range v.watchers { + tx.cond.Broadcast() + } v.mu.Unlock() } } @@ -32,8 +40,14 @@ func (tx *Tx) commit() { // wait blocks until another transaction modifies any of the Vars read by tx. func (tx *Tx) wait() { globalCond.L.Lock() + for v := range tx.reads { + v.watchers[tx] = struct{}{} + } for tx.verify() { - globalCond.Wait() + tx.cond.Wait() + } + for v := range tx.reads { + delete(v.watchers, tx) } globalCond.L.Unlock() } diff --git a/var.go b/var.go index 3494352..05e86cf 100644 --- a/var.go +++ b/var.go @@ -4,12 +4,16 @@ import "sync" // A Var holds an STM variable. type Var struct { - val interface{} - version uint64 - mu sync.Mutex + val interface{} + version uint64 + mu sync.Mutex + watchers map[*Tx]struct{} } // NewVar returns a new STM variable. func NewVar(val interface{}) *Var { - return &Var{val: val} + return &Var{ + val: val, + watchers: make(map[*Tx]struct{}), + } } -- cgit v1.2.3