diff options
-rw-r--r-- | README.md | 22 | ||||
-rw-r--r-- | bench_test.go | 14 | ||||
-rw-r--r-- | cmd/santa-example/main.go | 40 | ||||
-rw-r--r-- | doc.go | 16 | ||||
-rw-r--r-- | doc_test.go | 24 | ||||
-rw-r--r-- | external_test.go | 50 | ||||
-rw-r--r-- | funcs.go | 25 | ||||
-rw-r--r-- | rate/rate_test.go | 1 | ||||
-rw-r--r-- | rate/ratelimit.go | 30 | ||||
-rw-r--r-- | stm_test.go | 74 | ||||
-rw-r--r-- | stmutil/context.go | 6 | ||||
-rw-r--r-- | tx.go | 37 | ||||
-rw-r--r-- | var-value.go | 20 | ||||
-rw-r--r-- | var.go | 32 |
14 files changed, 203 insertions, 188 deletions
@@ -42,11 +42,11 @@ Be very careful when managing pointers inside transactions! (This includes slices, maps, channels, and captured variables.) Here's why: ```go -p := stm.NewVar([]byte{1,2,3}) +p := stm.NewVar[[]byte]([]byte{1,2,3}) stm.Atomically(func(tx *stm.Tx) { - b := tx.Get(p).([]byte) + b := p.Get(tx) b[0] = 7 - tx.Set(p, b) + stm.p.Set(tx, b) }) ``` @@ -57,11 +57,11 @@ Following this advice, we can rewrite the transaction to perform a copy: ```go stm.Atomically(func(tx *stm.Tx) { - b := tx.Get(p).([]byte) + b := p.Get(tx) c := make([]byte, len(b)) copy(c, b) c[0] = 7 - tx.Set(p, c) + p.Set(tx, c) }) ``` @@ -73,11 +73,11 @@ In the same vein, it would be a mistake to do this: type foo struct { i int } -p := stm.NewVar(&foo{i: 2}) +p := stm.NewVar[*foo](&foo{i: 2}) stm.Atomically(func(tx *stm.Tx) { - f := tx.Get(p).(*foo) + f := p.Get(tx) f.i = 7 - tx.Set(p, f) + stm.p.Set(tx, f) }) ``` @@ -88,11 +88,11 @@ the correct approach is to move the `Var` inside the struct: type foo struct { i *stm.Var } -f := foo{i: stm.NewVar(2)} +f := foo{i: stm.NewVar[int](2)} stm.Atomically(func(tx *stm.Tx) { - i := tx.Get(f.i).(int) + i := f.i.Get(tx).(int) i = 7 - tx.Set(f.i, i) + f.i.Set(tx, i) }) ``` diff --git a/bench_test.go b/bench_test.go index 82aaba1..0b84715 100644 --- a/bench_test.go +++ b/bench_test.go @@ -8,14 +8,14 @@ import ( ) func BenchmarkAtomicGet(b *testing.B) { - x := NewVar(0) + x := NewVar[int](0) for i := 0; i < b.N; i++ { AtomicGet(x) } } func BenchmarkAtomicSet(b *testing.B) { - x := NewVar(0) + x := NewVar[int](0) for i := 0; i < b.N; i++ { AtomicSet(x, 0) } @@ -24,16 +24,16 @@ func BenchmarkAtomicSet(b *testing.B) { func BenchmarkIncrementSTM(b *testing.B) { for i := 0; i < b.N; i++ { // spawn 1000 goroutines that each increment x by 1 - x := NewVar(0) + x := NewVar[int](0) for i := 0; i < 1000; i++ { go Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur+1) + cur := x.Get(tx) + x.Set(tx, cur+1) })) } // wait for x to reach 1000 Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 1000) + tx.Assert(x.Get(tx) == 1000) })) } } @@ -83,7 +83,7 @@ func BenchmarkReadVarSTM(b *testing.B) { for i := 0; i < b.N; i++ { var wg sync.WaitGroup wg.Add(1000) - x := NewVar(0) + x := NewVar[int](0) for i := 0; i < 1000; i++ { go func() { AtomicGet(x) diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go index 3c2b9ea..72be64f 100644 --- a/cmd/santa-example/main.go +++ b/cmd/santa-example/main.go @@ -39,15 +39,15 @@ import ( type gate struct { capacity int - remaining *stm.Var + remaining *stm.Var[int] } func (g gate) pass() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) // wait until gate can hold us tx.Assert(rem > 0) - tx.Set(g.remaining, rem-1) + g.remaining.Set(tx, rem-1) })) } @@ -56,7 +56,7 @@ func (g gate) operate() { stm.AtomicSet(g.remaining, g.capacity) // wait for gate to be full stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) tx.Assert(rem == 0) })) } @@ -64,49 +64,49 @@ func (g gate) operate() { func newGate(capacity int) gate { return gate{ capacity: capacity, - remaining: stm.NewVar(0), // gate starts out closed + remaining: stm.NewVar[int](0), // gate starts out closed } } type group struct { capacity int - remaining *stm.Var - gate1, gate2 *stm.Var + remaining *stm.Var[int] + gate1, gate2 *stm.Var[gate] } func newGroup(capacity int) *group { return &group{ capacity: capacity, - remaining: stm.NewVar(capacity), // group starts out with full capacity - gate1: stm.NewVar(newGate(capacity)), - gate2: stm.NewVar(newGate(capacity)), + remaining: stm.NewVar[int](capacity), // group starts out with full capacity + gate1: stm.NewVar[gate](newGate(capacity)), + gate2: stm.NewVar[gate](newGate(capacity)), } } func (g *group) join() (g1, g2 gate) { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) // wait until the group can hold us tx.Assert(rem > 0) - tx.Set(g.remaining, rem-1) + g.remaining.Set(tx, rem-1) // return the group's gates - g1 = tx.Get(g.gate1).(gate) - g2 = tx.Get(g.gate2).(gate) + g1 = g.gate1.Get(tx) + g2 = g.gate2.Get(tx) })) return } func (g *group) await(tx *stm.Tx) (gate, gate) { // wait for group to be empty - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) tx.Assert(rem == 0) // get the group's gates - g1 := tx.Get(g.gate1).(gate) - g2 := tx.Get(g.gate2).(gate) + g1 := g.gate1.Get(tx) + g2 := g.gate2.Get(tx) // reset group - tx.Set(g.remaining, g.capacity) - tx.Set(g.gate1, newGate(g.capacity)) - tx.Set(g.gate2, newGate(g.capacity)) + g.remaining.Set(tx, g.capacity) + g.gate1.Set(tx, newGate(g.capacity)) + g.gate2.Set(tx, newGate(g.capacity)) return g1, g2 } @@ -10,14 +10,14 @@ it non-atomic). To begin, create an STM object that wraps the data you want to access concurrently. - x := stm.NewVar(3) + x := stm.NewVar[int](3) You can then use the Atomically method to atomically read and/or write the the data. This code atomically decrements x: stm.Atomically(func(tx *stm.Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) }) An important part of STM transactions is retrying. At any point during the @@ -29,11 +29,11 @@ updated before the transaction will be rerun. As an example, this code will try to decrement x, but will block as long as x is zero: stm.Atomically(func(tx *stm.Tx) { - cur := tx.Get(x).(int) + cur := x.Get(tx) if cur == 0 { tx.Retry() } - tx.Set(x, cur-1) + x.Set(tx, cur-1) }) Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other @@ -47,13 +47,13 @@ retried. For example, this code implements the "decrement-if-nonzero" transaction above, but for two values. It will first try to decrement x, then y, and block if both values are zero. - func dec(v *stm.Var) { + func dec(v *stm.Var[int]) { return func(tx *stm.Tx) { - cur := tx.Get(v).(int) + cur := v.Get(tx) if cur == 0 { tx.Retry() } - tx.Set(v, cur-1) + v.Set(tx, cur-1) } } diff --git a/doc_test.go b/doc_test.go index f6bb863..670d07b 100644 --- a/doc_test.go +++ b/doc_test.go @@ -6,43 +6,43 @@ import ( func Example() { // create a shared variable - n := stm.NewVar(3) + n := stm.NewVar[int](3) // read a variable var v int stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - v = tx.Get(n).(int) + v = n.Get(tx) })) // or: - v = stm.AtomicGet(n).(int) + v = stm.AtomicGet(n) _ = v // write to a variable stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(n, 12) + n.Set(tx, 12) })) // or: stm.AtomicSet(n, 12) // update a variable stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) - tx.Set(n, cur-1) + cur := n.Get(tx) + n.Set(tx, cur-1) })) // block until a condition is met stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) + cur := n.Get(tx) if cur != 0 { tx.Retry() } - tx.Set(n, 10) + n.Set(tx, 10) })) // or: stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) + cur := n.Get(tx) tx.Assert(cur == 0) - tx.Set(n, 10) + n.Set(tx, 10) })) // select among multiple (potentially blocking) transactions @@ -51,11 +51,11 @@ func Example() { stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }), // this function will always succeed without blocking - stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 10) }), + stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }), // this function will never run, because the previous // function succeeded - stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 11) }), + stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }), )) // since Select is a normal transaction, if the entire select retries diff --git a/external_test.go b/external_test.go index ae29ca8..d0f5ecf 100644 --- a/external_test.go +++ b/external_test.go @@ -63,14 +63,14 @@ func BenchmarkThunderingHerd(b *testing.B) { pending := stm.NewBuiltinEqVar(0) for range iter.N(1000) { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(pending, tx.Get(pending).(int)+1) + pending.Set(tx, pending.Get(tx)+1) })) go func() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - t := tx.Get(tokens).(int) + t := tokens.Get(tx) if t > 0 { - tx.Set(tokens, t-1) - tx.Set(pending, tx.Get(pending).(int)-1) + tokens.Set(tx, t-1) + pending.Set(tx, pending.Get(tx)-1) } else { tx.Retry() } @@ -79,17 +79,17 @@ func BenchmarkThunderingHerd(b *testing.B) { } go func() { for stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(done).(bool) { + if done.Get(tx) { return false } - tx.Assert(tx.Get(tokens).(int) < maxTokens) - tx.Set(tokens, tx.Get(tokens).(int)+1) + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) return true }).(bool) { } }() stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(pending).(int) == 0) + tx.Assert(pending.Get(tx) == 0) })) stm.AtomicSet(done, true) } @@ -99,53 +99,53 @@ func BenchmarkInvertedThunderingHerd(b *testing.B) { for i := 0; i < b.N; i++ { done := stm.NewBuiltinEqVar(false) tokens := stm.NewBuiltinEqVar(0) - pending := stm.NewVar(stmutil.NewSet()) + pending := stm.NewVar[stmutil.Settish](stmutil.NewSet()) for range iter.N(1000) { - ready := stm.NewVar(false) + ready := stm.NewVar[bool](false) stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(pending, tx.Get(pending).(stmutil.Settish).Add(ready)) + pending.Set(tx, pending.Get(tx).Add(ready)) })) go func() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(ready).(bool)) - set := tx.Get(pending).(stmutil.Settish) + tx.Assert(ready.Get(tx)) + set := pending.Get(tx) if !set.Contains(ready) { panic("couldn't find ourselves in pending") } - tx.Set(pending, set.Delete(ready)) + pending.Set(tx, set.Delete(ready)) })) //b.Log("waiter finished") }() } go func() { for stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(done).(bool) { + if done.Get(tx) { return false } - tx.Assert(tx.Get(tokens).(int) < maxTokens) - tx.Set(tokens, tx.Get(tokens).(int)+1) + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) return true }).(bool) { } }() go func() { for stm.Atomically(func(tx *stm.Tx) interface{} { - tx.Assert(tx.Get(tokens).(int) > 0) - tx.Set(tokens, tx.Get(tokens).(int)-1) - tx.Get(pending).(stmutil.Settish).Range(func(i interface{}) bool { - ready := i.(*stm.Var) - if !tx.Get(ready).(bool) { - tx.Set(ready, true) + tx.Assert(tokens.Get(tx) > 0) + tokens.Set(tx, tokens.Get(tx)-1) + pending.Get(tx).Range(func(i interface{}) bool { + ready := i.(*stm.Var[bool]) + if !ready.Get(tx) { + ready.Set(tx, true) return false } return true }) - return !tx.Get(done).(bool) + return !done.Get(tx) }).(bool) { } }() stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(pending).(stmutil.Lenner).Len() == 0) + tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0) })) stm.AtomicSet(done, true) } @@ -2,7 +2,6 @@ package stm import ( "math/rand" - "reflect" "runtime/pprof" "sync" "time" @@ -12,9 +11,9 @@ var ( txPool = sync.Pool{New: func() interface{} { expvars.Add("new txs", 1) tx := &Tx{ - reads: make(map[*Var]VarValue), - writes: make(map[*Var]interface{}), - watching: make(map[*Var]struct{}), + reads: make(map[txVar]VarValue), + writes: make(map[txVar]interface{}), + watching: make(map[txVar]struct{}), } tx.cond.L = &tx.mu return tx @@ -104,12 +103,12 @@ retry: } // AtomicGet is a helper function that atomically reads a value. -func AtomicGet(v *Var) interface{} { - return v.value.Load().Get() +func AtomicGet[T any](v *Var[T]) T { + return v.value.Load().Get().(T) } // AtomicSet is a helper function that atomically writes a value. -func AtomicSet(v *Var, val interface{}) { +func AtomicSet[T any](v *Var[T], val interface{}) { v.mu.Lock() v.changeValue(val) v.mu.Unlock() @@ -140,7 +139,7 @@ func Select(fns ...Operation) Operation { return fns[0](tx) default: oldWrites := tx.writes - tx.writes = make(map[*Var]interface{}, len(oldWrites)) + tx.writes = make(map[txVar]interface{}, len(oldWrites)) for k, v := range oldWrites { tx.writes[k] = v } @@ -164,14 +163,8 @@ func VoidOperation(f func(*Tx)) Operation { } } -func AtomicModify(v *Var, f interface{}) { - r := reflect.ValueOf(f) +func AtomicModify[T any](v *Var[T], f func(T) T) { Atomically(VoidOperation(func(tx *Tx) { - cur := reflect.ValueOf(tx.Get(v)) - out := r.Call([]reflect.Value{cur}) - if lenOut := len(out); lenOut != 1 { - panic(lenOut) - } - tx.Set(v, out[0].Interface()) + v.Set(tx, f(v.Get(tx))) })) } diff --git a/rate/rate_test.go b/rate/rate_test.go index 7f44d74..a4e2fac 100644 --- a/rate/rate_test.go +++ b/rate/rate_test.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.7 // +build go1.7 package rate diff --git a/rate/ratelimit.go b/rate/ratelimit.go index a44f383..7fde1ce 100644 --- a/rate/ratelimit.go +++ b/rate/ratelimit.go @@ -13,9 +13,9 @@ import ( type numTokens = int type Limiter struct { - max *stm.Var - cur *stm.Var - lastAdd *stm.Var + max *stm.Var[numTokens] + cur *stm.Var[numTokens] + lastAdd *stm.Var[time.Time] rate Limit } @@ -36,9 +36,9 @@ func Every(interval time.Duration) Limit { func NewLimiter(rate Limit, burst numTokens) *Limiter { rl := &Limiter{ - max: stm.NewVar(burst), + max: stm.NewVar[int](burst), cur: stm.NewBuiltinEqVar(burst), - lastAdd: stm.NewVar(time.Now()), + lastAdd: stm.NewVar[time.Time](time.Now()), rate: rate, } if rate != Inf { @@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter { func (rl *Limiter) tokenGenerator(interval time.Duration) { for { - lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time) + lastAdd := stm.AtomicGet(rl.lastAdd) time.Sleep(time.Until(lastAdd.Add(interval))) now := time.Now() available := numTokens(now.Sub(lastAdd) / interval) @@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { continue } stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(rl.cur).(numTokens) - max := tx.Get(rl.max).(numTokens) + cur := rl.cur.Get(tx) + max := rl.max.Get(tx) tx.Assert(cur < max) newCur := cur + available if newCur > max { newCur = max } if newCur != cur { - tx.Set(rl.cur, newCur) + rl.cur.Set(tx, newCur) } - tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) + rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) })) } } @@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool { if rl.rate == Inf { return true } - cur := tx.Get(rl.cur).(numTokens) + cur := rl.cur.Get(tx) if cur >= n { - tx.Set(rl.cur, cur-n) + rl.cur.Set(tx, cur-n) return true } return false @@ -106,17 +106,17 @@ func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := stmutil.ContextDoneVar(ctx) defer cancel() if err := stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(ctxDone).(bool) { + if ctxDone.Get(tx) { return ctx.Err() } if rl.takeTokens(tx, n) { return nil } - if n > tx.Get(rl.max).(numTokens) { + if n > rl.max.Get(tx) { return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { - if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { + if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { return context.DeadlineExceeded } } diff --git a/stm_test.go b/stm_test.go index 5c0ae31..fdf7af2 100644 --- a/stm_test.go +++ b/stm_test.go @@ -11,17 +11,17 @@ import ( ) func TestDecrement(t *testing.T) { - x := NewVar(1000) + x := NewVar[int](1000) for i := 0; i < 500; i++ { go Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) })) } done := make(chan struct{}) go func() { Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x) == 500) + tx.Assert(x.Get(tx) == 500) })) close(done) }() @@ -35,7 +35,7 @@ func TestDecrement(t *testing.T) { // read-only transaction aren't exempt from calling tx.inputsChanged func TestReadVerify(t *testing.T) { read := make(chan struct{}) - x, y := NewVar(1), NewVar(2) + x, y := NewVar[int](1), NewVar[int](2) // spawn a transaction that writes to x go func() { @@ -50,10 +50,10 @@ func TestReadVerify(t *testing.T) { // between the reads, causing this tx to retry. var x2, y2 int Atomically(VoidOperation(func(tx *Tx) { - x2 = tx.Get(x).(int) + x2 = x.Get(tx) read <- struct{}{} <-read // wait for other tx to complete - y2 = tx.Get(y).(int) + y2 = y.Get(tx) })) if x2 == 1 && y2 == 2 { t.Fatal("read was not verified") @@ -61,15 +61,15 @@ func TestReadVerify(t *testing.T) { } func TestRetry(t *testing.T) { - x := NewVar(10) + x := NewVar[int](10) // spawn 10 transactions, one every 10 milliseconds. This will decrement x // to 0 over the course of 100 milliseconds. go func() { for i := 0; i < 10; i++ { time.Sleep(10 * time.Millisecond) Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) })) } }() @@ -77,7 +77,7 @@ func TestRetry(t *testing.T) { // retry. This should result in no more than 1 retry per transaction. retry := 0 Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) + cur := x.Get(tx) if cur != 0 { retry++ tx.Retry() @@ -93,16 +93,16 @@ func TestVerify(t *testing.T) { type foo struct { i int } - x := NewVar(&foo{3}) + x := NewVar[*foo](&foo{3}) read := make(chan struct{}) // spawn a transaction that modifies x go func() { Atomically(VoidOperation(func(tx *Tx) { <-read - rx := tx.Get(x).(*foo) + rx := x.Get(tx) rx.i = 7 - tx.Set(x, rx) + x.Set(tx, rx) })) read <- struct{}{} // other tx should retry, so we need to read/send again @@ -113,7 +113,7 @@ func TestVerify(t *testing.T) { // between the reads, causing this tx to retry. var i int Atomically(VoidOperation(func(tx *Tx) { - f := tx.Get(x).(*foo) + f := x.Get(tx) i = f.i read <- struct{}{} <-read // wait for other tx to complete @@ -128,9 +128,9 @@ func TestSelect(t *testing.T) { require.Panics(t, func() { Atomically(Select()) }) // with one arg, Select adds no effect - x := NewVar(2) + x := NewVar[int](2) Atomically(Select(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 2) + tx.Assert(x.Get(tx) == 2) }))) picked := Atomically(Select( @@ -146,7 +146,7 @@ func TestSelect(t *testing.T) { func(tx *Tx) interface{} { return 3 }, - )).(int) + )) assert.EqualValues(t, 2, picked) } @@ -178,20 +178,20 @@ func TestPanic(t *testing.T) { func TestReadWritten(t *testing.T) { // reading a variable written in the same transaction should return the // previously written value - x := NewVar(3) + x := NewVar[int](3) Atomically(VoidOperation(func(tx *Tx) { - tx.Set(x, 5) - tx.Assert(tx.Get(x).(int) == 5) + x.Set(tx, 5) + tx.Assert(x.Get(tx) == 5) })) } func TestAtomicSetRetry(t *testing.T) { // AtomicSet should cause waiting transactions to retry - x := NewVar(3) + x := NewVar[int](3) done := make(chan struct{}) go func() { Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 5) + tx.Assert(x.Get(tx) == 5) })) done <- struct{}{} }() @@ -206,21 +206,21 @@ func TestAtomicSetRetry(t *testing.T) { func testPingPong(t testing.TB, n int, afterHit func(string)) { ball := NewBuiltinEqVar(false) - doneVar := NewVar(false) - hits := NewVar(0) - ready := NewVar(true) // The ball is ready for hitting. + doneVar := NewVar[bool](false) + hits := NewVar[int](0) + ready := NewVar[bool](true) // The ball is ready for hitting. var wg sync.WaitGroup - bat := func(from, to interface{}, noise string) { + bat := func(from, to bool, noise string) { defer wg.Done() for !Atomically(func(tx *Tx) interface{} { - if tx.Get(doneVar).(bool) { + if doneVar.Get(tx) { return true } - tx.Assert(tx.Get(ready).(bool)) - if tx.Get(ball) == from { - tx.Set(ball, to) - tx.Set(hits, tx.Get(hits).(int)+1) - tx.Set(ready, false) + tx.Assert(ready.Get(tx)) + if ball.Get(tx) == from { + ball.Set(tx, to) + hits.Set(tx, hits.Get(tx)+1) + ready.Set(tx, false) return false } return tx.Retry() @@ -233,8 +233,8 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) { go bat(false, true, "ping!") go bat(true, false, "pong!") Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(hits).(int) >= n) - tx.Set(doneVar, true) + tx.Assert(hits.Get(tx) >= n) + doneVar.Set(tx, true) })) wg.Wait() } @@ -253,7 +253,7 @@ func TestSleepingBeauty(t *testing.T) { } //func TestRetryStack(t *testing.T) { -// v := NewVar(nil) +// v := NewVar[int](nil) // go func() { // i := 0 // for { @@ -266,7 +266,7 @@ func TestSleepingBeauty(t *testing.T) { // ret := func() { // defer Atomically(nil) // } -// tx.Get(v) +// v.Get(tx) // tx.Assert(false) // return ret // }) diff --git a/stmutil/context.go b/stmutil/context.go index 8a4d58d..9d23e12 100644 --- a/stmutil/context.go +++ b/stmutil/context.go @@ -9,12 +9,12 @@ import ( var ( mu sync.Mutex - ctxVars = map[context.Context]*stm.Var{} + ctxVars = map[context.Context]*stm.Var[bool]{} ) // Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be // called when the user is no longer interested in the var. -func ContextDoneVar(ctx context.Context) (*stm.Var, func()) { +func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) { mu.Lock() defer mu.Unlock() if v, ok := ctxVars[ctx]; ok { @@ -26,7 +26,7 @@ func ContextDoneVar(ctx context.Context) (*stm.Var, func()) { v := stm.NewBuiltinEqVar(true) return v, func() {} } - v := stm.NewVar(false) + v := stm.NewVar[bool](false) go func() { <-ctx.Done() stm.AtomicSet(v, true) @@ -5,13 +5,22 @@ import ( "sort" "sync" "unsafe" + + "github.com/alecthomas/atomic" ) +type txVar interface { + getValue() *atomic.Value[VarValue] + changeValue(interface{}) + getWatchers() *sync.Map + getLock() *sync.Mutex +} + // A Tx represents an atomic transaction. type Tx struct { - reads map[*Var]VarValue - writes map[*Var]interface{} - watching map[*Var]struct{} + reads map[txVar]VarValue + writes map[txVar]interface{} + watching map[txVar]struct{} locks txLocks mu sync.Mutex cond sync.Cond @@ -24,7 +33,7 @@ type Tx struct { // Check that none of the logged values have changed since the transaction began. func (tx *Tx) inputsChanged() bool { for v, read := range tx.reads { - if read.Changed(v.value.Load()) { + if read.Changed(v.getValue().Load()) { return true } } @@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() { for v := range tx.watching { if _, ok := tx.reads[v]; !ok { delete(tx.watching, v) - v.watchers.Delete(tx) + v.getWatchers().Delete(tx) } } for v := range tx.reads { if _, ok := tx.watching[v]; !ok { - v.watchers.Store(tx, nil) + v.getWatchers().Store(tx, nil) tx.watching[v] = struct{}{} } } @@ -76,22 +85,22 @@ func (tx *Tx) wait() { } // Get returns the value of v as of the start of the transaction. -func (tx *Tx) Get(v *Var) interface{} { +func (v *Var[T]) Get(tx *Tx) T { // If we previously wrote to v, it will be in the write log. if val, ok := tx.writes[v]; ok { - return val + return val.(T) } // If we haven't previously read v, record its version vv, ok := tx.reads[v] if !ok { - vv = v.value.Load() + vv = v.getValue().Load() tx.reads[v] = vv } - return vv.Get() + return vv.Get().(T) } // Set sets the value of a Var for the lifetime of the transaction. -func (tx *Tx) Set(v *Var, val interface{}) { +func (v *Var[T]) Set(tx *Tx, val T) { if v == nil { panic("nil Var") } @@ -143,7 +152,7 @@ func (tx *Tx) removeRetryProfiles() { func (tx *Tx) recycle() { for v := range tx.watching { delete(tx.watching, v) - v.watchers.Delete(tx) + v.getWatchers().Delete(tx) } tx.removeRetryProfiles() // I don't think we can reuse Txs, because the "completed" field should/needs to be set @@ -164,7 +173,7 @@ func (tx *Tx) resetLocks() { func (tx *Tx) collectReadLocks() { for v := range tx.reads { - tx.locks.append(&v.mu) + tx.locks.append(v.getLock()) } } @@ -172,7 +181,7 @@ func (tx *Tx) collectAllLocks() { tx.collectReadLocks() for v := range tx.writes { if _, ok := tx.reads[v]; !ok { - tx.locks.append(&v.mu) + tx.locks.append(v.getLock()) } } } diff --git a/var-value.go b/var-value.go index ff97104..966399a 100644 --- a/var-value.go +++ b/var-value.go @@ -28,24 +28,24 @@ func (me versionedValue) Changed(other VarValue) bool { return me.version != other.(versionedValue).version } -type customVarValue struct { - value interface{} - changed func(interface{}, interface{}) bool +type customVarValue[T any] struct { + value T + changed func(T, T) bool } -var _ VarValue = customVarValue{} +var _ VarValue = customVarValue[struct{}]{} -func (me customVarValue) Changed(other VarValue) bool { - return me.changed(me.value, other.(customVarValue).value) +func (me customVarValue[T]) Changed(other VarValue) bool { + return me.changed(me.value, other.(customVarValue[T]).value) } -func (me customVarValue) Set(newValue interface{}) VarValue { - return customVarValue{ - value: newValue, +func (me customVarValue[T]) Set(newValue interface{}) VarValue { + return customVarValue[T]{ + value: newValue.(T), changed: me.changed, } } -func (me customVarValue) Get() interface{} { +func (me customVarValue[T]) Get() interface{} { return me.value } @@ -7,13 +7,25 @@ import ( ) // Holds an STM variable. -type Var struct { +type Var[T any] struct { value atomic.Value[VarValue] watchers sync.Map mu sync.Mutex } -func (v *Var) changeValue(new interface{}) { +func (v *Var[T]) getValue() *atomic.Value[VarValue] { + return &v.value +} + +func (v *Var[T]) getWatchers() *sync.Map { + return &v.watchers +} + +func (v *Var[T]) getLock() *sync.Mutex { + return &v.mu +} + +func (v *Var[T]) changeValue(new interface{}) { old := v.value.Load() newVarValue := old.Set(new) v.value.Store(newVarValue) @@ -22,7 +34,7 @@ func (v *Var) changeValue(new interface{}) { } } -func (v *Var) wakeWatchers(new VarValue) { +func (v *Var[T]) wakeWatchers(new VarValue) { v.watchers.Range(func(k, _ interface{}) bool { tx := k.(*Tx) // We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we @@ -45,25 +57,25 @@ type varSnapshot struct { } // Returns a new STM variable. -func NewVar(val interface{}) *Var { - v := &Var{} +func NewVar[T any](val interface{}) *Var[T] { + v := &Var[T]{} v.value.Store(versionedValue{ value: val, }) return v } -func NewCustomVar(val interface{}, changed func(interface{}, interface{}) bool) *Var { - v := &Var{} - v.value.Store(customVarValue{ +func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { + v := &Var[T]{} + v.value.Store(customVarValue[T]{ value: val, changed: changed, }) return v } -func NewBuiltinEqVar(val interface{}) *Var { - return NewCustomVar(val, func(a, b interface{}) bool { +func NewBuiltinEqVar[T comparable](val T) *Var[T] { + return NewCustomVar(val, func(a, b T) bool { return a != b }) } |