diff options
-rw-r--r-- | bench_test.go | 8 | ||||
-rw-r--r-- | cmd/santa-example/main.go | 18 | ||||
-rw-r--r-- | funcs.go | 70 | ||||
-rw-r--r-- | rate/ratelimit.go | 19 | ||||
-rw-r--r-- | retry.go | 4 | ||||
-rw-r--r-- | stm_test.go | 92 | ||||
-rw-r--r-- | tx.go | 8 |
7 files changed, 98 insertions, 121 deletions
diff --git a/bench_test.go b/bench_test.go index 0d176d8..82aaba1 100644 --- a/bench_test.go +++ b/bench_test.go @@ -26,15 +26,15 @@ func BenchmarkIncrementSTM(b *testing.B) { // spawn 1000 goroutines that each increment x by 1 x := NewVar(0) for i := 0; i < 1000; i++ { - go Atomically(func(tx *Tx) { + go Atomically(VoidOperation(func(tx *Tx) { cur := tx.Get(x).(int) tx.Set(x, cur+1) - }) + })) } // wait for x to reach 1000 - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { tx.Assert(tx.Get(x).(int) == 1000) - }) + })) } } diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go index 9e6eed7..3c2b9ea 100644 --- a/cmd/santa-example/main.go +++ b/cmd/santa-example/main.go @@ -43,22 +43,22 @@ type gate struct { } func (g gate) pass() { - stm.Atomically(func(tx *stm.Tx) { + stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { rem := tx.Get(g.remaining).(int) // wait until gate can hold us tx.Assert(rem > 0) tx.Set(g.remaining, rem-1) - }) + })) } func (g gate) operate() { // open gate, reseting capacity stm.AtomicSet(g.remaining, g.capacity) // wait for gate to be full - stm.Atomically(func(tx *stm.Tx) { + stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { rem := tx.Get(g.remaining).(int) tx.Assert(rem == 0) - }) + })) } func newGate(capacity int) gate { @@ -84,7 +84,7 @@ func newGroup(capacity int) *group { } func (g *group) join() (g1, g2 gate) { - stm.Atomically(func(tx *stm.Tx) { + stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { rem := tx.Get(g.remaining).(int) // wait until the group can hold us tx.Assert(rem > 0) @@ -92,7 +92,7 @@ func (g *group) join() (g1, g2 gate) { // return the group's gates g1 = tx.Get(g.gate1).(gate) g2 = tx.Get(g.gate2).(gate) - }) + })) return } @@ -137,11 +137,11 @@ type selection struct { gate1, gate2 gate } -func chooseGroup(g *group, task string, s *selection) func(*stm.Tx) { - return func(tx *stm.Tx) { +func chooseGroup(g *group, task string, s *selection) stm.Operation { + return stm.VoidOperation(func(tx *stm.Tx) { s.gate1, s.gate2 = g.await(tx) s.task = task - } + }) } func spawnSanta(elves, reindeer *group) { @@ -30,49 +30,25 @@ func newTx() *Tx { return txPool.Get().(*Tx) } -func WouldBlock(fn func(*Tx)) (block bool) { +func WouldBlock(fn Operation) (block bool) { tx := newTx() tx.reset() - defer func() { - if r := recover(); r == Retry { - block = true - } else if _, ok := r.(_return); ok { - } else if r != nil { - panic(r) - } - }() - return catchRetry(fn, tx) + _, block = catchRetry(fn, tx) + return } // Atomically executes the atomic function fn. -func Atomically(fn func(*Tx)) interface{} { +func Atomically(op Operation) interface{} { expvars.Add("atomically", 1) // run the transaction tx := newTx() retry: tx.reset() - var ret interface{} - if func() (retry bool) { - defer func() { - r := recover() - if r == nil { - return - } - if _ret, ok := r.(_return); ok { - expvars.Add("explicit returns", 1) - ret = _ret.value - } else if r == Retry { - expvars.Add("retries", 1) - // wait for one of the variables we read to change before retrying - tx.wait() - retry = true - } else { - panic(r) - } - }() - fn(tx) - return false - }() { + ret, retry := catchRetry(op, tx) + if retry { + expvars.Add("retries", 1) + // wait for one of the variables we read to change before retrying + tx.wait() goto retry } // verify the read log @@ -107,29 +83,43 @@ func AtomicSet(v *Var, val interface{}) { // Compose is a helper function that composes multiple transactions into a // single transaction. -func Compose(fns ...func(*Tx)) func(*Tx) { - return func(tx *Tx) { +func Compose(fns ...Operation) Operation { + return func(tx *Tx) interface{} { for _, f := range fns { f(tx) } + return nil } } // Select runs the supplied functions in order. Execution stops when a // function succeeds without calling Retry. If no functions succeed, the // entire selection will be retried. -func Select(fns ...func(*Tx)) func(*Tx) { - return func(tx *Tx) { +func Select(fns ...Operation) Operation { + return func(tx *Tx) interface{} { switch len(fns) { case 0: // empty Select blocks forever tx.Retry() + panic("unreachable") case 1: - fns[0](tx) + return fns[0](tx) default: - if catchRetry(fns[0], tx) { - Select(fns[1:]...)(tx) + ret, retry := catchRetry(fns[0], tx) + if retry { + return Select(fns[1:]...)(tx) + } else { + return ret } } } } + +type Operation func(*Tx) interface{} + +func VoidOperation(f func(*Tx)) Operation { + return func(tx *Tx) interface{} { + f(tx) + return nil + } +} diff --git a/rate/ratelimit.go b/rate/ratelimit.go index 0278283..69d6932 100644 --- a/rate/ratelimit.go +++ b/rate/ratelimit.go @@ -56,7 +56,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { if available < 1 { continue } - stm.Atomically(func(tx *stm.Tx) { + stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { cur := tx.Get(rl.cur).(numTokens) max := tx.Get(rl.max).(numTokens) tx.Assert(cur < max) @@ -68,7 +68,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { tx.Set(rl.cur, newCur) } tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) - }) + })) } } @@ -77,8 +77,8 @@ func (rl *Limiter) Allow() bool { } func (rl *Limiter) AllowN(n numTokens) bool { - return stm.Atomically(func(tx *stm.Tx) { - tx.Return(rl.takeTokens(tx, n)) + return stm.Atomically(func(tx *stm.Tx) interface{} { + return rl.takeTokens(tx, n) }).(bool) } @@ -105,22 +105,23 @@ func (rl *Limiter) Wait(ctx context.Context) error { func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := stmutil.ContextDoneVar(ctx) defer cancel() - if err := stm.Atomically(func(tx *stm.Tx) { + if err := stm.Atomically(func(tx *stm.Tx) interface{} { if tx.Get(ctxDone).(bool) { - tx.Return(ctx.Err()) + return ctx.Err() } if rl.takeTokens(tx, n) { - tx.Return(nil) + return nil } if n > tx.Get(rl.max).(numTokens) { - tx.Return(errors.New("burst exceeded")) + return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { - tx.Return(context.DeadlineExceeded) + return context.DeadlineExceeded } } tx.Retry() + panic("unreachable") }); err != nil { return err.(error) } @@ -5,7 +5,7 @@ package stm const Retry = "retry" // catchRetry returns true if fn calls tx.Retry. -func catchRetry(fn func(*Tx), tx *Tx) (retry bool) { +func catchRetry(fn Operation, tx *Tx) (result interface{}, retry bool) { defer func() { if r := recover(); r == Retry { retry = true @@ -13,6 +13,6 @@ func catchRetry(fn func(*Tx), tx *Tx) (retry bool) { panic(r) } }() - fn(tx) + result = fn(tx) return } diff --git a/stm_test.go b/stm_test.go index aeb0907..ffb4011 100644 --- a/stm_test.go +++ b/stm_test.go @@ -6,21 +6,22 @@ import ( "time" _ "github.com/anacrolix/envpprof" + "github.com/stretchr/testify/assert" ) func TestDecrement(t *testing.T) { x := NewVar(1000) for i := 0; i < 500; i++ { - go Atomically(func(tx *Tx) { + go Atomically(VoidOperation(func(tx *Tx) { cur := tx.Get(x).(int) tx.Set(x, cur-1) - }) + })) } done := make(chan struct{}) go func() { - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { tx.Assert(tx.Get(x) == 500) - }) + })) close(done) }() select { @@ -47,12 +48,12 @@ func TestReadVerify(t *testing.T) { // spawn a transaction that reads x, then y. The other tx will modify x in // between the reads, causing this tx to retry. var x2, y2 int - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { x2 = tx.Get(x).(int) read <- struct{}{} <-read // wait for other tx to complete y2 = tx.Get(y).(int) - }) + })) if x2 == 1 && y2 == 2 { t.Fatal("read was not verified") } @@ -65,22 +66,22 @@ func TestRetry(t *testing.T) { go func() { for i := 0; i < 10; i++ { time.Sleep(10 * time.Millisecond) - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { cur := tx.Get(x).(int) tx.Set(x, cur-1) - }) + })) } }() // Each time we read x before the above loop has finished, we need to // retry. This should result in no more than 1 retry per transaction. retry := 0 - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { cur := tx.Get(x).(int) if cur != 0 { retry++ tx.Retry() } - }) + })) if retry > 10 { t.Fatal("should have retried at most 10 times, got", retry) } @@ -96,12 +97,12 @@ func TestVerify(t *testing.T) { // spawn a transaction that modifies x go func() { - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { <-read rx := tx.Get(x).(*foo) rx.i = 7 tx.Set(x, rx) - }) + })) read <- struct{}{} // other tx should retry, so we need to read/send again read <- <-read @@ -110,12 +111,12 @@ func TestVerify(t *testing.T) { // spawn a transaction that reads x, then y. The other tx will modify x in // between the reads, causing this tx to retry. var i int - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { f := tx.Get(x).(*foo) i = f.i read <- struct{}{} <-read // wait for other tx to complete - }) + })) if i == 3 { t.Fatal("verify did not retry despite modified Var", i) } @@ -136,37 +137,33 @@ func TestSelect(t *testing.T) { // with one arg, Select adds no effect x := NewVar(2) - Atomically(Select(func(tx *Tx) { + Atomically(Select(VoidOperation(func(tx *Tx) { tx.Assert(tx.Get(x).(int) == 2) - })) + }))) - var picked int - Atomically(Select( + picked := Atomically(Select( // always blocks; should never be selected - func(tx *Tx) { + VoidOperation(func(tx *Tx) { tx.Retry() - picked = 1 - }, + }), // always succeeds; should always be selected - func(tx *Tx) { - picked = 2 + func(tx *Tx) interface{} { + return 2 }, // always succeeds; should never be selected - func(tx *Tx) { - picked = 3 + func(tx *Tx) interface{} { + return 3 }, - )) - if picked != 2 { - t.Fatal("Select selected wrong transaction:", picked) - } + )).(int) + assert.EqualValues(t, 2, picked) } func TestCompose(t *testing.T) { nums := make([]int, 100) - fns := make([]func(*Tx), 100) + fns := make([]Operation, 100) for i := range fns { - fns[i] = func(x int) func(*Tx) { - return func(*Tx) { nums[x] = x } + fns[i] = func(x int) Operation { + return VoidOperation(func(*Tx) { nums[x] = x }) }(i) // capture loop var } Atomically(Compose(fns...)) @@ -178,14 +175,11 @@ func TestCompose(t *testing.T) { } func TestPanic(t *testing.T) { - defer func() { - if recover() == nil { - t.Fatal("expected panic, got nil") - } - }() // normal panics should escape Atomically - Atomically(func(*Tx) { - panic("foo") + assert.PanicsWithValue(t, "foo", func() { + Atomically(func(*Tx) interface{} { + panic("foo") + }) }) } @@ -193,10 +187,10 @@ func TestReadWritten(t *testing.T) { // reading a variable written in the same transaction should return the // previously written value x := NewVar(3) - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { tx.Set(x, 5) tx.Assert(tx.Get(x).(int) == 5) - }) + })) } func TestAtomicSetRetry(t *testing.T) { @@ -204,9 +198,9 @@ func TestAtomicSetRetry(t *testing.T) { x := NewVar(3) done := make(chan struct{}) go func() { - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { tx.Assert(tx.Get(x).(int) == 5) - }) + })) done <- struct{}{} }() time.Sleep(10 * time.Millisecond) @@ -226,18 +220,18 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) { var wg sync.WaitGroup bat := func(from, to interface{}, noise string) { defer wg.Done() - for !Atomically(func(tx *Tx) { + for !Atomically(func(tx *Tx) interface{} { if tx.Get(doneVar).(bool) { - tx.Return(true) + return true } tx.Assert(tx.Get(ready).(bool)) if tx.Get(ball) == from { tx.Set(ball, to) tx.Set(hits, tx.Get(hits).(int)+1) tx.Set(ready, false) - tx.Return(false) + return false } - tx.Retry() + panic(Retry) }).(bool) { afterHit(noise) AtomicSet(ready, true) @@ -246,10 +240,10 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) { wg.Add(2) go bat(false, true, "ping!") go bat(true, false, "pong!") - Atomically(func(tx *Tx) { + Atomically(VoidOperation(func(tx *Tx) { tx.Assert(tx.Get(hits).(int) >= n) tx.Set(doneVar, true) - }) + })) wg.Wait() } @@ -84,14 +84,6 @@ func (tx *Tx) Assert(p bool) { } } -func (tx *Tx) Return(v interface{}) { - panic(_return{v}) -} - -type _return struct { - value interface{} -} - func (tx *Tx) reset() { for k := range tx.reads { delete(tx.reads, k) |