aboutsummaryrefslogtreecommitdiff
path: root/tx.go
diff options
context:
space:
mode:
Diffstat (limited to 'tx.go')
-rw-r--r--tx.go37
1 files changed, 23 insertions, 14 deletions
diff --git a/tx.go b/tx.go
index 3cd0301..0d69498 100644
--- a/tx.go
+++ b/tx.go
@@ -5,13 +5,22 @@ import (
"sort"
"sync"
"unsafe"
+
+ "github.com/alecthomas/atomic"
)
+type txVar interface {
+ getValue() *atomic.Value[VarValue]
+ changeValue(interface{})
+ getWatchers() *sync.Map
+ getLock() *sync.Mutex
+}
+
// A Tx represents an atomic transaction.
type Tx struct {
- reads map[*Var]VarValue
- writes map[*Var]interface{}
- watching map[*Var]struct{}
+ reads map[txVar]VarValue
+ writes map[txVar]interface{}
+ watching map[txVar]struct{}
locks txLocks
mu sync.Mutex
cond sync.Cond
@@ -24,7 +33,7 @@ type Tx struct {
// Check that none of the logged values have changed since the transaction began.
func (tx *Tx) inputsChanged() bool {
for v, read := range tx.reads {
- if read.Changed(v.value.Load()) {
+ if read.Changed(v.getValue().Load()) {
return true
}
}
@@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() {
for v := range tx.watching {
if _, ok := tx.reads[v]; !ok {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
}
for v := range tx.reads {
if _, ok := tx.watching[v]; !ok {
- v.watchers.Store(tx, nil)
+ v.getWatchers().Store(tx, nil)
tx.watching[v] = struct{}{}
}
}
@@ -76,22 +85,22 @@ func (tx *Tx) wait() {
}
// Get returns the value of v as of the start of the transaction.
-func (tx *Tx) Get(v *Var) interface{} {
+func (v *Var[T]) Get(tx *Tx) T {
// If we previously wrote to v, it will be in the write log.
if val, ok := tx.writes[v]; ok {
- return val
+ return val.(T)
}
// If we haven't previously read v, record its version
vv, ok := tx.reads[v]
if !ok {
- vv = v.value.Load()
+ vv = v.getValue().Load()
tx.reads[v] = vv
}
- return vv.Get()
+ return vv.Get().(T)
}
// Set sets the value of a Var for the lifetime of the transaction.
-func (tx *Tx) Set(v *Var, val interface{}) {
+func (v *Var[T]) Set(tx *Tx, val T) {
if v == nil {
panic("nil Var")
}
@@ -143,7 +152,7 @@ func (tx *Tx) removeRetryProfiles() {
func (tx *Tx) recycle() {
for v := range tx.watching {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
tx.removeRetryProfiles()
// I don't think we can reuse Txs, because the "completed" field should/needs to be set
@@ -164,7 +173,7 @@ func (tx *Tx) resetLocks() {
func (tx *Tx) collectReadLocks() {
for v := range tx.reads {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
@@ -172,7 +181,7 @@ func (tx *Tx) collectAllLocks() {
tx.collectReadLocks()
for v := range tx.writes {
if _, ok := tx.reads[v]; !ok {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
}