aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatt Joiner <anacrolix@gmail.com>2019-11-06 16:14:47 +1100
committerMatt Joiner <anacrolix@gmail.com>2019-11-06 16:14:47 +1100
commite749ba3531cf430b66e1d3f310f53ea2972e3aa3 (patch)
tree328dbb9b2078908a8967483a0829ac4acb54cc9a
parentAdd WouldBlock (diff)
downloadstm-e749ba3531cf430b66e1d3f310f53ea2972e3aa3.tar.gz
stm-e749ba3531cf430b66e1d3f310f53ea2972e3aa3.tar.xz
Make returns explicit
-rw-r--r--bench_test.go8
-rw-r--r--cmd/santa-example/main.go18
-rw-r--r--funcs.go70
-rw-r--r--rate/ratelimit.go19
-rw-r--r--retry.go4
-rw-r--r--stm_test.go92
-rw-r--r--tx.go8
7 files changed, 98 insertions, 121 deletions
diff --git a/bench_test.go b/bench_test.go
index 0d176d8..82aaba1 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -26,15 +26,15 @@ func BenchmarkIncrementSTM(b *testing.B) {
// spawn 1000 goroutines that each increment x by 1
x := NewVar(0)
for i := 0; i < 1000; i++ {
- go Atomically(func(tx *Tx) {
+ go Atomically(VoidOperation(func(tx *Tx) {
cur := tx.Get(x).(int)
tx.Set(x, cur+1)
- })
+ }))
}
// wait for x to reach 1000
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
tx.Assert(tx.Get(x).(int) == 1000)
- })
+ }))
}
}
diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go
index 9e6eed7..3c2b9ea 100644
--- a/cmd/santa-example/main.go
+++ b/cmd/santa-example/main.go
@@ -43,22 +43,22 @@ type gate struct {
}
func (g gate) pass() {
- stm.Atomically(func(tx *stm.Tx) {
+ stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
rem := tx.Get(g.remaining).(int)
// wait until gate can hold us
tx.Assert(rem > 0)
tx.Set(g.remaining, rem-1)
- })
+ }))
}
func (g gate) operate() {
// open gate, reseting capacity
stm.AtomicSet(g.remaining, g.capacity)
// wait for gate to be full
- stm.Atomically(func(tx *stm.Tx) {
+ stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
rem := tx.Get(g.remaining).(int)
tx.Assert(rem == 0)
- })
+ }))
}
func newGate(capacity int) gate {
@@ -84,7 +84,7 @@ func newGroup(capacity int) *group {
}
func (g *group) join() (g1, g2 gate) {
- stm.Atomically(func(tx *stm.Tx) {
+ stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
rem := tx.Get(g.remaining).(int)
// wait until the group can hold us
tx.Assert(rem > 0)
@@ -92,7 +92,7 @@ func (g *group) join() (g1, g2 gate) {
// return the group's gates
g1 = tx.Get(g.gate1).(gate)
g2 = tx.Get(g.gate2).(gate)
- })
+ }))
return
}
@@ -137,11 +137,11 @@ type selection struct {
gate1, gate2 gate
}
-func chooseGroup(g *group, task string, s *selection) func(*stm.Tx) {
- return func(tx *stm.Tx) {
+func chooseGroup(g *group, task string, s *selection) stm.Operation {
+ return stm.VoidOperation(func(tx *stm.Tx) {
s.gate1, s.gate2 = g.await(tx)
s.task = task
- }
+ })
}
func spawnSanta(elves, reindeer *group) {
diff --git a/funcs.go b/funcs.go
index 83f8cf2..ceb162b 100644
--- a/funcs.go
+++ b/funcs.go
@@ -30,49 +30,25 @@ func newTx() *Tx {
return txPool.Get().(*Tx)
}
-func WouldBlock(fn func(*Tx)) (block bool) {
+func WouldBlock(fn Operation) (block bool) {
tx := newTx()
tx.reset()
- defer func() {
- if r := recover(); r == Retry {
- block = true
- } else if _, ok := r.(_return); ok {
- } else if r != nil {
- panic(r)
- }
- }()
- return catchRetry(fn, tx)
+ _, block = catchRetry(fn, tx)
+ return
}
// Atomically executes the atomic function fn.
-func Atomically(fn func(*Tx)) interface{} {
+func Atomically(op Operation) interface{} {
expvars.Add("atomically", 1)
// run the transaction
tx := newTx()
retry:
tx.reset()
- var ret interface{}
- if func() (retry bool) {
- defer func() {
- r := recover()
- if r == nil {
- return
- }
- if _ret, ok := r.(_return); ok {
- expvars.Add("explicit returns", 1)
- ret = _ret.value
- } else if r == Retry {
- expvars.Add("retries", 1)
- // wait for one of the variables we read to change before retrying
- tx.wait()
- retry = true
- } else {
- panic(r)
- }
- }()
- fn(tx)
- return false
- }() {
+ ret, retry := catchRetry(op, tx)
+ if retry {
+ expvars.Add("retries", 1)
+ // wait for one of the variables we read to change before retrying
+ tx.wait()
goto retry
}
// verify the read log
@@ -107,29 +83,43 @@ func AtomicSet(v *Var, val interface{}) {
// Compose is a helper function that composes multiple transactions into a
// single transaction.
-func Compose(fns ...func(*Tx)) func(*Tx) {
- return func(tx *Tx) {
+func Compose(fns ...Operation) Operation {
+ return func(tx *Tx) interface{} {
for _, f := range fns {
f(tx)
}
+ return nil
}
}
// Select runs the supplied functions in order. Execution stops when a
// function succeeds without calling Retry. If no functions succeed, the
// entire selection will be retried.
-func Select(fns ...func(*Tx)) func(*Tx) {
- return func(tx *Tx) {
+func Select(fns ...Operation) Operation {
+ return func(tx *Tx) interface{} {
switch len(fns) {
case 0:
// empty Select blocks forever
tx.Retry()
+ panic("unreachable")
case 1:
- fns[0](tx)
+ return fns[0](tx)
default:
- if catchRetry(fns[0], tx) {
- Select(fns[1:]...)(tx)
+ ret, retry := catchRetry(fns[0], tx)
+ if retry {
+ return Select(fns[1:]...)(tx)
+ } else {
+ return ret
}
}
}
}
+
+type Operation func(*Tx) interface{}
+
+func VoidOperation(f func(*Tx)) Operation {
+ return func(tx *Tx) interface{} {
+ f(tx)
+ return nil
+ }
+}
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
index 0278283..69d6932 100644
--- a/rate/ratelimit.go
+++ b/rate/ratelimit.go
@@ -56,7 +56,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
if available < 1 {
continue
}
- stm.Atomically(func(tx *stm.Tx) {
+ stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
cur := tx.Get(rl.cur).(numTokens)
max := tx.Get(rl.max).(numTokens)
tx.Assert(cur < max)
@@ -68,7 +68,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
tx.Set(rl.cur, newCur)
}
tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
- })
+ }))
}
}
@@ -77,8 +77,8 @@ func (rl *Limiter) Allow() bool {
}
func (rl *Limiter) AllowN(n numTokens) bool {
- return stm.Atomically(func(tx *stm.Tx) {
- tx.Return(rl.takeTokens(tx, n))
+ return stm.Atomically(func(tx *stm.Tx) interface{} {
+ return rl.takeTokens(tx, n)
}).(bool)
}
@@ -105,22 +105,23 @@ func (rl *Limiter) Wait(ctx context.Context) error {
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
- if err := stm.Atomically(func(tx *stm.Tx) {
+ if err := stm.Atomically(func(tx *stm.Tx) interface{} {
if tx.Get(ctxDone).(bool) {
- tx.Return(ctx.Err())
+ return ctx.Err()
}
if rl.takeTokens(tx, n) {
- tx.Return(nil)
+ return nil
}
if n > tx.Get(rl.max).(numTokens) {
- tx.Return(errors.New("burst exceeded"))
+ return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
- tx.Return(context.DeadlineExceeded)
+ return context.DeadlineExceeded
}
}
tx.Retry()
+ panic("unreachable")
}); err != nil {
return err.(error)
}
diff --git a/retry.go b/retry.go
index 7800ae9..92efb9e 100644
--- a/retry.go
+++ b/retry.go
@@ -5,7 +5,7 @@ package stm
const Retry = "retry"
// catchRetry returns true if fn calls tx.Retry.
-func catchRetry(fn func(*Tx), tx *Tx) (retry bool) {
+func catchRetry(fn Operation, tx *Tx) (result interface{}, retry bool) {
defer func() {
if r := recover(); r == Retry {
retry = true
@@ -13,6 +13,6 @@ func catchRetry(fn func(*Tx), tx *Tx) (retry bool) {
panic(r)
}
}()
- fn(tx)
+ result = fn(tx)
return
}
diff --git a/stm_test.go b/stm_test.go
index aeb0907..ffb4011 100644
--- a/stm_test.go
+++ b/stm_test.go
@@ -6,21 +6,22 @@ import (
"time"
_ "github.com/anacrolix/envpprof"
+ "github.com/stretchr/testify/assert"
)
func TestDecrement(t *testing.T) {
x := NewVar(1000)
for i := 0; i < 500; i++ {
- go Atomically(func(tx *Tx) {
+ go Atomically(VoidOperation(func(tx *Tx) {
cur := tx.Get(x).(int)
tx.Set(x, cur-1)
- })
+ }))
}
done := make(chan struct{})
go func() {
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
tx.Assert(tx.Get(x) == 500)
- })
+ }))
close(done)
}()
select {
@@ -47,12 +48,12 @@ func TestReadVerify(t *testing.T) {
// spawn a transaction that reads x, then y. The other tx will modify x in
// between the reads, causing this tx to retry.
var x2, y2 int
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
x2 = tx.Get(x).(int)
read <- struct{}{}
<-read // wait for other tx to complete
y2 = tx.Get(y).(int)
- })
+ }))
if x2 == 1 && y2 == 2 {
t.Fatal("read was not verified")
}
@@ -65,22 +66,22 @@ func TestRetry(t *testing.T) {
go func() {
for i := 0; i < 10; i++ {
time.Sleep(10 * time.Millisecond)
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
cur := tx.Get(x).(int)
tx.Set(x, cur-1)
- })
+ }))
}
}()
// Each time we read x before the above loop has finished, we need to
// retry. This should result in no more than 1 retry per transaction.
retry := 0
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
cur := tx.Get(x).(int)
if cur != 0 {
retry++
tx.Retry()
}
- })
+ }))
if retry > 10 {
t.Fatal("should have retried at most 10 times, got", retry)
}
@@ -96,12 +97,12 @@ func TestVerify(t *testing.T) {
// spawn a transaction that modifies x
go func() {
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
<-read
rx := tx.Get(x).(*foo)
rx.i = 7
tx.Set(x, rx)
- })
+ }))
read <- struct{}{}
// other tx should retry, so we need to read/send again
read <- <-read
@@ -110,12 +111,12 @@ func TestVerify(t *testing.T) {
// spawn a transaction that reads x, then y. The other tx will modify x in
// between the reads, causing this tx to retry.
var i int
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
f := tx.Get(x).(*foo)
i = f.i
read <- struct{}{}
<-read // wait for other tx to complete
- })
+ }))
if i == 3 {
t.Fatal("verify did not retry despite modified Var", i)
}
@@ -136,37 +137,33 @@ func TestSelect(t *testing.T) {
// with one arg, Select adds no effect
x := NewVar(2)
- Atomically(Select(func(tx *Tx) {
+ Atomically(Select(VoidOperation(func(tx *Tx) {
tx.Assert(tx.Get(x).(int) == 2)
- }))
+ })))
- var picked int
- Atomically(Select(
+ picked := Atomically(Select(
// always blocks; should never be selected
- func(tx *Tx) {
+ VoidOperation(func(tx *Tx) {
tx.Retry()
- picked = 1
- },
+ }),
// always succeeds; should always be selected
- func(tx *Tx) {
- picked = 2
+ func(tx *Tx) interface{} {
+ return 2
},
// always succeeds; should never be selected
- func(tx *Tx) {
- picked = 3
+ func(tx *Tx) interface{} {
+ return 3
},
- ))
- if picked != 2 {
- t.Fatal("Select selected wrong transaction:", picked)
- }
+ )).(int)
+ assert.EqualValues(t, 2, picked)
}
func TestCompose(t *testing.T) {
nums := make([]int, 100)
- fns := make([]func(*Tx), 100)
+ fns := make([]Operation, 100)
for i := range fns {
- fns[i] = func(x int) func(*Tx) {
- return func(*Tx) { nums[x] = x }
+ fns[i] = func(x int) Operation {
+ return VoidOperation(func(*Tx) { nums[x] = x })
}(i) // capture loop var
}
Atomically(Compose(fns...))
@@ -178,14 +175,11 @@ func TestCompose(t *testing.T) {
}
func TestPanic(t *testing.T) {
- defer func() {
- if recover() == nil {
- t.Fatal("expected panic, got nil")
- }
- }()
// normal panics should escape Atomically
- Atomically(func(*Tx) {
- panic("foo")
+ assert.PanicsWithValue(t, "foo", func() {
+ Atomically(func(*Tx) interface{} {
+ panic("foo")
+ })
})
}
@@ -193,10 +187,10 @@ func TestReadWritten(t *testing.T) {
// reading a variable written in the same transaction should return the
// previously written value
x := NewVar(3)
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
tx.Set(x, 5)
tx.Assert(tx.Get(x).(int) == 5)
- })
+ }))
}
func TestAtomicSetRetry(t *testing.T) {
@@ -204,9 +198,9 @@ func TestAtomicSetRetry(t *testing.T) {
x := NewVar(3)
done := make(chan struct{})
go func() {
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
tx.Assert(tx.Get(x).(int) == 5)
- })
+ }))
done <- struct{}{}
}()
time.Sleep(10 * time.Millisecond)
@@ -226,18 +220,18 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) {
var wg sync.WaitGroup
bat := func(from, to interface{}, noise string) {
defer wg.Done()
- for !Atomically(func(tx *Tx) {
+ for !Atomically(func(tx *Tx) interface{} {
if tx.Get(doneVar).(bool) {
- tx.Return(true)
+ return true
}
tx.Assert(tx.Get(ready).(bool))
if tx.Get(ball) == from {
tx.Set(ball, to)
tx.Set(hits, tx.Get(hits).(int)+1)
tx.Set(ready, false)
- tx.Return(false)
+ return false
}
- tx.Retry()
+ panic(Retry)
}).(bool) {
afterHit(noise)
AtomicSet(ready, true)
@@ -246,10 +240,10 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) {
wg.Add(2)
go bat(false, true, "ping!")
go bat(true, false, "pong!")
- Atomically(func(tx *Tx) {
+ Atomically(VoidOperation(func(tx *Tx) {
tx.Assert(tx.Get(hits).(int) >= n)
tx.Set(doneVar, true)
- })
+ }))
wg.Wait()
}
diff --git a/tx.go b/tx.go
index 30100bf..14003f8 100644
--- a/tx.go
+++ b/tx.go
@@ -84,14 +84,6 @@ func (tx *Tx) Assert(p bool) {
}
}
-func (tx *Tx) Return(v interface{}) {
- panic(_return{v})
-}
-
-type _return struct {
- value interface{}
-}
-
func (tx *Tx) reset() {
for k := range tx.reads {
delete(tx.reads, k)