aboutsummaryrefslogtreecommitdiff
path: root/tx.go
diff options
context:
space:
mode:
Diffstat (limited to 'tx.go')
-rw-r--r--tx.go40
1 files changed, 24 insertions, 16 deletions
diff --git a/tx.go b/tx.go
index 9be08b5..825ff01 100644
--- a/tx.go
+++ b/tx.go
@@ -5,13 +5,22 @@ import (
"sort"
"sync"
"unsafe"
+
+ "github.com/alecthomas/atomic"
)
+type txVar interface {
+ getValue() *atomic.Value[VarValue]
+ changeValue(any)
+ getWatchers() *sync.Map
+ getLock() *sync.Mutex
+}
+
// A Tx represents an atomic transaction.
type Tx struct {
- reads map[*Var]VarValue
- writes map[*Var]interface{}
- watching map[*Var]struct{}
+ reads map[txVar]VarValue
+ writes map[txVar]any
+ watching map[txVar]struct{}
locks txLocks
mu sync.Mutex
cond sync.Cond
@@ -24,7 +33,7 @@ type Tx struct {
// Check that none of the logged values have changed since the transaction began.
func (tx *Tx) inputsChanged() bool {
for v, read := range tx.reads {
- if read.Changed(v.value.Load().(VarValue)) {
+ if read.Changed(v.getValue().Load()) {
return true
}
}
@@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() {
for v := range tx.watching {
if _, ok := tx.reads[v]; !ok {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
}
for v := range tx.reads {
if _, ok := tx.watching[v]; !ok {
- v.watchers.Store(tx, nil)
+ v.getWatchers().Store(tx, nil)
tx.watching[v] = struct{}{}
}
}
@@ -76,22 +85,22 @@ func (tx *Tx) wait() {
}
// Get returns the value of v as of the start of the transaction.
-func (tx *Tx) Get(v *Var) interface{} {
+func (v *Var[T]) Get(tx *Tx) T {
// If we previously wrote to v, it will be in the write log.
if val, ok := tx.writes[v]; ok {
- return val
+ return val.(T)
}
// If we haven't previously read v, record its version
vv, ok := tx.reads[v]
if !ok {
- vv = v.value.Load().(VarValue)
+ vv = v.getValue().Load()
tx.reads[v] = vv
}
- return vv.Get()
+ return vv.Get().(T)
}
// Set sets the value of a Var for the lifetime of the transaction.
-func (tx *Tx) Set(v *Var, val interface{}) {
+func (v *Var[T]) Set(tx *Tx, val T) {
if v == nil {
panic("nil Var")
}
@@ -105,11 +114,10 @@ type txProfileValue struct {
// Retry aborts the transaction and retries it when a Var changes. You can return from this method
// to satisfy return values, but it should never actually return anything as it panics internally.
-func (tx *Tx) Retry() interface{} {
+func (tx *Tx) Retry() struct{} {
retries.Add(txProfileValue{tx, tx.numRetryValues}, 1)
tx.numRetryValues++
panic(retry)
- panic("unreachable")
}
// Assert is a helper function that retries a transaction if the condition is
@@ -143,7 +151,7 @@ func (tx *Tx) removeRetryProfiles() {
func (tx *Tx) recycle() {
for v := range tx.watching {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
tx.removeRetryProfiles()
// I don't think we can reuse Txs, because the "completed" field should/needs to be set
@@ -164,7 +172,7 @@ func (tx *Tx) resetLocks() {
func (tx *Tx) collectReadLocks() {
for v := range tx.reads {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
@@ -172,7 +180,7 @@ func (tx *Tx) collectAllLocks() {
tx.collectReadLocks()
for v := range tx.writes {
if _, ok := tx.reads[v]; !ok {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
}