aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md22
-rw-r--r--bench_test.go14
-rw-r--r--cmd/santa-example/main.go40
-rw-r--r--doc.go16
-rw-r--r--doc_test.go24
-rw-r--r--external_test.go50
-rw-r--r--funcs.go25
-rw-r--r--rate/rate_test.go1
-rw-r--r--rate/ratelimit.go30
-rw-r--r--stm_test.go74
-rw-r--r--stmutil/context.go6
-rw-r--r--tx.go37
-rw-r--r--var-value.go20
-rw-r--r--var.go32
14 files changed, 203 insertions, 188 deletions
diff --git a/README.md b/README.md
index 6330f79..22dfdda 100644
--- a/README.md
+++ b/README.md
@@ -42,11 +42,11 @@ Be very careful when managing pointers inside transactions! (This includes
slices, maps, channels, and captured variables.) Here's why:
```go
-p := stm.NewVar([]byte{1,2,3})
+p := stm.NewVar[[]byte]([]byte{1,2,3})
stm.Atomically(func(tx *stm.Tx) {
- b := tx.Get(p).([]byte)
+ b := p.Get(tx)
b[0] = 7
- tx.Set(p, b)
+ stm.p.Set(tx, b)
})
```
@@ -57,11 +57,11 @@ Following this advice, we can rewrite the transaction to perform a copy:
```go
stm.Atomically(func(tx *stm.Tx) {
- b := tx.Get(p).([]byte)
+ b := p.Get(tx)
c := make([]byte, len(b))
copy(c, b)
c[0] = 7
- tx.Set(p, c)
+ p.Set(tx, c)
})
```
@@ -73,11 +73,11 @@ In the same vein, it would be a mistake to do this:
type foo struct {
i int
}
-p := stm.NewVar(&foo{i: 2})
+p := stm.NewVar[*foo](&foo{i: 2})
stm.Atomically(func(tx *stm.Tx) {
- f := tx.Get(p).(*foo)
+ f := p.Get(tx)
f.i = 7
- tx.Set(p, f)
+ stm.p.Set(tx, f)
})
```
@@ -88,11 +88,11 @@ the correct approach is to move the `Var` inside the struct:
type foo struct {
i *stm.Var
}
-f := foo{i: stm.NewVar(2)}
+f := foo{i: stm.NewVar[int](2)}
stm.Atomically(func(tx *stm.Tx) {
- i := tx.Get(f.i).(int)
+ i := f.i.Get(tx).(int)
i = 7
- tx.Set(f.i, i)
+ f.i.Set(tx, i)
})
```
diff --git a/bench_test.go b/bench_test.go
index 82aaba1..0b84715 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -8,14 +8,14 @@ import (
)
func BenchmarkAtomicGet(b *testing.B) {
- x := NewVar(0)
+ x := NewVar[int](0)
for i := 0; i < b.N; i++ {
AtomicGet(x)
}
}
func BenchmarkAtomicSet(b *testing.B) {
- x := NewVar(0)
+ x := NewVar[int](0)
for i := 0; i < b.N; i++ {
AtomicSet(x, 0)
}
@@ -24,16 +24,16 @@ func BenchmarkAtomicSet(b *testing.B) {
func BenchmarkIncrementSTM(b *testing.B) {
for i := 0; i < b.N; i++ {
// spawn 1000 goroutines that each increment x by 1
- x := NewVar(0)
+ x := NewVar[int](0)
for i := 0; i < 1000; i++ {
go Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur+1)
+ cur := x.Get(tx)
+ x.Set(tx, cur+1)
}))
}
// wait for x to reach 1000
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 1000)
+ tx.Assert(x.Get(tx) == 1000)
}))
}
}
@@ -83,7 +83,7 @@ func BenchmarkReadVarSTM(b *testing.B) {
for i := 0; i < b.N; i++ {
var wg sync.WaitGroup
wg.Add(1000)
- x := NewVar(0)
+ x := NewVar[int](0)
for i := 0; i < 1000; i++ {
go func() {
AtomicGet(x)
diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go
index 3c2b9ea..72be64f 100644
--- a/cmd/santa-example/main.go
+++ b/cmd/santa-example/main.go
@@ -39,15 +39,15 @@ import (
type gate struct {
capacity int
- remaining *stm.Var
+ remaining *stm.Var[int]
}
func (g gate) pass() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
// wait until gate can hold us
tx.Assert(rem > 0)
- tx.Set(g.remaining, rem-1)
+ g.remaining.Set(tx, rem-1)
}))
}
@@ -56,7 +56,7 @@ func (g gate) operate() {
stm.AtomicSet(g.remaining, g.capacity)
// wait for gate to be full
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
tx.Assert(rem == 0)
}))
}
@@ -64,49 +64,49 @@ func (g gate) operate() {
func newGate(capacity int) gate {
return gate{
capacity: capacity,
- remaining: stm.NewVar(0), // gate starts out closed
+ remaining: stm.NewVar[int](0), // gate starts out closed
}
}
type group struct {
capacity int
- remaining *stm.Var
- gate1, gate2 *stm.Var
+ remaining *stm.Var[int]
+ gate1, gate2 *stm.Var[gate]
}
func newGroup(capacity int) *group {
return &group{
capacity: capacity,
- remaining: stm.NewVar(capacity), // group starts out with full capacity
- gate1: stm.NewVar(newGate(capacity)),
- gate2: stm.NewVar(newGate(capacity)),
+ remaining: stm.NewVar[int](capacity), // group starts out with full capacity
+ gate1: stm.NewVar[gate](newGate(capacity)),
+ gate2: stm.NewVar[gate](newGate(capacity)),
}
}
func (g *group) join() (g1, g2 gate) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
// wait until the group can hold us
tx.Assert(rem > 0)
- tx.Set(g.remaining, rem-1)
+ g.remaining.Set(tx, rem-1)
// return the group's gates
- g1 = tx.Get(g.gate1).(gate)
- g2 = tx.Get(g.gate2).(gate)
+ g1 = g.gate1.Get(tx)
+ g2 = g.gate2.Get(tx)
}))
return
}
func (g *group) await(tx *stm.Tx) (gate, gate) {
// wait for group to be empty
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
tx.Assert(rem == 0)
// get the group's gates
- g1 := tx.Get(g.gate1).(gate)
- g2 := tx.Get(g.gate2).(gate)
+ g1 := g.gate1.Get(tx)
+ g2 := g.gate2.Get(tx)
// reset group
- tx.Set(g.remaining, g.capacity)
- tx.Set(g.gate1, newGate(g.capacity))
- tx.Set(g.gate2, newGate(g.capacity))
+ g.remaining.Set(tx, g.capacity)
+ g.gate1.Set(tx, newGate(g.capacity))
+ g.gate2.Set(tx, newGate(g.capacity))
return g1, g2
}
diff --git a/doc.go b/doc.go
index db03501..748af80 100644
--- a/doc.go
+++ b/doc.go
@@ -10,14 +10,14 @@ it non-atomic).
To begin, create an STM object that wraps the data you want to access
concurrently.
- x := stm.NewVar(3)
+ x := stm.NewVar[int](3)
You can then use the Atomically method to atomically read and/or write the the
data. This code atomically decrements x:
stm.Atomically(func(tx *stm.Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
})
An important part of STM transactions is retrying. At any point during the
@@ -29,11 +29,11 @@ updated before the transaction will be rerun. As an example, this code will
try to decrement x, but will block as long as x is zero:
stm.Atomically(func(tx *stm.Tx) {
- cur := tx.Get(x).(int)
+ cur := x.Get(tx)
if cur == 0 {
tx.Retry()
}
- tx.Set(x, cur-1)
+ x.Set(tx, cur-1)
})
Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other
@@ -47,13 +47,13 @@ retried. For example, this code implements the "decrement-if-nonzero"
transaction above, but for two values. It will first try to decrement x, then
y, and block if both values are zero.
- func dec(v *stm.Var) {
+ func dec(v *stm.Var[int]) {
return func(tx *stm.Tx) {
- cur := tx.Get(v).(int)
+ cur := v.Get(tx)
if cur == 0 {
tx.Retry()
}
- tx.Set(v, cur-1)
+ v.Set(tx, cur-1)
}
}
diff --git a/doc_test.go b/doc_test.go
index f6bb863..670d07b 100644
--- a/doc_test.go
+++ b/doc_test.go
@@ -6,43 +6,43 @@ import (
func Example() {
// create a shared variable
- n := stm.NewVar(3)
+ n := stm.NewVar[int](3)
// read a variable
var v int
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- v = tx.Get(n).(int)
+ v = n.Get(tx)
}))
// or:
- v = stm.AtomicGet(n).(int)
+ v = stm.AtomicGet(n)
_ = v
// write to a variable
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(n, 12)
+ n.Set(tx, 12)
}))
// or:
stm.AtomicSet(n, 12)
// update a variable
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
- tx.Set(n, cur-1)
+ cur := n.Get(tx)
+ n.Set(tx, cur-1)
}))
// block until a condition is met
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
+ cur := n.Get(tx)
if cur != 0 {
tx.Retry()
}
- tx.Set(n, 10)
+ n.Set(tx, 10)
}))
// or:
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
+ cur := n.Get(tx)
tx.Assert(cur == 0)
- tx.Set(n, 10)
+ n.Set(tx, 10)
}))
// select among multiple (potentially blocking) transactions
@@ -51,11 +51,11 @@ func Example() {
stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }),
// this function will always succeed without blocking
- stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 10) }),
+ stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }),
// this function will never run, because the previous
// function succeeded
- stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 11) }),
+ stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }),
))
// since Select is a normal transaction, if the entire select retries
diff --git a/external_test.go b/external_test.go
index ae29ca8..d0f5ecf 100644
--- a/external_test.go
+++ b/external_test.go
@@ -63,14 +63,14 @@ func BenchmarkThunderingHerd(b *testing.B) {
pending := stm.NewBuiltinEqVar(0)
for range iter.N(1000) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(pending, tx.Get(pending).(int)+1)
+ pending.Set(tx, pending.Get(tx)+1)
}))
go func() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- t := tx.Get(tokens).(int)
+ t := tokens.Get(tx)
if t > 0 {
- tx.Set(tokens, t-1)
- tx.Set(pending, tx.Get(pending).(int)-1)
+ tokens.Set(tx, t-1)
+ pending.Set(tx, pending.Get(tx)-1)
} else {
tx.Retry()
}
@@ -79,17 +79,17 @@ func BenchmarkThunderingHerd(b *testing.B) {
}
go func() {
for stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(done).(bool) {
+ if done.Get(tx) {
return false
}
- tx.Assert(tx.Get(tokens).(int) < maxTokens)
- tx.Set(tokens, tx.Get(tokens).(int)+1)
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
return true
}).(bool) {
}
}()
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(pending).(int) == 0)
+ tx.Assert(pending.Get(tx) == 0)
}))
stm.AtomicSet(done, true)
}
@@ -99,53 +99,53 @@ func BenchmarkInvertedThunderingHerd(b *testing.B) {
for i := 0; i < b.N; i++ {
done := stm.NewBuiltinEqVar(false)
tokens := stm.NewBuiltinEqVar(0)
- pending := stm.NewVar(stmutil.NewSet())
+ pending := stm.NewVar[stmutil.Settish](stmutil.NewSet())
for range iter.N(1000) {
- ready := stm.NewVar(false)
+ ready := stm.NewVar[bool](false)
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(pending, tx.Get(pending).(stmutil.Settish).Add(ready))
+ pending.Set(tx, pending.Get(tx).Add(ready))
}))
go func() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(ready).(bool))
- set := tx.Get(pending).(stmutil.Settish)
+ tx.Assert(ready.Get(tx))
+ set := pending.Get(tx)
if !set.Contains(ready) {
panic("couldn't find ourselves in pending")
}
- tx.Set(pending, set.Delete(ready))
+ pending.Set(tx, set.Delete(ready))
}))
//b.Log("waiter finished")
}()
}
go func() {
for stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(done).(bool) {
+ if done.Get(tx) {
return false
}
- tx.Assert(tx.Get(tokens).(int) < maxTokens)
- tx.Set(tokens, tx.Get(tokens).(int)+1)
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
return true
}).(bool) {
}
}()
go func() {
for stm.Atomically(func(tx *stm.Tx) interface{} {
- tx.Assert(tx.Get(tokens).(int) > 0)
- tx.Set(tokens, tx.Get(tokens).(int)-1)
- tx.Get(pending).(stmutil.Settish).Range(func(i interface{}) bool {
- ready := i.(*stm.Var)
- if !tx.Get(ready).(bool) {
- tx.Set(ready, true)
+ tx.Assert(tokens.Get(tx) > 0)
+ tokens.Set(tx, tokens.Get(tx)-1)
+ pending.Get(tx).Range(func(i interface{}) bool {
+ ready := i.(*stm.Var[bool])
+ if !ready.Get(tx) {
+ ready.Set(tx, true)
return false
}
return true
})
- return !tx.Get(done).(bool)
+ return !done.Get(tx)
}).(bool) {
}
}()
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(pending).(stmutil.Lenner).Len() == 0)
+ tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0)
}))
stm.AtomicSet(done, true)
}
diff --git a/funcs.go b/funcs.go
index 53975e0..3336c71 100644
--- a/funcs.go
+++ b/funcs.go
@@ -2,7 +2,6 @@ package stm
import (
"math/rand"
- "reflect"
"runtime/pprof"
"sync"
"time"
@@ -12,9 +11,9 @@ var (
txPool = sync.Pool{New: func() interface{} {
expvars.Add("new txs", 1)
tx := &Tx{
- reads: make(map[*Var]VarValue),
- writes: make(map[*Var]interface{}),
- watching: make(map[*Var]struct{}),
+ reads: make(map[txVar]VarValue),
+ writes: make(map[txVar]interface{}),
+ watching: make(map[txVar]struct{}),
}
tx.cond.L = &tx.mu
return tx
@@ -104,12 +103,12 @@ retry:
}
// AtomicGet is a helper function that atomically reads a value.
-func AtomicGet(v *Var) interface{} {
- return v.value.Load().Get()
+func AtomicGet[T any](v *Var[T]) T {
+ return v.value.Load().Get().(T)
}
// AtomicSet is a helper function that atomically writes a value.
-func AtomicSet(v *Var, val interface{}) {
+func AtomicSet[T any](v *Var[T], val interface{}) {
v.mu.Lock()
v.changeValue(val)
v.mu.Unlock()
@@ -140,7 +139,7 @@ func Select(fns ...Operation) Operation {
return fns[0](tx)
default:
oldWrites := tx.writes
- tx.writes = make(map[*Var]interface{}, len(oldWrites))
+ tx.writes = make(map[txVar]interface{}, len(oldWrites))
for k, v := range oldWrites {
tx.writes[k] = v
}
@@ -164,14 +163,8 @@ func VoidOperation(f func(*Tx)) Operation {
}
}
-func AtomicModify(v *Var, f interface{}) {
- r := reflect.ValueOf(f)
+func AtomicModify[T any](v *Var[T], f func(T) T) {
Atomically(VoidOperation(func(tx *Tx) {
- cur := reflect.ValueOf(tx.Get(v))
- out := r.Call([]reflect.Value{cur})
- if lenOut := len(out); lenOut != 1 {
- panic(lenOut)
- }
- tx.Set(v, out[0].Interface())
+ v.Set(tx, f(v.Get(tx)))
}))
}
diff --git a/rate/rate_test.go b/rate/rate_test.go
index 7f44d74..a4e2fac 100644
--- a/rate/rate_test.go
+++ b/rate/rate_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build go1.7
// +build go1.7
package rate
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
index a44f383..7fde1ce 100644
--- a/rate/ratelimit.go
+++ b/rate/ratelimit.go
@@ -13,9 +13,9 @@ import (
type numTokens = int
type Limiter struct {
- max *stm.Var
- cur *stm.Var
- lastAdd *stm.Var
+ max *stm.Var[numTokens]
+ cur *stm.Var[numTokens]
+ lastAdd *stm.Var[time.Time]
rate Limit
}
@@ -36,9 +36,9 @@ func Every(interval time.Duration) Limit {
func NewLimiter(rate Limit, burst numTokens) *Limiter {
rl := &Limiter{
- max: stm.NewVar(burst),
+ max: stm.NewVar[int](burst),
cur: stm.NewBuiltinEqVar(burst),
- lastAdd: stm.NewVar(time.Now()),
+ lastAdd: stm.NewVar[time.Time](time.Now()),
rate: rate,
}
if rate != Inf {
@@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter {
func (rl *Limiter) tokenGenerator(interval time.Duration) {
for {
- lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
+ lastAdd := stm.AtomicGet(rl.lastAdd)
time.Sleep(time.Until(lastAdd.Add(interval)))
now := time.Now()
available := numTokens(now.Sub(lastAdd) / interval)
@@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
continue
}
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(rl.cur).(numTokens)
- max := tx.Get(rl.max).(numTokens)
+ cur := rl.cur.Get(tx)
+ max := rl.max.Get(tx)
tx.Assert(cur < max)
newCur := cur + available
if newCur > max {
newCur = max
}
if newCur != cur {
- tx.Set(rl.cur, newCur)
+ rl.cur.Set(tx, newCur)
}
- tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
+ rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available)))
}))
}
}
@@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
if rl.rate == Inf {
return true
}
- cur := tx.Get(rl.cur).(numTokens)
+ cur := rl.cur.Get(tx)
if cur >= n {
- tx.Set(rl.cur, cur-n)
+ rl.cur.Set(tx, cur-n)
return true
}
return false
@@ -106,17 +106,17 @@ func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
if err := stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(ctxDone).(bool) {
+ if ctxDone.Get(tx) {
return ctx.Err()
}
if rl.takeTokens(tx, n) {
return nil
}
- if n > tx.Get(rl.max).(numTokens) {
+ if n > rl.max.Get(tx) {
return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
- if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
+ if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n {
return context.DeadlineExceeded
}
}
diff --git a/stm_test.go b/stm_test.go
index 5c0ae31..fdf7af2 100644
--- a/stm_test.go
+++ b/stm_test.go
@@ -11,17 +11,17 @@ import (
)
func TestDecrement(t *testing.T) {
- x := NewVar(1000)
+ x := NewVar[int](1000)
for i := 0; i < 500; i++ {
go Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
}))
}
done := make(chan struct{})
go func() {
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x) == 500)
+ tx.Assert(x.Get(tx) == 500)
}))
close(done)
}()
@@ -35,7 +35,7 @@ func TestDecrement(t *testing.T) {
// read-only transaction aren't exempt from calling tx.inputsChanged
func TestReadVerify(t *testing.T) {
read := make(chan struct{})
- x, y := NewVar(1), NewVar(2)
+ x, y := NewVar[int](1), NewVar[int](2)
// spawn a transaction that writes to x
go func() {
@@ -50,10 +50,10 @@ func TestReadVerify(t *testing.T) {
// between the reads, causing this tx to retry.
var x2, y2 int
Atomically(VoidOperation(func(tx *Tx) {
- x2 = tx.Get(x).(int)
+ x2 = x.Get(tx)
read <- struct{}{}
<-read // wait for other tx to complete
- y2 = tx.Get(y).(int)
+ y2 = y.Get(tx)
}))
if x2 == 1 && y2 == 2 {
t.Fatal("read was not verified")
@@ -61,15 +61,15 @@ func TestReadVerify(t *testing.T) {
}
func TestRetry(t *testing.T) {
- x := NewVar(10)
+ x := NewVar[int](10)
// spawn 10 transactions, one every 10 milliseconds. This will decrement x
// to 0 over the course of 100 milliseconds.
go func() {
for i := 0; i < 10; i++ {
time.Sleep(10 * time.Millisecond)
Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
}))
}
}()
@@ -77,7 +77,7 @@ func TestRetry(t *testing.T) {
// retry. This should result in no more than 1 retry per transaction.
retry := 0
Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
+ cur := x.Get(tx)
if cur != 0 {
retry++
tx.Retry()
@@ -93,16 +93,16 @@ func TestVerify(t *testing.T) {
type foo struct {
i int
}
- x := NewVar(&foo{3})
+ x := NewVar[*foo](&foo{3})
read := make(chan struct{})
// spawn a transaction that modifies x
go func() {
Atomically(VoidOperation(func(tx *Tx) {
<-read
- rx := tx.Get(x).(*foo)
+ rx := x.Get(tx)
rx.i = 7
- tx.Set(x, rx)
+ x.Set(tx, rx)
}))
read <- struct{}{}
// other tx should retry, so we need to read/send again
@@ -113,7 +113,7 @@ func TestVerify(t *testing.T) {
// between the reads, causing this tx to retry.
var i int
Atomically(VoidOperation(func(tx *Tx) {
- f := tx.Get(x).(*foo)
+ f := x.Get(tx)
i = f.i
read <- struct{}{}
<-read // wait for other tx to complete
@@ -128,9 +128,9 @@ func TestSelect(t *testing.T) {
require.Panics(t, func() { Atomically(Select()) })
// with one arg, Select adds no effect
- x := NewVar(2)
+ x := NewVar[int](2)
Atomically(Select(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 2)
+ tx.Assert(x.Get(tx) == 2)
})))
picked := Atomically(Select(
@@ -146,7 +146,7 @@ func TestSelect(t *testing.T) {
func(tx *Tx) interface{} {
return 3
},
- )).(int)
+ ))
assert.EqualValues(t, 2, picked)
}
@@ -178,20 +178,20 @@ func TestPanic(t *testing.T) {
func TestReadWritten(t *testing.T) {
// reading a variable written in the same transaction should return the
// previously written value
- x := NewVar(3)
+ x := NewVar[int](3)
Atomically(VoidOperation(func(tx *Tx) {
- tx.Set(x, 5)
- tx.Assert(tx.Get(x).(int) == 5)
+ x.Set(tx, 5)
+ tx.Assert(x.Get(tx) == 5)
}))
}
func TestAtomicSetRetry(t *testing.T) {
// AtomicSet should cause waiting transactions to retry
- x := NewVar(3)
+ x := NewVar[int](3)
done := make(chan struct{})
go func() {
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 5)
+ tx.Assert(x.Get(tx) == 5)
}))
done <- struct{}{}
}()
@@ -206,21 +206,21 @@ func TestAtomicSetRetry(t *testing.T) {
func testPingPong(t testing.TB, n int, afterHit func(string)) {
ball := NewBuiltinEqVar(false)
- doneVar := NewVar(false)
- hits := NewVar(0)
- ready := NewVar(true) // The ball is ready for hitting.
+ doneVar := NewVar[bool](false)
+ hits := NewVar[int](0)
+ ready := NewVar[bool](true) // The ball is ready for hitting.
var wg sync.WaitGroup
- bat := func(from, to interface{}, noise string) {
+ bat := func(from, to bool, noise string) {
defer wg.Done()
for !Atomically(func(tx *Tx) interface{} {
- if tx.Get(doneVar).(bool) {
+ if doneVar.Get(tx) {
return true
}
- tx.Assert(tx.Get(ready).(bool))
- if tx.Get(ball) == from {
- tx.Set(ball, to)
- tx.Set(hits, tx.Get(hits).(int)+1)
- tx.Set(ready, false)
+ tx.Assert(ready.Get(tx))
+ if ball.Get(tx) == from {
+ ball.Set(tx, to)
+ hits.Set(tx, hits.Get(tx)+1)
+ ready.Set(tx, false)
return false
}
return tx.Retry()
@@ -233,8 +233,8 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) {
go bat(false, true, "ping!")
go bat(true, false, "pong!")
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(hits).(int) >= n)
- tx.Set(doneVar, true)
+ tx.Assert(hits.Get(tx) >= n)
+ doneVar.Set(tx, true)
}))
wg.Wait()
}
@@ -253,7 +253,7 @@ func TestSleepingBeauty(t *testing.T) {
}
//func TestRetryStack(t *testing.T) {
-// v := NewVar(nil)
+// v := NewVar[int](nil)
// go func() {
// i := 0
// for {
@@ -266,7 +266,7 @@ func TestSleepingBeauty(t *testing.T) {
// ret := func() {
// defer Atomically(nil)
// }
-// tx.Get(v)
+// v.Get(tx)
// tx.Assert(false)
// return ret
// })
diff --git a/stmutil/context.go b/stmutil/context.go
index 8a4d58d..9d23e12 100644
--- a/stmutil/context.go
+++ b/stmutil/context.go
@@ -9,12 +9,12 @@ import (
var (
mu sync.Mutex
- ctxVars = map[context.Context]*stm.Var{}
+ ctxVars = map[context.Context]*stm.Var[bool]{}
)
// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be
// called when the user is no longer interested in the var.
-func ContextDoneVar(ctx context.Context) (*stm.Var, func()) {
+func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) {
mu.Lock()
defer mu.Unlock()
if v, ok := ctxVars[ctx]; ok {
@@ -26,7 +26,7 @@ func ContextDoneVar(ctx context.Context) (*stm.Var, func()) {
v := stm.NewBuiltinEqVar(true)
return v, func() {}
}
- v := stm.NewVar(false)
+ v := stm.NewVar[bool](false)
go func() {
<-ctx.Done()
stm.AtomicSet(v, true)
diff --git a/tx.go b/tx.go
index 3cd0301..0d69498 100644
--- a/tx.go
+++ b/tx.go
@@ -5,13 +5,22 @@ import (
"sort"
"sync"
"unsafe"
+
+ "github.com/alecthomas/atomic"
)
+type txVar interface {
+ getValue() *atomic.Value[VarValue]
+ changeValue(interface{})
+ getWatchers() *sync.Map
+ getLock() *sync.Mutex
+}
+
// A Tx represents an atomic transaction.
type Tx struct {
- reads map[*Var]VarValue
- writes map[*Var]interface{}
- watching map[*Var]struct{}
+ reads map[txVar]VarValue
+ writes map[txVar]interface{}
+ watching map[txVar]struct{}
locks txLocks
mu sync.Mutex
cond sync.Cond
@@ -24,7 +33,7 @@ type Tx struct {
// Check that none of the logged values have changed since the transaction began.
func (tx *Tx) inputsChanged() bool {
for v, read := range tx.reads {
- if read.Changed(v.value.Load()) {
+ if read.Changed(v.getValue().Load()) {
return true
}
}
@@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() {
for v := range tx.watching {
if _, ok := tx.reads[v]; !ok {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
}
for v := range tx.reads {
if _, ok := tx.watching[v]; !ok {
- v.watchers.Store(tx, nil)
+ v.getWatchers().Store(tx, nil)
tx.watching[v] = struct{}{}
}
}
@@ -76,22 +85,22 @@ func (tx *Tx) wait() {
}
// Get returns the value of v as of the start of the transaction.
-func (tx *Tx) Get(v *Var) interface{} {
+func (v *Var[T]) Get(tx *Tx) T {
// If we previously wrote to v, it will be in the write log.
if val, ok := tx.writes[v]; ok {
- return val
+ return val.(T)
}
// If we haven't previously read v, record its version
vv, ok := tx.reads[v]
if !ok {
- vv = v.value.Load()
+ vv = v.getValue().Load()
tx.reads[v] = vv
}
- return vv.Get()
+ return vv.Get().(T)
}
// Set sets the value of a Var for the lifetime of the transaction.
-func (tx *Tx) Set(v *Var, val interface{}) {
+func (v *Var[T]) Set(tx *Tx, val T) {
if v == nil {
panic("nil Var")
}
@@ -143,7 +152,7 @@ func (tx *Tx) removeRetryProfiles() {
func (tx *Tx) recycle() {
for v := range tx.watching {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
tx.removeRetryProfiles()
// I don't think we can reuse Txs, because the "completed" field should/needs to be set
@@ -164,7 +173,7 @@ func (tx *Tx) resetLocks() {
func (tx *Tx) collectReadLocks() {
for v := range tx.reads {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
@@ -172,7 +181,7 @@ func (tx *Tx) collectAllLocks() {
tx.collectReadLocks()
for v := range tx.writes {
if _, ok := tx.reads[v]; !ok {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
}
diff --git a/var-value.go b/var-value.go
index ff97104..966399a 100644
--- a/var-value.go
+++ b/var-value.go
@@ -28,24 +28,24 @@ func (me versionedValue) Changed(other VarValue) bool {
return me.version != other.(versionedValue).version
}
-type customVarValue struct {
- value interface{}
- changed func(interface{}, interface{}) bool
+type customVarValue[T any] struct {
+ value T
+ changed func(T, T) bool
}
-var _ VarValue = customVarValue{}
+var _ VarValue = customVarValue[struct{}]{}
-func (me customVarValue) Changed(other VarValue) bool {
- return me.changed(me.value, other.(customVarValue).value)
+func (me customVarValue[T]) Changed(other VarValue) bool {
+ return me.changed(me.value, other.(customVarValue[T]).value)
}
-func (me customVarValue) Set(newValue interface{}) VarValue {
- return customVarValue{
- value: newValue,
+func (me customVarValue[T]) Set(newValue interface{}) VarValue {
+ return customVarValue[T]{
+ value: newValue.(T),
changed: me.changed,
}
}
-func (me customVarValue) Get() interface{} {
+func (me customVarValue[T]) Get() interface{} {
return me.value
}
diff --git a/var.go b/var.go
index ec85996..40c3c74 100644
--- a/var.go
+++ b/var.go
@@ -7,13 +7,25 @@ import (
)
// Holds an STM variable.
-type Var struct {
+type Var[T any] struct {
value atomic.Value[VarValue]
watchers sync.Map
mu sync.Mutex
}
-func (v *Var) changeValue(new interface{}) {
+func (v *Var[T]) getValue() *atomic.Value[VarValue] {
+ return &v.value
+}
+
+func (v *Var[T]) getWatchers() *sync.Map {
+ return &v.watchers
+}
+
+func (v *Var[T]) getLock() *sync.Mutex {
+ return &v.mu
+}
+
+func (v *Var[T]) changeValue(new interface{}) {
old := v.value.Load()
newVarValue := old.Set(new)
v.value.Store(newVarValue)
@@ -22,7 +34,7 @@ func (v *Var) changeValue(new interface{}) {
}
}
-func (v *Var) wakeWatchers(new VarValue) {
+func (v *Var[T]) wakeWatchers(new VarValue) {
v.watchers.Range(func(k, _ interface{}) bool {
tx := k.(*Tx)
// We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we
@@ -45,25 +57,25 @@ type varSnapshot struct {
}
// Returns a new STM variable.
-func NewVar(val interface{}) *Var {
- v := &Var{}
+func NewVar[T any](val interface{}) *Var[T] {
+ v := &Var[T]{}
v.value.Store(versionedValue{
value: val,
})
return v
}
-func NewCustomVar(val interface{}, changed func(interface{}, interface{}) bool) *Var {
- v := &Var{}
- v.value.Store(customVarValue{
+func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] {
+ v := &Var[T]{}
+ v.value.Store(customVarValue[T]{
value: val,
changed: changed,
})
return v
}
-func NewBuiltinEqVar(val interface{}) *Var {
- return NewCustomVar(val, func(a, b interface{}) bool {
+func NewBuiltinEqVar[T comparable](val T) *Var[T] {
+ return NewCustomVar(val, func(a, b T) bool {
return a != b
})
}