diff options
-rw-r--r-- | bench_test.go | 151 | ||||
-rw-r--r-- | doc.go | 77 | ||||
-rw-r--r-- | doc_test.go | 77 | ||||
-rw-r--r-- | external_test.go | 151 | ||||
-rw-r--r-- | funcs.go | 169 | ||||
-rw-r--r-- | metrics.go | 7 | ||||
-rw-r--r-- | rate/rate_test.go | 480 | ||||
-rw-r--r-- | rate/ratelimit.go | 130 | ||||
-rw-r--r-- | retry.go | 24 | ||||
-rw-r--r-- | src/stm.go | 1105 | ||||
-rw-r--r-- | stm_test.go | 274 | ||||
-rw-r--r-- | stmutil/containers.go | 192 | ||||
-rw-r--r-- | stmutil/context.go | 39 | ||||
-rw-r--r-- | stmutil/context_test.go | 20 | ||||
-rw-r--r-- | tests/stm.go | 1177 | ||||
-rw-r--r-- | tx.go | 231 | ||||
-rw-r--r-- | var-value.go | 51 | ||||
-rw-r--r-- | var.go | 76 |
18 files changed, 2282 insertions, 2149 deletions
diff --git a/bench_test.go b/bench_test.go deleted file mode 100644 index 0d40caf..0000000 --- a/bench_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package stm - -import ( - "sync" - "testing" - - "github.com/anacrolix/missinggo/iter" -) - -func BenchmarkAtomicGet(b *testing.B) { - x := NewVar(0) - for i := 0; i < b.N; i++ { - AtomicGet(x) - } -} - -func BenchmarkAtomicSet(b *testing.B) { - x := NewVar(0) - for i := 0; i < b.N; i++ { - AtomicSet(x, 0) - } -} - -func BenchmarkIncrementSTM(b *testing.B) { - for i := 0; i < b.N; i++ { - // spawn 1000 goroutines that each increment x by 1 - x := NewVar(0) - for i := 0; i < 1000; i++ { - go Atomically(VoidOperation(func(tx *Tx) { - cur := x.Get(tx) - x.Set(tx, cur+1) - })) - } - // wait for x to reach 1000 - Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(x.Get(tx) == 1000) - })) - } -} - -func BenchmarkIncrementMutex(b *testing.B) { - for i := 0; i < b.N; i++ { - var mu sync.Mutex - x := 0 - for i := 0; i < 1000; i++ { - go func() { - mu.Lock() - x++ - mu.Unlock() - }() - } - for { - mu.Lock() - read := x - mu.Unlock() - if read == 1000 { - break - } - } - } -} - -func BenchmarkIncrementChannel(b *testing.B) { - for i := 0; i < b.N; i++ { - c := make(chan int, 1) - c <- 0 - for i := 0; i < 1000; i++ { - go func() { - c <- 1 + <-c - }() - } - for { - read := <-c - if read == 1000 { - break - } - c <- read - } - } -} - -func BenchmarkReadVarSTM(b *testing.B) { - for i := 0; i < b.N; i++ { - var wg sync.WaitGroup - wg.Add(1000) - x := NewVar(0) - for i := 0; i < 1000; i++ { - go func() { - AtomicGet(x) - wg.Done() - }() - } - wg.Wait() - } -} - -func BenchmarkReadVarMutex(b *testing.B) { - for i := 0; i < b.N; i++ { - var mu sync.Mutex - var wg sync.WaitGroup - wg.Add(1000) - x := 0 - for i := 0; i < 1000; i++ { - go func() { - mu.Lock() - _ = x - mu.Unlock() - wg.Done() - }() - } - wg.Wait() - } -} - -func BenchmarkReadVarChannel(b *testing.B) { - for i := 0; i < b.N; i++ { - var wg sync.WaitGroup - wg.Add(1000) - c := make(chan int) - close(c) - for i := 0; i < 1000; i++ { - go func() { - <-c - wg.Done() - }() - } - wg.Wait() - } -} - -func parallelPingPongs(b *testing.B, n int) { - var wg sync.WaitGroup - wg.Add(n) - for range iter.N(n) { - go func() { - defer wg.Done() - testPingPong(b, b.N, func(string) {}) - }() - } - wg.Wait() -} - -func BenchmarkPingPong4(b *testing.B) { - b.ReportAllocs() - parallelPingPongs(b, 4) -} - -func BenchmarkPingPong(b *testing.B) { - b.ReportAllocs() - parallelPingPongs(b, 1) -} @@ -1,77 +0,0 @@ -/* -Package stm provides Software Transactional Memory operations for Go. This is -an alternative to the standard way of writing concurrent code (channels and -mutexes). STM makes it easy to perform arbitrarily complex operations in an -atomic fashion. One of its primary advantages over traditional locking is that -STM transactions are composable, whereas locking functions are not -- the -composition will either deadlock or release the lock between functions (making -it non-atomic). - -To begin, create an STM object that wraps the data you want to access -concurrently. - - x := stm.NewVar[int](3) - -You can then use the Atomically method to atomically read and/or write the the -data. This code atomically decrements x: - - stm.Atomically(func(tx *stm.Tx) { - cur := x.Get(tx) - x.Set(tx, cur-1) - }) - -An important part of STM transactions is retrying. At any point during the -transaction, you can call tx.Retry(), which will abort the transaction, but -not cancel it entirely. The call to Atomically will block until another call -to Atomically finishes, at which point the transaction will be rerun. -Specifically, one of the values read by the transaction (via tx.Get) must be -updated before the transaction will be rerun. As an example, this code will -try to decrement x, but will block as long as x is zero: - - stm.Atomically(func(tx *stm.Tx) { - cur := x.Get(tx) - if cur == 0 { - tx.Retry() - } - x.Set(tx, cur-1) - }) - -Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other -value will cancel the transaction; no values will be changed. However, it is -the responsibility of the caller to catch such panics. - -Multiple transactions can be composed using Select. If the first transaction -calls Retry, the next transaction will be run, and so on. If all of the -transactions call Retry, the call will block and the entire selection will be -retried. For example, this code implements the "decrement-if-nonzero" -transaction above, but for two values. It will first try to decrement x, then -y, and block if both values are zero. - - func dec(v *stm.Var[int]) { - return func(tx *stm.Tx) { - cur := v.Get(tx) - if cur == 0 { - tx.Retry() - } - v.Set(tx, cur-1) - } - } - - // Note that Select does not perform any work itself, but merely - // returns a transaction function. - stm.Atomically(stm.Select(dec(x), dec(y))) - -An important caveat: transactions must be idempotent (they should have the -same effect every time they are invoked). This is because a transaction may be -retried several times before successfully completing, meaning its side effects -may execute more than once. This will almost certainly cause incorrect -behavior. One common way to get around this is to build up a list of impure -operations inside the transaction, and then perform them after the transaction -completes. - -The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but -Haskell can enforce at compile time that STM variables are not modified outside -the STM monad. This is not possible in Go, so be especially careful when using -pointers in your STM code. Remember: modifying a pointer is a side effect! -*/ -package stm diff --git a/doc_test.go b/doc_test.go deleted file mode 100644 index ae1af9a..0000000 --- a/doc_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package stm_test - -import ( - "github.com/anacrolix/stm" -) - -func Example() { - // create a shared variable - n := stm.NewVar(3) - - // read a variable - var v int - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - v = n.Get(tx) - })) - // or: - v = stm.AtomicGet(n) - _ = v - - // write to a variable - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - n.Set(tx, 12) - })) - // or: - stm.AtomicSet(n, 12) - - // update a variable - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := n.Get(tx) - n.Set(tx, cur-1) - })) - - // block until a condition is met - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := n.Get(tx) - if cur != 0 { - tx.Retry() - } - n.Set(tx, 10) - })) - // or: - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := n.Get(tx) - tx.Assert(cur == 0) - n.Set(tx, 10) - })) - - // select among multiple (potentially blocking) transactions - stm.Atomically(stm.Select( - // this function blocks forever, so it will be skipped - stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }), - - // this function will always succeed without blocking - stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }), - - // this function will never run, because the previous - // function succeeded - stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }), - )) - - // since Select is a normal transaction, if the entire select retries - // (blocks), it will be retried as a whole: - x := 0 - stm.Atomically(stm.Select( - // this function will run twice, and succeed the second time - stm.VoidOperation(func(tx *stm.Tx) { tx.Assert(x == 1) }), - - // this function will run once - stm.VoidOperation(func(tx *stm.Tx) { - x = 1 - tx.Retry() - }), - )) - // But wait! Transactions are only retried when one of the Vars they read is - // updated. Since x isn't a stm Var, this code will actually block forever -- - // but you get the idea. -} diff --git a/external_test.go b/external_test.go deleted file mode 100644 index 1291cce..0000000 --- a/external_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package stm_test - -import ( - "sync" - "testing" - - "github.com/anacrolix/missinggo/iter" - "github.com/anacrolix/stm" - "github.com/anacrolix/stm/stmutil" -) - -const maxTokens = 25 - -func BenchmarkThunderingHerdCondVar(b *testing.B) { - for i := 0; i < b.N; i++ { - var mu sync.Mutex - consumer := sync.NewCond(&mu) - generator := sync.NewCond(&mu) - done := false - tokens := 0 - var pending sync.WaitGroup - for range iter.N(1000) { - pending.Add(1) - go func() { - mu.Lock() - for { - if tokens > 0 { - tokens-- - generator.Signal() - break - } - consumer.Wait() - } - mu.Unlock() - pending.Done() - }() - } - go func() { - mu.Lock() - for !done { - if tokens < maxTokens { - tokens++ - consumer.Signal() - } else { - generator.Wait() - } - } - mu.Unlock() - }() - pending.Wait() - mu.Lock() - done = true - generator.Signal() - mu.Unlock() - } - -} - -func BenchmarkThunderingHerd(b *testing.B) { - for i := 0; i < b.N; i++ { - done := stm.NewBuiltinEqVar(false) - tokens := stm.NewBuiltinEqVar(0) - pending := stm.NewBuiltinEqVar(0) - for range iter.N(1000) { - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - pending.Set(tx, pending.Get(tx)+1) - })) - go func() { - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - t := tokens.Get(tx) - if t > 0 { - tokens.Set(tx, t-1) - pending.Set(tx, pending.Get(tx)-1) - } else { - tx.Retry() - } - })) - }() - } - go func() { - for stm.Atomically(func(tx *stm.Tx) bool { - if done.Get(tx) { - return false - } - tx.Assert(tokens.Get(tx) < maxTokens) - tokens.Set(tx, tokens.Get(tx)+1) - return true - }) { - } - }() - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(pending.Get(tx) == 0) - })) - stm.AtomicSet(done, true) - } -} - -func BenchmarkInvertedThunderingHerd(b *testing.B) { - for i := 0; i < b.N; i++ { - done := stm.NewBuiltinEqVar(false) - tokens := stm.NewBuiltinEqVar(0) - pending := stm.NewVar(stmutil.NewSet[*stm.Var[bool]]()) - for range iter.N(1000) { - ready := stm.NewVar(false) - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - pending.Set(tx, pending.Get(tx).Add(ready)) - })) - go func() { - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(ready.Get(tx)) - set := pending.Get(tx) - if !set.Contains(ready) { - panic("couldn't find ourselves in pending") - } - pending.Set(tx, set.Delete(ready)) - })) - //b.Log("waiter finished") - }() - } - go func() { - for stm.Atomically(func(tx *stm.Tx) bool { - if done.Get(tx) { - return false - } - tx.Assert(tokens.Get(tx) < maxTokens) - tokens.Set(tx, tokens.Get(tx)+1) - return true - }) { - } - }() - go func() { - for stm.Atomically(func(tx *stm.Tx) bool { - tx.Assert(tokens.Get(tx) > 0) - tokens.Set(tx, tokens.Get(tx)-1) - pending.Get(tx).Range(func(ready *stm.Var[bool]) bool { - if !ready.Get(tx) { - ready.Set(tx, true) - return false - } - return true - }) - return !done.Get(tx) - }) { - } - }() - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0) - })) - stm.AtomicSet(done, true) - } -} diff --git a/funcs.go b/funcs.go deleted file mode 100644 index 07d35ec..0000000 --- a/funcs.go +++ /dev/null @@ -1,169 +0,0 @@ -package stm - -import ( - "math/rand" - "runtime/pprof" - "sync" - "time" -) - -var ( - txPool = sync.Pool{New: func() any { - expvars.Add("new txs", 1) - tx := &Tx{ - reads: make(map[txVar]VarValue), - writes: make(map[txVar]any), - watching: make(map[txVar]struct{}), - } - tx.cond.L = &tx.mu - return tx - }} - failedCommitsProfile *pprof.Profile -) - -const ( - profileFailedCommits = false - sleepBetweenRetries = false -) - -func init() { - if profileFailedCommits { - failedCommitsProfile = pprof.NewProfile("stmFailedCommits") - } -} - -func newTx() *Tx { - tx := txPool.Get().(*Tx) - tx.tries = 0 - tx.completed = false - return tx -} - -func WouldBlock[R any](fn Operation[R]) (block bool) { - tx := newTx() - tx.reset() - _, block = catchRetry(fn, tx) - if len(tx.watching) != 0 { - panic("shouldn't have installed any watchers") - } - tx.recycle() - return -} - -// Atomically executes the atomic function fn. -func Atomically[R any](op Operation[R]) R { - expvars.Add("atomically", 1) - // run the transaction - tx := newTx() -retry: - tx.tries++ - tx.reset() - if sleepBetweenRetries { - shift := int64(tx.tries - 1) - const maxShift = 30 - if shift > maxShift { - shift = maxShift - } - ns := int64(1) << shift - d := time.Duration(rand.Int63n(ns)) - if d > 100*time.Microsecond { - tx.updateWatchers() - time.Sleep(time.Duration(ns)) - } - } - tx.mu.Lock() - ret, retry := catchRetry(op, tx) - tx.mu.Unlock() - if retry { - expvars.Add("retries", 1) - // wait for one of the variables we read to change before retrying - tx.wait() - goto retry - } - // verify the read log - tx.lockAllVars() - if tx.inputsChanged() { - tx.unlock() - expvars.Add("failed commits", 1) - if profileFailedCommits { - failedCommitsProfile.Add(new(int), 0) - } - goto retry - } - // commit the write log and broadcast that variables have changed - tx.commit() - tx.mu.Lock() - tx.completed = true - tx.cond.Broadcast() - tx.mu.Unlock() - tx.unlock() - expvars.Add("commits", 1) - tx.recycle() - return ret -} - -// AtomicGet is a helper function that atomically reads a value. -func AtomicGet[T any](v *Var[T]) T { - return v.value.Load().Get().(T) -} - -// AtomicSet is a helper function that atomically writes a value. -func AtomicSet[T any](v *Var[T], val T) { - v.mu.Lock() - v.changeValue(val) - v.mu.Unlock() -} - -// Compose is a helper function that composes multiple transactions into a -// single transaction. -func Compose[R any](fns ...Operation[R]) Operation[struct{}] { - return VoidOperation(func(tx *Tx) { - for _, f := range fns { - f(tx) - } - }) -} - -// Select runs the supplied functions in order. Execution stops when a -// function succeeds without calling Retry. If no functions succeed, the -// entire selection will be retried. -func Select[R any](fns ...Operation[R]) Operation[R] { - return func(tx *Tx) R { - switch len(fns) { - case 0: - // empty Select blocks forever - tx.Retry() - panic("unreachable") - case 1: - return fns[0](tx) - default: - oldWrites := tx.writes - tx.writes = make(map[txVar]any, len(oldWrites)) - for k, v := range oldWrites { - tx.writes[k] = v - } - ret, retry := catchRetry(fns[0], tx) - if retry { - tx.writes = oldWrites - return Select(fns[1:]...)(tx) - } else { - return ret - } - } - } -} - -type Operation[R any] func(*Tx) R - -func VoidOperation(f func(*Tx)) Operation[struct{}] { - return func(tx *Tx) struct{} { - f(tx) - return struct{}{} - } -} - -func AtomicModify[T any](v *Var[T], f func(T) T) { - Atomically(VoidOperation(func(tx *Tx) { - v.Set(tx, f(v.Get(tx))) - })) -} diff --git a/metrics.go b/metrics.go deleted file mode 100644 index b8563e2..0000000 --- a/metrics.go +++ /dev/null @@ -1,7 +0,0 @@ -package stm - -import ( - "expvar" -) - -var expvars = expvar.NewMap("stm") diff --git a/rate/rate_test.go b/rate/rate_test.go deleted file mode 100644 index 3078b4c..0000000 --- a/rate/rate_test.go +++ /dev/null @@ -1,480 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.7 -// +build go1.7 - -package rate - -import ( - "context" - "math" - "runtime" - "sync" - "sync/atomic" - "testing" - "time" -) - -func TestLimit(t *testing.T) { - if Limit(10) == Inf { - t.Errorf("Limit(10) == Inf should be false") - } -} - -func closeEnough(a, b Limit) bool { - return (math.Abs(float64(a)/float64(b)) - 1.0) < 1e-9 -} - -func TestEvery(t *testing.T) { - cases := []struct { - interval time.Duration - lim Limit - }{ - {0, Inf}, - {-1, Inf}, - {1 * time.Nanosecond, Limit(1e9)}, - {1 * time.Microsecond, Limit(1e6)}, - {1 * time.Millisecond, Limit(1e3)}, - {10 * time.Millisecond, Limit(100)}, - {100 * time.Millisecond, Limit(10)}, - {1 * time.Second, Limit(1)}, - {2 * time.Second, Limit(0.5)}, - {time.Duration(2.5 * float64(time.Second)), Limit(0.4)}, - {4 * time.Second, Limit(0.25)}, - {10 * time.Second, Limit(0.1)}, - {time.Duration(math.MaxInt64), Limit(1e9 / float64(math.MaxInt64))}, - } - for _, tc := range cases { - lim := Every(tc.interval) - if !closeEnough(lim, tc.lim) { - t.Errorf("Every(%v) = %v want %v", tc.interval, lim, tc.lim) - } - } -} - -const ( - d = 100 * time.Millisecond -) - -var ( - t0 = time.Now() - t1 = t0.Add(time.Duration(1) * d) - t2 = t0.Add(time.Duration(2) * d) - t3 = t0.Add(time.Duration(3) * d) - t4 = t0.Add(time.Duration(4) * d) - t5 = t0.Add(time.Duration(5) * d) - t9 = t0.Add(time.Duration(9) * d) -) - -type allow struct { - t time.Time - n int - ok bool -} - -// -//func run(t *testing.T, lim *Limiter, allows []allow) { -// for i, allow := range allows { -// ok := lim.AllowN(allow.t, allow.n) -// if ok != allow.ok { -// t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v", -// i, allow.t, allow.n, ok, allow.ok) -// } -// } -//} -// -//func TestLimiterBurst1(t *testing.T) { -// run(t, NewLimiter(10, 1), []allow{ -// {t0, 1, true}, -// {t0, 1, false}, -// {t0, 1, false}, -// {t1, 1, true}, -// {t1, 1, false}, -// {t1, 1, false}, -// {t2, 2, false}, // burst size is 1, so n=2 always fails -// {t2, 1, true}, -// {t2, 1, false}, -// }) -//} -// -//func TestLimiterBurst3(t *testing.T) { -// run(t, NewLimiter(10, 3), []allow{ -// {t0, 2, true}, -// {t0, 2, false}, -// {t0, 1, true}, -// {t0, 1, false}, -// {t1, 4, false}, -// {t2, 1, true}, -// {t3, 1, true}, -// {t4, 1, true}, -// {t4, 1, true}, -// {t4, 1, false}, -// {t4, 1, false}, -// {t9, 3, true}, -// {t9, 0, true}, -// }) -//} -// -//func TestLimiterJumpBackwards(t *testing.T) { -// run(t, NewLimiter(10, 3), []allow{ -// {t1, 1, true}, // start at t1 -// {t0, 1, true}, // jump back to t0, two tokens remain -// {t0, 1, true}, -// {t0, 1, false}, -// {t0, 1, false}, -// {t1, 1, true}, // got a token -// {t1, 1, false}, -// {t1, 1, false}, -// {t2, 1, true}, // got another token -// {t2, 1, false}, -// {t2, 1, false}, -// }) -//} - -// Ensure that tokensFromDuration doesn't produce -// rounding errors by truncating nanoseconds. -// See golang.org/issues/34861. -func TestLimiter_noTruncationErrors(t *testing.T) { - if !NewLimiter(0.7692307692307693, 1).Allow() { - t.Fatal("expected true") - } -} - -func TestSimultaneousRequests(t *testing.T) { - const ( - limit = 1 - burst = 5 - numRequests = 15 - ) - var ( - wg sync.WaitGroup - numOK = uint32(0) - ) - - // Very slow replenishing bucket. - lim := NewLimiter(limit, burst) - - // Tries to take a token, atomically updates the counter and decreases the wait - // group counter. - f := func() { - defer wg.Done() - if ok := lim.Allow(); ok { - atomic.AddUint32(&numOK, 1) - } - } - - wg.Add(numRequests) - for i := 0; i < numRequests; i++ { - go f() - } - wg.Wait() - if numOK != burst { - t.Errorf("numOK = %d, want %d", numOK, burst) - } -} - -func TestLongRunningQPS(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - if runtime.GOOS == "openbsd" { - t.Skip("low resolution time.Sleep invalidates test (golang.org/issue/14183)") - return - } - - // The test runs for a few seconds executing many requests and then checks - // that overall number of requests is reasonable. - const ( - limit = 100 - burst = 100 - ) - var numOK = int32(0) - - lim := NewLimiter(limit, burst) - - var wg sync.WaitGroup - f := func() { - if ok := lim.Allow(); ok { - atomic.AddInt32(&numOK, 1) - } - wg.Done() - } - - start := time.Now() - end := start.Add(5 * time.Second) - for time.Now().Before(end) { - wg.Add(1) - go f() - - // This will still offer ~500 requests per second, but won't consume - // outrageous amount of CPU. - time.Sleep(2 * time.Millisecond) - } - wg.Wait() - elapsed := time.Since(start) - ideal := burst + (limit * float64(elapsed) / float64(time.Second)) - - // We should never get more requests than allowed. - if want := int32(ideal + 1); numOK > want { - t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) - } - // We should get very close to the number of requests allowed. - if want := int32(0.999 * ideal); numOK < want { - t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) - } -} - -type request struct { - t time.Time - n int - act time.Time - ok bool -} - -// dFromDuration converts a duration to a multiple of the global constant d -func dFromDuration(dur time.Duration) int { - // Adding a millisecond to be swallowed by the integer division - // because we don't care about small inaccuracies - return int((dur + time.Millisecond) / d) -} - -// dSince returns multiples of d since t0 -func dSince(t time.Time) int { - return dFromDuration(t.Sub(t0)) -} - -// -//func runReserve(t *testing.T, lim *Limiter, req request) *Reservation { -// return runReserveMax(t, lim, req, InfDuration) -//} -// -//func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation { -// r := lim.reserveN(req.t, req.n, maxReserve) -// if r.ok && (dSince(r.timeToAct) != dSince(req.act)) || r.ok != req.ok { -// t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)", -// dSince(req.t), req.n, maxReserve, dSince(r.timeToAct), r.ok, dSince(req.act), req.ok) -// } -// return &r -//} -// -//func TestSimpleReserve(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// runReserve(t, lim, request{t0, 2, t2, true}) -// runReserve(t, lim, request{t3, 2, t4, true}) -//} -// -//func TestMix(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst -// runReserve(t, lim, request{t0, 2, t0, true}) -// run(t, lim, []allow{{t1, 2, false}}) // not enought tokens - don't allow -// runReserve(t, lim, request{t1, 2, t2, true}) -// run(t, lim, []allow{{t1, 1, false}}) // negative tokens - don't allow -// run(t, lim, []allow{{t3, 1, true}}) -//} -// -//func TestCancelInvalid(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 3, t3, false}) -// r.CancelAt(t0) // should have no effect -// runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens -//} -// -//func TestCancelLast(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 2, t2, true}) -// r.CancelAt(t1) // got 2 tokens back -// runReserve(t, lim, request{t1, 2, t2, true}) -//} -// -//func TestCancelTooLate(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 2, t2, true}) -// r.CancelAt(t3) // too late to cancel - should have no effect -// runReserve(t, lim, request{t3, 2, t4, true}) -//} -// -//func TestCancel0Tokens(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 1, t1, true}) -// runReserve(t, lim, request{t0, 1, t2, true}) -// r.CancelAt(t0) // got 0 tokens back -// runReserve(t, lim, request{t0, 1, t3, true}) -//} -// -//func TestCancel1Token(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 2, t2, true}) -// runReserve(t, lim, request{t0, 1, t3, true}) -// r.CancelAt(t2) // got 1 token back -// runReserve(t, lim, request{t2, 2, t4, true}) -//} -// -//func TestCancelMulti(t *testing.T) { -// lim := NewLimiter(10, 4) -// -// runReserve(t, lim, request{t0, 4, t0, true}) -// rA := runReserve(t, lim, request{t0, 3, t3, true}) -// runReserve(t, lim, request{t0, 1, t4, true}) -// rC := runReserve(t, lim, request{t0, 1, t5, true}) -// rC.CancelAt(t1) // get 1 token back -// rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved -// runReserve(t, lim, request{t1, 3, t5, true}) -//} -// -//func TestReserveJumpBack(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 -// runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst -// runReserve(t, lim, request{t2, 2, t3, true}) -//} - -//func TestReserveJumpBackCancel(t *testing.T) { -// lim := NewLimiter(10, 2) -// -// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 -// r := runReserve(t, lim, request{t1, 2, t3, true}) -// runReserve(t, lim, request{t1, 1, t4, true}) -// r.CancelAt(t0) // cancel at t0, get 1 token back -// runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst -//} -// -//func TestReserveSetLimit(t *testing.T) { -// lim := NewLimiter(5, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// runReserve(t, lim, request{t0, 2, t4, true}) -// lim.SetLimitAt(t2, 10) -// runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst -//} -// -//func TestReserveSetBurst(t *testing.T) { -// lim := NewLimiter(5, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// runReserve(t, lim, request{t0, 2, t4, true}) -// lim.SetBurstAt(t3, 4) -// runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst -//} -// -//func TestReserveSetLimitCancel(t *testing.T) { -// lim := NewLimiter(5, 2) -// -// runReserve(t, lim, request{t0, 2, t0, true}) -// r := runReserve(t, lim, request{t0, 2, t4, true}) -// lim.SetLimitAt(t2, 10) -// r.CancelAt(t2) // 2 tokens back -// runReserve(t, lim, request{t2, 2, t3, true}) -//} -// -//func TestReserveMax(t *testing.T) { -// lim := NewLimiter(10, 2) -// maxT := d -// -// runReserveMax(t, lim, request{t0, 2, t0, true}, maxT) -// runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future -// runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future -//} - -type wait struct { - name string - ctx context.Context - n int - delay int // in multiples of d - nilErr bool -} - -func runWait(t *testing.T, lim *Limiter, w wait) { - t.Helper() - start := time.Now() - err := lim.WaitN(w.ctx, w.n) - delay := time.Since(start) - if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) { - errString := "<nil>" - if !w.nilErr { - errString = "<non-nil error>" - } - t.Errorf("lim.WaitN(%v, lim, %v) = %v with delay %v ; want %v with delay %v", - w.name, w.n, err, delay, errString, d*time.Duration(w.delay)) - } -} - -func TestWaitSimple(t *testing.T) { - lim := NewLimiter(10, 3) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - runWait(t, lim, wait{"already-cancelled", ctx, 1, 0, false}) - - runWait(t, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false}) - - runWait(t, lim, wait{"act-now", context.Background(), 2, 0, true}) - runWait(t, lim, wait{"act-later", context.Background(), 3, 2, true}) -} - -func TestWaitCancel(t *testing.T) { - lim := NewLimiter(10, 3) - - ctx, cancel := context.WithCancel(context.Background()) - runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1 - go func() { - time.Sleep(d) - cancel() - }() - runWait(t, lim, wait{"will-cancel", ctx, 3, 1, false}) - // should get 3 tokens back, and have lim.tokens = 2 - //t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent) - runWait(t, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true}) -} - -func TestWaitTimeout(t *testing.T) { - lim := NewLimiter(10, 3) - - ctx, cancel := context.WithTimeout(context.Background(), d) - defer cancel() - runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) - runWait(t, lim, wait{"w-timeout-err", ctx, 3, 0, false}) -} - -func TestWaitInf(t *testing.T) { - lim := NewLimiter(Inf, 0) - - runWait(t, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true}) -} - -func BenchmarkAllowN(b *testing.B) { - lim := NewLimiter(Every(1*time.Second), 1) - b.ReportAllocs() - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - lim.AllowN(1) - } - }) -} - -func BenchmarkWaitNNoDelay(b *testing.B) { - lim := NewLimiter(Limit(b.N), b.N) - ctx := context.Background() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - lim.WaitN(ctx, 1) - } -} diff --git a/rate/ratelimit.go b/rate/ratelimit.go deleted file mode 100644 index f521f66..0000000 --- a/rate/ratelimit.go +++ /dev/null @@ -1,130 +0,0 @@ -package rate - -import ( - "context" - "errors" - "math" - "time" - - "github.com/anacrolix/stm" - "github.com/anacrolix/stm/stmutil" -) - -type numTokens = int - -type Limiter struct { - max *stm.Var[numTokens] - cur *stm.Var[numTokens] - lastAdd *stm.Var[time.Time] - rate Limit -} - -const Inf = Limit(math.MaxFloat64) - -type Limit float64 - -func (l Limit) interval() time.Duration { - return time.Duration(Limit(1*time.Second) / l) -} - -func Every(interval time.Duration) Limit { - if interval == 0 { - return Inf - } - return Limit(time.Second / interval) -} - -func NewLimiter(rate Limit, burst numTokens) *Limiter { - rl := &Limiter{ - max: stm.NewVar(burst), - cur: stm.NewBuiltinEqVar(burst), - lastAdd: stm.NewVar(time.Now()), - rate: rate, - } - if rate != Inf { - go rl.tokenGenerator(rate.interval()) - } - return rl -} - -func (rl *Limiter) tokenGenerator(interval time.Duration) { - for { - lastAdd := stm.AtomicGet(rl.lastAdd) - time.Sleep(time.Until(lastAdd.Add(interval))) - now := time.Now() - available := numTokens(now.Sub(lastAdd) / interval) - if available < 1 { - continue - } - stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := rl.cur.Get(tx) - max := rl.max.Get(tx) - tx.Assert(cur < max) - newCur := cur + available - if newCur > max { - newCur = max - } - if newCur != cur { - rl.cur.Set(tx, newCur) - } - rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) - })) - } -} - -func (rl *Limiter) Allow() bool { - return rl.AllowN(1) -} - -func (rl *Limiter) AllowN(n numTokens) bool { - return stm.Atomically(func(tx *stm.Tx) bool { - return rl.takeTokens(tx, n) - }) -} - -func (rl *Limiter) AllowStm(tx *stm.Tx) bool { - return rl.takeTokens(tx, 1) -} - -func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool { - if rl.rate == Inf { - return true - } - cur := rl.cur.Get(tx) - if cur >= n { - rl.cur.Set(tx, cur-n) - return true - } - return false -} - -func (rl *Limiter) Wait(ctx context.Context) error { - return rl.WaitN(ctx, 1) -} - -func (rl *Limiter) WaitN(ctx context.Context, n int) error { - ctxDone, cancel := stmutil.ContextDoneVar(ctx) - defer cancel() - if err := stm.Atomically(func(tx *stm.Tx) error { - if ctxDone.Get(tx) { - return ctx.Err() - } - if rl.takeTokens(tx, n) { - return nil - } - if n > rl.max.Get(tx) { - return errors.New("burst exceeded") - } - if dl, ok := ctx.Deadline(); ok { - if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { - return context.DeadlineExceeded - } - } - tx.Retry() - panic("unreachable") - }); err != nil { - return err - } - return nil - -} diff --git a/retry.go b/retry.go deleted file mode 100644 index 1adcfd0..0000000 --- a/retry.go +++ /dev/null @@ -1,24 +0,0 @@ -package stm - -import ( - "runtime/pprof" -) - -var retries = pprof.NewProfile("stmRetries") - -// retry is a sentinel value. When thrown via panic, it indicates that a -// transaction should be retried. -var retry = &struct{}{} - -// catchRetry returns true if fn calls tx.Retry. -func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) { - defer func() { - if r := recover(); r == retry { - gotRetry = true - } else if r != nil { - panic(r) - } - }() - result = fn(tx) - return -} @@ -1,9 +1,1114 @@ +/// Package stm provides Software Transactional Memory operations for Go. This +/// is an alternative to the standard way of writing concurrent code (channels +/// and mutexes). STM makes it easy to perform arbitrarily complex operations +/// in an atomic fashion. One of its primary advantages over traditional +/// locking is that STM transactions are composable, whereas locking functions +/// are not -- the composition will either deadlock or release the lock between +/// functions (making it non-atomic). +/// +/// To begin, create an STM object that wraps the data you want to access +/// concurrently. +/// +/// x := stm.NewVar[int](3) +/// +/// You can then use the Atomically method to atomically read and/or write the +/// the data. This code atomically decrements x: +/// +/// stm.Atomically(func(tx *stm.Tx) { +/// cur := x.Get(tx) +/// x.Set(tx, cur-1) +/// }) +/// +/// An important part of STM transactions is retrying. At any point during the +/// transaction, you can call tx.Retry(), which will abort the transaction, but +/// not cancel it entirely. The call to Atomically will block until another +/// call to Atomically finishes, at which point the transaction will be rerun. +/// Specifically, one of the values read by the transaction (via tx.Get) must be +/// updated before the transaction will be rerun. As an example, this code will +/// try to decrement x, but will block as long as x is zero: +/// +/// stm.Atomically(func(tx *stm.Tx) { +/// cur := x.Get(tx) +/// if cur == 0 { +/// tx.Retry() +/// } +/// x.Set(tx, cur-1) +/// }) +/// +/// Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any +/// other value will cancel the transaction; no values will be changed. +/// However, it is the responsibility of the caller to catch such panics. +/// +/// Multiple transactions can be composed using Select. If the first +/// transaction calls Retry, the next transaction will be run, and so on. If +/// all of the transactions call Retry, the call will block and the entire +/// selection will be retried. For example, this code implements the +/// "decrement-if-nonzero" transaction above, but for two values. It will +/// first try to decrement x, then y, and block if both values are zero. +/// +/// func dec(v *stm.Var[int]) { +/// return func(tx *stm.Tx) { +/// cur := v.Get(tx) +/// if cur == 0 { +/// tx.Retry() +/// } +/// v.Set(tx, cur-1) +/// } +/// } +/// +/// // Note that Select does not perform any work itself, but merely +/// // returns a transaction function. +/// stm.Atomically(stm.Select(dec(x), dec(y))) +/// +/// An important caveat: transactions must be idempotent (they should have the +/// same effect every time they are invoked). This is because a transaction may +/// be retried several times before successfully completing, meaning its side +/// effects may execute more than once. This will almost certainly cause +/// incorrect behavior. One common way to get around this is to build up a list +/// of impure operations inside the transaction, and then perform them after the +/// transaction completes. +/// +/// The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but +/// Haskell can enforce at compile time that STM variables are not modified +/// outside the STM monad. This is not possible in Go, so be especially careful +/// when using pointers in your STM code. Remember: modifying a pointer is a +/// side effect! package stm import ( + "context" + "errors" + "expvar" + "fmt" + "math" + "math/rand" + "runtime/pprof" + "sort" + "sync" + "sync/atomic" + "time" + "unsafe" + + "pds" ) +// Package atomic contains type-safe atomic types. +// +// The zero value for the numeric types cannot be used. Use New*. The +// rationale for this behaviour is that copying an atomic integer is not +// reliable. Copying can be prevented by embedding sync.Mutex, but that bloats +// the type. + +// Interface represents atomic operations on a value. +type Interface[T any] interface { + // Load value atomically. + Load() T + // Store value atomically. + Store(value T) + // Swap the previous value with the new value atomically. + Swap(new T) (old T) + // CompareAndSwap the previous value with new if its value is "old". + CompareAndSwap(old, new T) (swapped bool) +} + +var _ Interface[bool] = &Value[bool]{} + +// Value wraps any generic value in atomic load and store operations. +// +// The zero value should be initialised using [Value.Store] before use. +type Value[T any] struct { + value atomic.Value +} + +// New atomic Value. +func New[T any](seed T) *Value[T] { + v := &Value[T]{} + v.value.Store(seed) + return v +} + +// Load value atomically. +// +// Will panic if the value is nil. +func (v *Value[T]) Load() (out T) { + value := v.value.Load() + if value == nil { + panic("nil value in atomic.Value") + } + return value.(T) +} +func (v *Value[T]) Store(value T) { v.value.Store(value) } +func (v *Value[T]) Swap(new T) (old T) { return v.value.Swap(new).(T) } +func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { return v.value.CompareAndSwap(old, new) } + +// atomicint defines the types that atomic integer operations are supported on. +type atomicint interface { + int32 | uint32 | int64 | uint64 +} + +// Int expresses atomic operations on signed or unsigned integer values. +type Int[T atomicint] interface { + Interface[T] + // Add a value and return the new result. + Add(delta T) (new T) +} + +// Currently not supported by Go's generic type system: +// +// ./atomic.go:48:9: cannot use type switch on type parameter value v (variable of type T constrained by atomicint) +// +// // ForInt infers and creates an atomic Int[T] type for a value. +// func ForInt[T atomicint](v T) Int[T] { +// switch v.(type) { +// case int32: +// return NewInt32(v) +// case uint32: +// return NewUint32(v) +// case int64: +// return NewInt64(v) +// case uint64: +// return NewUint64(v) +// } +// panic("can't happen") +// } + +// Int32 atomic value. +// +// Copying creates an alias. The zero value is not usable, use NewInt32. +type Int32 struct{ value *int32 } + +// NewInt32 creates a new atomic integer with an initial value. +func NewInt32(value int32) Int32 { return Int32{value: &value} } + +var _ Int[int32] = &Int32{} + +func (i Int32) Add(delta int32) (new int32) { return atomic.AddInt32(i.value, delta) } +func (i Int32) Load() (val int32) { return atomic.LoadInt32(i.value) } +func (i Int32) Store(val int32) { atomic.StoreInt32(i.value, val) } +func (i Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(i.value, new) } +func (i Int32) CompareAndSwap(old, new int32) (swapped bool) { + return atomic.CompareAndSwapInt32(i.value, old, new) +} + +// Uint32 atomic value. +// +// Copying creates an alias. +type Uint32 struct{ value *uint32 } + +var _ Int[uint32] = Uint32{} + +// NewUint32 creates a new atomic integer with an initial value. +func NewUint32(value uint32) Uint32 { return Uint32{value: &value} } + +func (i Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(i.value, delta) } +func (i Uint32) Load() (val uint32) { return atomic.LoadUint32(i.value) } +func (i Uint32) Store(val uint32) { atomic.StoreUint32(i.value, val) } +func (i Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(i.value, new) } +func (i Uint32) CompareAndSwap(old, new uint32) (swapped bool) { + return atomic.CompareAndSwapUint32(i.value, old, new) +} + +// Int64 atomic value. +// +// Copying creates an alias. +type Int64 struct{ value *int64 } + +var _ Int[int64] = Int64{} + +// NewInt64 creates a new atomic integer with an initial value. +func NewInt64(value int64) Int64 { return Int64{value: &value} } + +func (i Int64) Add(delta int64) (new int64) { return atomic.AddInt64(i.value, delta) } +func (i Int64) Load() (val int64) { return atomic.LoadInt64(i.value) } +func (i Int64) Store(val int64) { atomic.StoreInt64(i.value, val) } +func (i Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(i.value, new) } +func (i Int64) CompareAndSwap(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(i.value, old, new) +} + +// Uint64 atomic value. +// +// Copying creates an alias. +type Uint64 struct{ value *uint64 } + +var _ Int[uint64] = Uint64{} + +// NewUint64 creates a new atomic integer with an initial value. +func NewUint64(value uint64) Uint64 { return Uint64{value: &value} } + +func (i Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(i.value, delta) } +func (i Uint64) Load() (val uint64) { return atomic.LoadUint64(i.value) } +func (i Uint64) Store(val uint64) { atomic.StoreUint64(i.value, val) } +func (i Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(i.value, new) } +func (i Uint64) CompareAndSwap(old, new uint64) (swapped bool) { + return atomic.CompareAndSwapUint64(i.value, old, new) +} + + func F() { } + +var ( + txPool = sync.Pool{New: func() any { + expvars.Add("new txs", 1) + tx := &Tx{ + reads: make(map[txVar]VarValue), + writes: make(map[txVar]any), + watching: make(map[txVar]struct{}), + } + tx.cond.L = &tx.mu + return tx + }} + failedCommitsProfile *pprof.Profile +) + +const ( + profileFailedCommits = false + sleepBetweenRetries = false +) + +func init() { + if profileFailedCommits { + failedCommitsProfile = pprof.NewProfile("stmFailedCommits") + } +} + +func newTx() *Tx { + tx := txPool.Get().(*Tx) + tx.tries = 0 + tx.completed = false + return tx +} + +func WouldBlock[R any](fn Operation[R]) (block bool) { + tx := newTx() + tx.reset() + _, block = catchRetry(fn, tx) + if len(tx.watching) != 0 { + panic("shouldn't have installed any watchers") + } + tx.recycle() + return +} + +// Atomically executes the atomic function fn. +func Atomically[R any](op Operation[R]) R { + expvars.Add("atomically", 1) + // run the transaction + tx := newTx() +retry: + tx.tries++ + tx.reset() + if sleepBetweenRetries { + shift := int64(tx.tries - 1) + const maxShift = 30 + if shift > maxShift { + shift = maxShift + } + ns := int64(1) << shift + d := time.Duration(rand.Int63n(ns)) + if d > 100*time.Microsecond { + tx.updateWatchers() + time.Sleep(time.Duration(ns)) + } + } + tx.mu.Lock() + ret, retry := catchRetry(op, tx) + tx.mu.Unlock() + if retry { + expvars.Add("retries", 1) + // wait for one of the variables we read to change before retrying + tx.wait() + goto retry + } + // verify the read log + tx.lockAllVars() + if tx.inputsChanged() { + tx.unlock() + expvars.Add("failed commits", 1) + if profileFailedCommits { + failedCommitsProfile.Add(new(int), 0) + } + goto retry + } + // commit the write log and broadcast that variables have changed + tx.commit() + tx.mu.Lock() + tx.completed = true + tx.cond.Broadcast() + tx.mu.Unlock() + tx.unlock() + expvars.Add("commits", 1) + tx.recycle() + return ret +} + +// AtomicGet is a helper function that atomically reads a value. +func AtomicGet[T any](v *Var[T]) T { + return v.value.Load().Get().(T) +} + +// AtomicSet is a helper function that atomically writes a value. +func AtomicSet[T any](v *Var[T], val T) { + v.mu.Lock() + v.changeValue(val) + v.mu.Unlock() +} + +// Compose is a helper function that composes multiple transactions into a +// single transaction. +func Compose[R any](fns ...Operation[R]) Operation[struct{}] { + return VoidOperation(func(tx *Tx) { + for _, f := range fns { + f(tx) + } + }) +} + +// Select runs the supplied functions in order. Execution stops when a +// function succeeds without calling Retry. If no functions succeed, the +// entire selection will be retried. +func Select[R any](fns ...Operation[R]) Operation[R] { + return func(tx *Tx) R { + switch len(fns) { + case 0: + // empty Select blocks forever + tx.Retry() + panic("unreachable") + case 1: + return fns[0](tx) + default: + oldWrites := tx.writes + tx.writes = make(map[txVar]any, len(oldWrites)) + for k, v := range oldWrites { + tx.writes[k] = v + } + ret, retry := catchRetry(fns[0], tx) + if retry { + tx.writes = oldWrites + return Select(fns[1:]...)(tx) + } else { + return ret + } + } + } +} + +type Operation[R any] func(*Tx) R + +func VoidOperation(f func(*Tx)) Operation[struct{}] { + return func(tx *Tx) struct{} { + f(tx) + return struct{}{} + } +} + +func AtomicModify[T any](v *Var[T], f func(T) T) { + Atomically(VoidOperation(func(tx *Tx) { + v.Set(tx, f(v.Get(tx))) + })) +} + +var expvars = expvar.NewMap("stm") + +type VarValue interface { + Set(any) VarValue + Get() any + Changed(VarValue) bool +} + +type version uint64 + +type versionedValue[T any] struct { + value T + version version +} + +func (me versionedValue[T]) Set(newValue any) VarValue { + return versionedValue[T]{ + value: newValue.(T), + version: me.version + 1, + } +} + +func (me versionedValue[T]) Get() any { + return me.value +} + +func (me versionedValue[T]) Changed(other VarValue) bool { + return me.version != other.(versionedValue[T]).version +} + +type customVarValue[T any] struct { + value T + changed func(T, T) bool +} + +var _ VarValue = customVarValue[struct{}]{} + +func (me customVarValue[T]) Changed(other VarValue) bool { + return me.changed(me.value, other.(customVarValue[T]).value) +} + +func (me customVarValue[T]) Set(newValue any) VarValue { + return customVarValue[T]{ + value: newValue.(T), + changed: me.changed, + } +} + +func (me customVarValue[T]) Get() any { + return me.value +} + +type txVar interface { + getValue() *Value[VarValue] + changeValue(any) + getWatchers() *sync.Map + getLock() *sync.Mutex +} + +// A Tx represents an atomic transaction. +type Tx struct { + reads map[txVar]VarValue + writes map[txVar]any + watching map[txVar]struct{} + locks txLocks + mu sync.Mutex + cond sync.Cond + waiting bool + completed bool + tries int + numRetryValues int +} + +// Check that none of the logged values have changed since the transaction began. +func (tx *Tx) inputsChanged() bool { + for v, read := range tx.reads { + if read.Changed(v.getValue().Load()) { + return true + } + } + return false +} + +// Writes the values in the transaction log to their respective Vars. +func (tx *Tx) commit() { + for v, val := range tx.writes { + v.changeValue(val) + } +} + +func (tx *Tx) updateWatchers() { + for v := range tx.watching { + if _, ok := tx.reads[v]; !ok { + delete(tx.watching, v) + v.getWatchers().Delete(tx) + } + } + for v := range tx.reads { + if _, ok := tx.watching[v]; !ok { + v.getWatchers().Store(tx, nil) + tx.watching[v] = struct{}{} + } + } +} + +// wait blocks until another transaction modifies any of the Vars read by tx. +func (tx *Tx) wait() { + if len(tx.reads) == 0 { + panic("not waiting on anything") + } + tx.updateWatchers() + tx.mu.Lock() + firstWait := true + for !tx.inputsChanged() { + if !firstWait { + expvars.Add("wakes for unchanged versions", 1) + } + expvars.Add("waits", 1) + tx.waiting = true + tx.cond.Broadcast() + tx.cond.Wait() + tx.waiting = false + firstWait = false + } + tx.mu.Unlock() +} + +// Get returns the value of v as of the start of the transaction. +func (v *Var[T]) Get(tx *Tx) T { + // If we previously wrote to v, it will be in the write log. + if val, ok := tx.writes[v]; ok { + return val.(T) + } + // If we haven't previously read v, record its version + vv, ok := tx.reads[v] + if !ok { + vv = v.getValue().Load() + tx.reads[v] = vv + } + return vv.Get().(T) +} + +// Set sets the value of a Var for the lifetime of the transaction. +func (v *Var[T]) Set(tx *Tx, val T) { + if v == nil { + panic("nil Var") + } + tx.writes[v] = val +} + +type txProfileValue struct { + *Tx + int +} + +// Retry aborts the transaction and retries it when a Var changes. You can return from this method +// to satisfy return values, but it should never actually return anything as it panics internally. +func (tx *Tx) Retry() struct{} { + retries.Add(txProfileValue{tx, tx.numRetryValues}, 1) + tx.numRetryValues++ + panic(retry) +} + +// Assert is a helper function that retries a transaction if the condition is +// not satisfied. +func (tx *Tx) Assert(p bool) { + if !p { + tx.Retry() + } +} + +func (tx *Tx) reset() { + tx.mu.Lock() + for k := range tx.reads { + delete(tx.reads, k) + } + for k := range tx.writes { + delete(tx.writes, k) + } + tx.mu.Unlock() + tx.removeRetryProfiles() + tx.resetLocks() +} + +func (tx *Tx) removeRetryProfiles() { + for tx.numRetryValues > 0 { + tx.numRetryValues-- + retries.Remove(txProfileValue{tx, tx.numRetryValues}) + } +} + +func (tx *Tx) recycle() { + for v := range tx.watching { + delete(tx.watching, v) + v.getWatchers().Delete(tx) + } + tx.removeRetryProfiles() + // I don't think we can reuse Txs, because the "completed" field should/needs to be set + // indefinitely after use. + //txPool.Put(tx) +} + +func (tx *Tx) lockAllVars() { + tx.resetLocks() + tx.collectAllLocks() + tx.sortLocks() + tx.lock() +} + +func (tx *Tx) resetLocks() { + tx.locks.clear() +} + +func (tx *Tx) collectReadLocks() { + for v := range tx.reads { + tx.locks.append(v.getLock()) + } +} + +func (tx *Tx) collectAllLocks() { + tx.collectReadLocks() + for v := range tx.writes { + if _, ok := tx.reads[v]; !ok { + tx.locks.append(v.getLock()) + } + } +} + +func (tx *Tx) sortLocks() { + sort.Sort(&tx.locks) +} + +func (tx *Tx) lock() { + for _, l := range tx.locks.mus { + l.Lock() + } +} + +func (tx *Tx) unlock() { + for _, l := range tx.locks.mus { + l.Unlock() + } +} + +func (tx *Tx) String() string { + return fmt.Sprintf("%[1]T %[1]p", tx) +} + +// Dedicated type avoids reflection in sort.Slice. +type txLocks struct { + mus []*sync.Mutex +} + +func (me txLocks) Len() int { + return len(me.mus) +} + +func (me txLocks) Less(i, j int) bool { + return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j])) +} + +func (me txLocks) Swap(i, j int) { + me.mus[i], me.mus[j] = me.mus[j], me.mus[i] +} + +func (me *txLocks) clear() { + me.mus = me.mus[:0] +} + +func (me *txLocks) append(mu *sync.Mutex) { + me.mus = append(me.mus, mu) +} + +// Holds an STM variable. +type Var[T any] struct { + value Value[VarValue] + watchers sync.Map + mu sync.Mutex +} + +func (v *Var[T]) getValue() *Value[VarValue] { + return &v.value +} + +func (v *Var[T]) getWatchers() *sync.Map { + return &v.watchers +} + +func (v *Var[T]) getLock() *sync.Mutex { + return &v.mu +} + +func (v *Var[T]) changeValue(new any) { + old := v.value.Load() + newVarValue := old.Set(new) + v.value.Store(newVarValue) + if old.Changed(newVarValue) { + go v.wakeWatchers(newVarValue) + } +} + +func (v *Var[T]) wakeWatchers(new VarValue) { + v.watchers.Range(func(k, _ any) bool { + tx := k.(*Tx) + // We have to lock here to ensure that the Tx is waiting before + // we signal it. Otherwise we could signal it before it goes to + // sleep and it will miss the notification. + tx.mu.Lock() + if read := tx.reads[v]; read != nil && read.Changed(new) { + tx.cond.Broadcast() + for !tx.waiting && !tx.completed { + tx.cond.Wait() + } + } + tx.mu.Unlock() + return !v.value.Load().Changed(new) + }) +} + +// Returns a new STM variable. +func NewVar[T any](val T) *Var[T] { + v := &Var[T]{} + v.value.Store(versionedValue[T]{ + value: val, + }) + return v +} + +func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { + v := &Var[T]{} + v.value.Store(customVarValue[T]{ + value: val, + changed: changed, + }) + return v +} + +func NewBuiltinEqVar[T comparable](val T) *Var[T] { + return NewCustomVar(val, func(a, b T) bool { + return a != b + }) +} + +var retries = pprof.NewProfile("stmRetries") + +// retry is a sentinel value. When thrown via panic, it indicates that a +// transaction should be retried. +var retry = &struct{}{} + +// catchRetry returns true if fn calls tx.Retry. +func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) { + defer func() { + if r := recover(); r == retry { + gotRetry = true + } else if r != nil { + panic(r) + } + }() + result = fn(tx) + return +} + +// This is the type constraint for keys passed through from pds +type KeyConstraint interface { + comparable +} + +type Settish[K KeyConstraint] interface { + Add(K) Settish[K] + Delete(K) Settish[K] + Contains(K) bool + Range(func(K) bool) + Len() int + // iter.Iterable +} + +type mapToSet[K KeyConstraint] struct { + m Mappish[K, struct{}] +} + +type interhash[K KeyConstraint] struct{} + +func (interhash[K]) Hash(x K) uint32 { + return uint32(nilinterhash(unsafe.Pointer(&x), 0)) +} + +func (interhash[K]) Equal(i, j K) bool { + return i == j +} + +func NewSet[K KeyConstraint]() Settish[K] { + return mapToSet[K]{NewMap[K, struct{}]()} +} + +func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] { + return mapToSet[K]{NewSortedMap[K, struct{}](lesser)} +} + +func (s mapToSet[K]) Add(x K) Settish[K] { + s.m = s.m.Set(x, struct{}{}) + return s +} + +func (s mapToSet[K]) Delete(x K) Settish[K] { + s.m = s.m.Delete(x) + return s +} + +func (s mapToSet[K]) Len() int { + return s.m.Len() +} + +func (s mapToSet[K]) Contains(x K) bool { + _, ok := s.m.Get(x) + return ok +} + +func (s mapToSet[K]) Range(f func(K) bool) { + s.m.Range(func(k K, _ struct{}) bool { + return f(k) + }) +} + +/* +func (s mapToSet[K]) Iter(cb iter.Callback) { + s.Range(func(k K) bool { + return cb(k) + }) +} +*/ + +type Map[K KeyConstraint, V any] struct { + *pds.Map[K, V] +} + +func NewMap[K KeyConstraint, V any]() Mappish[K, V] { + return Map[K, V]{pds.NewMap[K, V](interhash[K]{})} +} + +func (m Map[K, V]) Delete(x K) Mappish[K, V] { + m.Map = m.Map.Delete(x) + return m +} + +func (m Map[K, V]) Set(key K, value V) Mappish[K, V] { + m.Map = m.Map.Set(key, value) + return m +} + +func (sm Map[K, V]) Range(f func(K, V) bool) { + iter := sm.Map.Iterator() + for { + k, v, ok := iter.Next() + if !ok { + break + } + if !f(k, v) { + return + } + } +} + +/* +func (sm Map[K, V]) Iter(cb iter.Callback) { + sm.Range(func(key K, _ V) bool { + return cb(key) + }) +} +*/ + +type SortedMap[K KeyConstraint, V any] struct { + *pds.SortedMap[K, V] +} + +func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] { + sm.SortedMap = sm.SortedMap.Set(key, value) + return sm +} + +func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] { + sm.SortedMap = sm.SortedMap.Delete(key) + return sm +} + +func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) { + iter := sm.SortedMap.Iterator() + for { + k, v, ok := iter.Next() + if !ok { + break + } + if !f(k, v) { + return + } + } +} + +/* +func (sm SortedMap[K, V]) Iter(cb iter.Callback) { + sm.Range(func(key K, _ V) bool { + return cb(key) + }) +} +*/ + +type lessFunc[T KeyConstraint] func(l, r T) bool + +type comparer[K KeyConstraint] struct { + less lessFunc[K] +} + +func (me comparer[K]) Compare(i, j K) int { + if me.less(i, j) { + return -1 + } else if me.less(j, i) { + return 1 + } else { + return 0 + } +} + +func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] { + return SortedMap[K, V]{ + SortedMap: pds.NewSortedMap[K, V](comparer[K]{less}), + } +} + +type Mappish[K, V any] interface { + Set(K, V) Mappish[K, V] + Delete(key K) Mappish[K, V] + Get(key K) (V, bool) + Range(func(K, V) bool) + Len() int + // iter.Iterable +} + +func GetLeft(l, _ any) any { + return l +} + +//go:noescape +//go:linkname nilinterhash runtime.nilinterhash +func nilinterhash(p unsafe.Pointer, h uintptr) uintptr + +func interfaceHash(x any) uint32 { + return uint32(nilinterhash(unsafe.Pointer(&x), 0)) +} + +type Lenner interface { + Len() int +} + +var ( + mu sync.Mutex + ctxVars = map[context.Context]*Var[bool]{} +) + +// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be +// called when the user is no longer interested in the var. +func ContextDoneVar(ctx context.Context) (*Var[bool], func()) { + mu.Lock() + defer mu.Unlock() + if v, ok := ctxVars[ctx]; ok { + return v, func() {} + } + if ctx.Err() != nil { + // TODO: What if we had read-only Vars? Then we could have a global one for this that we + // just reuse. + v := NewBuiltinEqVar(true) + return v, func() {} + } + v := NewVar(false) + go func() { + <-ctx.Done() + AtomicSet(v, true) + mu.Lock() + delete(ctxVars, ctx) + mu.Unlock() + }() + ctxVars[ctx] = v + return v, func() {} +} + +type numTokens = int + +type Limiter struct { + max *Var[numTokens] + cur *Var[numTokens] + lastAdd *Var[time.Time] + rate Limit +} + +const Inf = Limit(math.MaxFloat64) + +type Limit float64 + +func (l Limit) interval() time.Duration { + return time.Duration(Limit(1*time.Second) / l) +} + +func Every(interval time.Duration) Limit { + if interval == 0 { + return Inf + } + return Limit(time.Second / interval) +} + +func NewLimiter(rate Limit, burst numTokens) *Limiter { + rl := &Limiter{ + max: NewVar(burst), + cur: NewBuiltinEqVar(burst), + lastAdd: NewVar(time.Now()), + rate: rate, + } + if rate != Inf { + go rl.tokenGenerator(rate.interval()) + } + return rl +} + +func (rl *Limiter) tokenGenerator(interval time.Duration) { + for { + lastAdd := AtomicGet(rl.lastAdd) + time.Sleep(time.Until(lastAdd.Add(interval))) + now := time.Now() + available := numTokens(now.Sub(lastAdd) / interval) + if available < 1 { + continue + } + Atomically(VoidOperation(func(tx *Tx) { + cur := rl.cur.Get(tx) + max := rl.max.Get(tx) + tx.Assert(cur < max) + newCur := cur + available + if newCur > max { + newCur = max + } + if newCur != cur { + rl.cur.Set(tx, newCur) + } + rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) + })) + } +} + +func (rl *Limiter) Allow() bool { + return rl.AllowN(1) +} + +func (rl *Limiter) AllowN(n numTokens) bool { + return Atomically(func(tx *Tx) bool { + return rl.takeTokens(tx, n) + }) +} + +func (rl *Limiter) AllowStm(tx *Tx) bool { + return rl.takeTokens(tx, 1) +} + +func (rl *Limiter) takeTokens(tx *Tx, n numTokens) bool { + if rl.rate == Inf { + return true + } + cur := rl.cur.Get(tx) + if cur >= n { + rl.cur.Set(tx, cur-n) + return true + } + return false +} + +func (rl *Limiter) Wait(ctx context.Context) error { + return rl.WaitN(ctx, 1) +} + +func (rl *Limiter) WaitN(ctx context.Context, n int) error { + ctxDone, cancel := ContextDoneVar(ctx) + defer cancel() + if err := Atomically(func(tx *Tx) error { + if ctxDone.Get(tx) { + return ctx.Err() + } + if rl.takeTokens(tx, n) { + return nil + } + if n > rl.max.Get(tx) { + return errors.New("burst exceeded") + } + if dl, ok := ctx.Deadline(); ok { + if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { + return context.DeadlineExceeded + } + } + tx.Retry() + panic("unreachable") + }); err != nil { + return err + } + return nil + +} diff --git a/stm_test.go b/stm_test.go deleted file mode 100644 index a98685b..0000000 --- a/stm_test.go +++ /dev/null @@ -1,274 +0,0 @@ -package stm - -import ( - "sync" - "testing" - "time" - - _ "github.com/anacrolix/envpprof" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDecrement(t *testing.T) { - x := NewVar(1000) - for i := 0; i < 500; i++ { - go Atomically(VoidOperation(func(tx *Tx) { - cur := x.Get(tx) - x.Set(tx, cur-1) - })) - } - done := make(chan struct{}) - go func() { - Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(x.Get(tx) == 500) - })) - close(done) - }() - select { - case <-done: - case <-time.After(10 * time.Second): - t.Fatal("decrement did not complete in time") - } -} - -// read-only transaction aren't exempt from calling tx.inputsChanged -func TestReadVerify(t *testing.T) { - read := make(chan struct{}) - x, y := NewVar(1), NewVar(2) - - // spawn a transaction that writes to x - go func() { - <-read - AtomicSet(x, 3) - read <- struct{}{} - // other tx should retry, so we need to read/send again - read <- <-read - }() - - // spawn a transaction that reads x, then y. The other tx will modify x in - // between the reads, causing this tx to retry. - var x2, y2 int - Atomically(VoidOperation(func(tx *Tx) { - x2 = x.Get(tx) - read <- struct{}{} - <-read // wait for other tx to complete - y2 = y.Get(tx) - })) - if x2 == 1 && y2 == 2 { - t.Fatal("read was not verified") - } -} - -func TestRetry(t *testing.T) { - x := NewVar(10) - // spawn 10 transactions, one every 10 milliseconds. This will decrement x - // to 0 over the course of 100 milliseconds. - go func() { - for i := 0; i < 10; i++ { - time.Sleep(10 * time.Millisecond) - Atomically(VoidOperation(func(tx *Tx) { - cur := x.Get(tx) - x.Set(tx, cur-1) - })) - } - }() - // Each time we read x before the above loop has finished, we need to - // retry. This should result in no more than 1 retry per transaction. - retry := 0 - Atomically(VoidOperation(func(tx *Tx) { - cur := x.Get(tx) - if cur != 0 { - retry++ - tx.Retry() - } - })) - if retry > 10 { - t.Fatal("should have retried at most 10 times, got", retry) - } -} - -func TestVerify(t *testing.T) { - // tx.inputsChanged should check more than pointer equality - type foo struct { - i int - } - x := NewVar(&foo{3}) - read := make(chan struct{}) - - // spawn a transaction that modifies x - go func() { - Atomically(VoidOperation(func(tx *Tx) { - <-read - rx := x.Get(tx) - rx.i = 7 - x.Set(tx, rx) - })) - read <- struct{}{} - // other tx should retry, so we need to read/send again - read <- <-read - }() - - // spawn a transaction that reads x, then y. The other tx will modify x in - // between the reads, causing this tx to retry. - var i int - Atomically(VoidOperation(func(tx *Tx) { - f := x.Get(tx) - i = f.i - read <- struct{}{} - <-read // wait for other tx to complete - })) - if i == 3 { - t.Fatal("inputsChanged did not retry despite modified Var", i) - } -} - -func TestSelect(t *testing.T) { - // empty Select should panic - require.Panics(t, func() { Atomically(Select[struct{}]()) }) - - // with one arg, Select adds no effect - x := NewVar(2) - Atomically(Select(VoidOperation(func(tx *Tx) { - tx.Assert(x.Get(tx) == 2) - }))) - - picked := Atomically(Select( - // always blocks; should never be selected - func(tx *Tx) int { - tx.Retry() - panic("unreachable") - }, - // always succeeds; should always be selected - func(tx *Tx) int { - return 2 - }, - // always succeeds; should never be selected - func(tx *Tx) int { - return 3 - }, - )) - assert.EqualValues(t, 2, picked) -} - -func TestCompose(t *testing.T) { - nums := make([]int, 100) - fns := make([]Operation[struct{}], 100) - for i := range fns { - fns[i] = func(x int) Operation[struct{}] { - return VoidOperation(func(*Tx) { nums[x] = x }) - }(i) // capture loop var - } - Atomically(Compose(fns...)) - for i := range nums { - if nums[i] != i { - t.Error("Compose failed:", nums[i], i) - } - } -} - -func TestPanic(t *testing.T) { - // normal panics should escape Atomically - assert.PanicsWithValue(t, "foo", func() { - Atomically(func(*Tx) any { - panic("foo") - }) - }) -} - -func TestReadWritten(t *testing.T) { - // reading a variable written in the same transaction should return the - // previously written value - x := NewVar(3) - Atomically(VoidOperation(func(tx *Tx) { - x.Set(tx, 5) - tx.Assert(x.Get(tx) == 5) - })) -} - -func TestAtomicSetRetry(t *testing.T) { - // AtomicSet should cause waiting transactions to retry - x := NewVar(3) - done := make(chan struct{}) - go func() { - Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(x.Get(tx) == 5) - })) - done <- struct{}{} - }() - time.Sleep(10 * time.Millisecond) - AtomicSet(x, 5) - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("AtomicSet did not wake up a waiting transaction") - } -} - -func testPingPong(t testing.TB, n int, afterHit func(string)) { - ball := NewBuiltinEqVar(false) - doneVar := NewVar(false) - hits := NewVar(0) - ready := NewVar(true) // The ball is ready for hitting. - var wg sync.WaitGroup - bat := func(from, to bool, noise string) { - defer wg.Done() - for !Atomically(func(tx *Tx) any { - if doneVar.Get(tx) { - return true - } - tx.Assert(ready.Get(tx)) - if ball.Get(tx) == from { - ball.Set(tx, to) - hits.Set(tx, hits.Get(tx)+1) - ready.Set(tx, false) - return false - } - return tx.Retry() - }).(bool) { - afterHit(noise) - AtomicSet(ready, true) - } - } - wg.Add(2) - go bat(false, true, "ping!") - go bat(true, false, "pong!") - Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(hits.Get(tx) >= n) - doneVar.Set(tx, true) - })) - wg.Wait() -} - -func TestPingPong(t *testing.T) { - testPingPong(t, 42, func(s string) { t.Log(s) }) -} - -func TestSleepingBeauty(t *testing.T) { - require.Panics(t, func() { - Atomically(func(tx *Tx) any { - tx.Assert(false) - return nil - }) - }) -} - -//func TestRetryStack(t *testing.T) { -// v := NewVar[int](nil) -// go func() { -// i := 0 -// for { -// AtomicSet(v, i) -// i++ -// } -// }() -// Atomically(func(tx *Tx) any { -// debug.PrintStack() -// ret := func() { -// defer Atomically(nil) -// } -// v.Get(tx) -// tx.Assert(false) -// return ret -// }) -//} diff --git a/stmutil/containers.go b/stmutil/containers.go deleted file mode 100644 index 0cc592d..0000000 --- a/stmutil/containers.go +++ /dev/null @@ -1,192 +0,0 @@ -package stmutil - -import ( - "unsafe" - - "github.com/anacrolix/missinggo/v2/iter" - "github.com/benbjohnson/immutable" -) - -// This is the type constraint for keys passed through from github.com/benbjohnson/immutable. -type KeyConstraint interface { - comparable -} - -type Settish[K KeyConstraint] interface { - Add(K) Settish[K] - Delete(K) Settish[K] - Contains(K) bool - Range(func(K) bool) - iter.Iterable - Len() int -} - -type mapToSet[K KeyConstraint] struct { - m Mappish[K, struct{}] -} - -type interhash[K KeyConstraint] struct{} - -func (interhash[K]) Hash(x K) uint32 { - return uint32(nilinterhash(unsafe.Pointer(&x), 0)) -} - -func (interhash[K]) Equal(i, j K) bool { - return i == j -} - -func NewSet[K KeyConstraint]() Settish[K] { - return mapToSet[K]{NewMap[K, struct{}]()} -} - -func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] { - return mapToSet[K]{NewSortedMap[K, struct{}](lesser)} -} - -func (s mapToSet[K]) Add(x K) Settish[K] { - s.m = s.m.Set(x, struct{}{}) - return s -} - -func (s mapToSet[K]) Delete(x K) Settish[K] { - s.m = s.m.Delete(x) - return s -} - -func (s mapToSet[K]) Len() int { - return s.m.Len() -} - -func (s mapToSet[K]) Contains(x K) bool { - _, ok := s.m.Get(x) - return ok -} - -func (s mapToSet[K]) Range(f func(K) bool) { - s.m.Range(func(k K, _ struct{}) bool { - return f(k) - }) -} - -func (s mapToSet[K]) Iter(cb iter.Callback) { - s.Range(func(k K) bool { - return cb(k) - }) -} - -type Map[K KeyConstraint, V any] struct { - *immutable.Map[K, V] -} - -func NewMap[K KeyConstraint, V any]() Mappish[K, V] { - return Map[K, V]{immutable.NewMap[K, V](interhash[K]{})} -} - -func (m Map[K, V]) Delete(x K) Mappish[K, V] { - m.Map = m.Map.Delete(x) - return m -} - -func (m Map[K, V]) Set(key K, value V) Mappish[K, V] { - m.Map = m.Map.Set(key, value) - return m -} - -func (sm Map[K, V]) Range(f func(K, V) bool) { - iter := sm.Map.Iterator() - for { - k, v, ok := iter.Next() - if !ok { - break - } - if !f(k, v) { - return - } - } -} - -func (sm Map[K, V]) Iter(cb iter.Callback) { - sm.Range(func(key K, _ V) bool { - return cb(key) - }) -} - -type SortedMap[K KeyConstraint, V any] struct { - *immutable.SortedMap[K, V] -} - -func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] { - sm.SortedMap = sm.SortedMap.Set(key, value) - return sm -} - -func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] { - sm.SortedMap = sm.SortedMap.Delete(key) - return sm -} - -func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) { - iter := sm.SortedMap.Iterator() - for { - k, v, ok := iter.Next() - if !ok { - break - } - if !f(k, v) { - return - } - } -} - -func (sm SortedMap[K, V]) Iter(cb iter.Callback) { - sm.Range(func(key K, _ V) bool { - return cb(key) - }) -} - -type lessFunc[T KeyConstraint] func(l, r T) bool - -type comparer[K KeyConstraint] struct { - less lessFunc[K] -} - -func (me comparer[K]) Compare(i, j K) int { - if me.less(i, j) { - return -1 - } else if me.less(j, i) { - return 1 - } else { - return 0 - } -} - -func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] { - return SortedMap[K, V]{ - SortedMap: immutable.NewSortedMap[K, V](comparer[K]{less}), - } -} - -type Mappish[K, V any] interface { - Set(K, V) Mappish[K, V] - Delete(key K) Mappish[K, V] - Get(key K) (V, bool) - Range(func(K, V) bool) - Len() int - iter.Iterable -} - -func GetLeft(l, _ any) any { - return l -} - -//go:noescape -//go:linkname nilinterhash runtime.nilinterhash -func nilinterhash(p unsafe.Pointer, h uintptr) uintptr - -func interfaceHash(x any) uint32 { - return uint32(nilinterhash(unsafe.Pointer(&x), 0)) -} - -type Lenner interface { - Len() int -} diff --git a/stmutil/context.go b/stmutil/context.go deleted file mode 100644 index 6f8ba9b..0000000 --- a/stmutil/context.go +++ /dev/null @@ -1,39 +0,0 @@ -package stmutil - -import ( - "context" - "sync" - - "github.com/anacrolix/stm" -) - -var ( - mu sync.Mutex - ctxVars = map[context.Context]*stm.Var[bool]{} -) - -// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be -// called when the user is no longer interested in the var. -func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) { - mu.Lock() - defer mu.Unlock() - if v, ok := ctxVars[ctx]; ok { - return v, func() {} - } - if ctx.Err() != nil { - // TODO: What if we had read-only Vars? Then we could have a global one for this that we - // just reuse. - v := stm.NewBuiltinEqVar(true) - return v, func() {} - } - v := stm.NewVar(false) - go func() { - <-ctx.Done() - stm.AtomicSet(v, true) - mu.Lock() - delete(ctxVars, ctx) - mu.Unlock() - }() - ctxVars[ctx] = v - return v, func() {} -} diff --git a/stmutil/context_test.go b/stmutil/context_test.go deleted file mode 100644 index 0a6b7c0..0000000 --- a/stmutil/context_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package stmutil - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestContextEquality(t *testing.T) { - ctx := context.Background() - assert.True(t, ctx == context.Background()) - childCtx, cancel := context.WithCancel(ctx) - assert.True(t, childCtx != ctx) - assert.True(t, childCtx != ctx) - assert.Equal(t, context.Background(), ctx) - cancel() - assert.Equal(t, context.Background(), ctx) - assert.NotEqual(t, ctx, childCtx) -} diff --git a/tests/stm.go b/tests/stm.go index a30ac1a..816cdec 100644 --- a/tests/stm.go +++ b/tests/stm.go @@ -1,8 +1,1185 @@ package stm import ( + "context" + "fmt" + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + g "gobang" +) + + + +func TestValue(t *testing.T) { + v := New("hello") + g.TAssertEqual("hello", v.Load()) + v.Store("world") + g.TAssertEqual("world", v.Load()) +} + +func TestValueZeroValue(t *testing.T) { + var v Value[string] + // assert.Panics(t, func() { v.Load() }) + v.Store("world") + g.TAssertEqual("world", v.Load()) +} + +func TestValueSwapZeroValue(t *testing.T) { + // var v Value[string] + // assert.Panics(t, func() { v.Swap("hello") }) +} + +func TestInt32(t *testing.T) { + v := NewInt32(0) + g.TAssertEqual(0, v.Load()) + g.TAssertEqual(true, v.CompareAndSwap(0, 10)) + g.TAssertEqual(false, v.CompareAndSwap(0, 10)) +} + +func BenchmarkInt64Add(b *testing.B) { + v := NewInt64(0) + for i := 0; i < b.N; i++ { + v.Add(1) + } +} + +func BenchmarkIntInterfaceAdd(b *testing.B) { + var v Int[int64] = NewInt64(0) + for i := 0; i < b.N; i++ { + v.Add(1) + } +} + +func BenchmarkStdlibInt64Add(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + atomic.AddInt64(&n, 1) + } +} + +func BenchmarkInterfaceStore(b *testing.B) { + var v Interface[string] = New("hello") + for i := 0; i < b.N; i++ { + v.Store(fmt.Sprint(i)) + } +} + +func BenchmarkValueStore(b *testing.B) { + v := New("hello") + for i := 0; i < b.N; i++ { + v.Store(fmt.Sprint(i)) + } +} + +func BenchmarkStdlibValueStore(b *testing.B) { + v := atomic.Value{} + for i := 0; i < b.N; i++ { + v.Store(fmt.Sprint(i)) + } +} + +func BenchmarkAtomicGet(b *testing.B) { + x := NewVar(0) + for i := 0; i < b.N; i++ { + AtomicGet(x) + } +} + +func BenchmarkAtomicSet(b *testing.B) { + x := NewVar(0) + for i := 0; i < b.N; i++ { + AtomicSet(x, 0) + } +} + +func BenchmarkIncrementSTM(b *testing.B) { + for i := 0; i < b.N; i++ { + // spawn 1000 goroutines that each increment x by 1 + x := NewVar(0) + for i := 0; i < 1000; i++ { + go Atomically(VoidOperation(func(tx *Tx) { + cur := x.Get(tx) + x.Set(tx, cur+1) + })) + } + // wait for x to reach 1000 + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(x.Get(tx) == 1000) + })) + } +} + +func BenchmarkIncrementMutex(b *testing.B) { + for i := 0; i < b.N; i++ { + var mu sync.Mutex + x := 0 + for i := 0; i < 1000; i++ { + go func() { + mu.Lock() + x++ + mu.Unlock() + }() + } + for { + mu.Lock() + read := x + mu.Unlock() + if read == 1000 { + break + } + } + } +} + +func BenchmarkIncrementChannel(b *testing.B) { + for i := 0; i < b.N; i++ { + c := make(chan int, 1) + c <- 0 + for i := 0; i < 1000; i++ { + go func() { + c <- 1 + <-c + }() + } + for { + read := <-c + if read == 1000 { + break + } + c <- read + } + } +} + +func BenchmarkReadVarSTM(b *testing.B) { + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + wg.Add(1000) + x := NewVar(0) + for i := 0; i < 1000; i++ { + go func() { + AtomicGet(x) + wg.Done() + }() + } + wg.Wait() + } +} + +func BenchmarkReadVarMutex(b *testing.B) { + for i := 0; i < b.N; i++ { + var mu sync.Mutex + var wg sync.WaitGroup + wg.Add(1000) + x := 0 + for i := 0; i < 1000; i++ { + go func() { + mu.Lock() + _ = x + mu.Unlock() + wg.Done() + }() + } + wg.Wait() + } +} + +func BenchmarkReadVarChannel(b *testing.B) { + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + wg.Add(1000) + c := make(chan int) + close(c) + for i := 0; i < 1000; i++ { + go func() { + <-c + wg.Done() + }() + } + wg.Wait() + } +} + +func parallelPingPongs(b *testing.B, n int) { + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + testPingPong(b, b.N, func(string) {}) + }() + } + wg.Wait() +} + +func BenchmarkPingPong4(b *testing.B) { + b.ReportAllocs() + parallelPingPongs(b, 4) +} + +func BenchmarkPingPong(b *testing.B) { + b.ReportAllocs() + parallelPingPongs(b, 1) +} + +func Example() { + // create a shared variable + n := NewVar(3) + + // read a variable + var v int + Atomically(VoidOperation(func(tx *Tx) { + v = n.Get(tx) + })) + // or: + v = AtomicGet(n) + _ = v + + // write to a variable + Atomically(VoidOperation(func(tx *Tx) { + n.Set(tx, 12) + })) + // or: + AtomicSet(n, 12) + + // update a variable + Atomically(VoidOperation(func(tx *Tx) { + cur := n.Get(tx) + n.Set(tx, cur-1) + })) + + // block until a condition is met + Atomically(VoidOperation(func(tx *Tx) { + cur := n.Get(tx) + if cur != 0 { + tx.Retry() + } + n.Set(tx, 10) + })) + // or: + Atomically(VoidOperation(func(tx *Tx) { + cur := n.Get(tx) + tx.Assert(cur == 0) + n.Set(tx, 10) + })) + + // select among multiple (potentially blocking) transactions + Atomically(Select( + // this function blocks forever, so it will be skipped + VoidOperation(func(tx *Tx) { tx.Retry() }), + + // this function will always succeed without blocking + VoidOperation(func(tx *Tx) { n.Set(tx, 10) }), + + // this function will never run, because the previous + // function succeeded + VoidOperation(func(tx *Tx) { n.Set(tx, 11) }), + )) + + // since Select is a normal transaction, if the entire select retries + // (blocks), it will be retried as a whole: + x := 0 + Atomically(Select( + // this function will run twice, and succeed the second time + VoidOperation(func(tx *Tx) { tx.Assert(x == 1) }), + + // this function will run once + VoidOperation(func(tx *Tx) { + x = 1 + tx.Retry() + }), + )) + // But wait! Transactions are only retried when one of the Vars they read is + // updated. Since x isn't a stm Var, this code will actually block forever -- + // but you get the idea. +} + +const maxTokens = 25 + +func BenchmarkThunderingHerdCondVar(b *testing.B) { + for i := 0; i < b.N; i++ { + var mu sync.Mutex + consumer := sync.NewCond(&mu) + generator := sync.NewCond(&mu) + done := false + tokens := 0 + var pending sync.WaitGroup + for i := 0; i < 1000; i++ { + pending.Add(1) + go func() { + mu.Lock() + for { + if tokens > 0 { + tokens-- + generator.Signal() + break + } + consumer.Wait() + } + mu.Unlock() + pending.Done() + }() + } + go func() { + mu.Lock() + for !done { + if tokens < maxTokens { + tokens++ + consumer.Signal() + } else { + generator.Wait() + } + } + mu.Unlock() + }() + pending.Wait() + mu.Lock() + done = true + generator.Signal() + mu.Unlock() + } + +} + +func BenchmarkThunderingHerd(b *testing.B) { + for i := 0; i < b.N; i++ { + done := NewBuiltinEqVar(false) + tokens := NewBuiltinEqVar(0) + pending := NewBuiltinEqVar(0) + for i := 0; i < 1000; i++ { + Atomically(VoidOperation(func(tx *Tx) { + pending.Set(tx, pending.Get(tx)+1) + })) + go func() { + Atomically(VoidOperation(func(tx *Tx) { + t := tokens.Get(tx) + if t > 0 { + tokens.Set(tx, t-1) + pending.Set(tx, pending.Get(tx)-1) + } else { + tx.Retry() + } + })) + }() + } + go func() { + for Atomically(func(tx *Tx) bool { + if done.Get(tx) { + return false + } + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) + return true + }) { + } + }() + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(pending.Get(tx) == 0) + })) + AtomicSet(done, true) + } +} + +func BenchmarkInvertedThunderingHerd(b *testing.B) { + for i := 0; i < b.N; i++ { + done := NewBuiltinEqVar(false) + tokens := NewBuiltinEqVar(0) + pending := NewVar(NewSet[*Var[bool]]()) + for i := 0; i < 1000; i++ { + ready := NewVar(false) + Atomically(VoidOperation(func(tx *Tx) { + pending.Set(tx, pending.Get(tx).Add(ready)) + })) + go func() { + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(ready.Get(tx)) + set := pending.Get(tx) + if !set.Contains(ready) { + panic("couldn't find ourselves in pending") + } + pending.Set(tx, set.Delete(ready)) + })) + //b.Log("waiter finished") + }() + } + go func() { + for Atomically(func(tx *Tx) bool { + if done.Get(tx) { + return false + } + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) + return true + }) { + } + }() + go func() { + for Atomically(func(tx *Tx) bool { + tx.Assert(tokens.Get(tx) > 0) + tokens.Set(tx, tokens.Get(tx)-1) + pending.Get(tx).Range(func(ready *Var[bool]) bool { + if !ready.Get(tx) { + ready.Set(tx, true) + return false + } + return true + }) + return !done.Get(tx) + }) { + } + }() + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(pending.Get(tx).(Lenner).Len() == 0) + })) + AtomicSet(done, true) + } +} + +func TestLimit(t *testing.T) { + if Limit(10) == Inf { + t.Errorf("Limit(10) == Inf should be false") + } +} + +func closeEnough(a, b Limit) bool { + return (math.Abs(float64(a)/float64(b)) - 1.0) < 1e-9 +} + +func TestEvery(t *testing.T) { + cases := []struct { + interval time.Duration + lim Limit + }{ + {0, Inf}, + {-1, Inf}, + {1 * time.Nanosecond, Limit(1e9)}, + {1 * time.Microsecond, Limit(1e6)}, + {1 * time.Millisecond, Limit(1e3)}, + {10 * time.Millisecond, Limit(100)}, + {100 * time.Millisecond, Limit(10)}, + {1 * time.Second, Limit(1)}, + {2 * time.Second, Limit(0.5)}, + {time.Duration(2.5 * float64(time.Second)), Limit(0.4)}, + {4 * time.Second, Limit(0.25)}, + {10 * time.Second, Limit(0.1)}, + {time.Duration(math.MaxInt64), Limit(1e9 / float64(math.MaxInt64))}, + } + for _, tc := range cases { + lim := Every(tc.interval) + if !closeEnough(lim, tc.lim) { + t.Errorf("Every(%v) = %v want %v", tc.interval, lim, tc.lim) + } + } +} + +const ( + d = 100 * time.Millisecond +) + +var ( + t0 = time.Now() + t1 = t0.Add(time.Duration(1) * d) + t2 = t0.Add(time.Duration(2) * d) + t3 = t0.Add(time.Duration(3) * d) + t4 = t0.Add(time.Duration(4) * d) + t5 = t0.Add(time.Duration(5) * d) + t9 = t0.Add(time.Duration(9) * d) ) +type allow struct { + t time.Time + n int + ok bool +} + +// +//func run(t *testing.T, lim *Limiter, allows []allow) { +// for i, allow := range allows { +// ok := lim.AllowN(allow.t, allow.n) +// if ok != allow.ok { +// t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v", +// i, allow.t, allow.n, ok, allow.ok) +// } +// } +//} +// +//func TestLimiterBurst1(t *testing.T) { +// run(t, NewLimiter(10, 1), []allow{ +// {t0, 1, true}, +// {t0, 1, false}, +// {t0, 1, false}, +// {t1, 1, true}, +// {t1, 1, false}, +// {t1, 1, false}, +// {t2, 2, false}, // burst size is 1, so n=2 always fails +// {t2, 1, true}, +// {t2, 1, false}, +// }) +//} +// +//func TestLimiterBurst3(t *testing.T) { +// run(t, NewLimiter(10, 3), []allow{ +// {t0, 2, true}, +// {t0, 2, false}, +// {t0, 1, true}, +// {t0, 1, false}, +// {t1, 4, false}, +// {t2, 1, true}, +// {t3, 1, true}, +// {t4, 1, true}, +// {t4, 1, true}, +// {t4, 1, false}, +// {t4, 1, false}, +// {t9, 3, true}, +// {t9, 0, true}, +// }) +//} +// +//func TestLimiterJumpBackwards(t *testing.T) { +// run(t, NewLimiter(10, 3), []allow{ +// {t1, 1, true}, // start at t1 +// {t0, 1, true}, // jump back to t0, two tokens remain +// {t0, 1, true}, +// {t0, 1, false}, +// {t0, 1, false}, +// {t1, 1, true}, // got a token +// {t1, 1, false}, +// {t1, 1, false}, +// {t2, 1, true}, // got another token +// {t2, 1, false}, +// {t2, 1, false}, +// }) +//} + +// Ensure that tokensFromDuration doesn't produce +// rounding errors by truncating nanoseconds. +// See golang.org/issues/34861. +func TestLimiter_noTruncationErrors(t *testing.T) { + if !NewLimiter(0.7692307692307693, 1).Allow() { + t.Fatal("expected true") + } +} + +func TestSimultaneousRequests(t *testing.T) { + const ( + limit = 1 + burst = 5 + numRequests = 15 + ) + var ( + wg sync.WaitGroup + numOK = uint32(0) + ) + + // Very slow replenishing bucket. + lim := NewLimiter(limit, burst) + + // Tries to take a token, atomically updates the counter and decreases the wait + // group counter. + f := func() { + defer wg.Done() + if ok := lim.Allow(); ok { + atomic.AddUint32(&numOK, 1) + } + } + + wg.Add(numRequests) + for i := 0; i < numRequests; i++ { + go f() + } + wg.Wait() + if numOK != burst { + t.Errorf("numOK = %d, want %d", numOK, burst) + } +} + +func TestLongRunningQPS(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + if runtime.GOOS == "openbsd" { + t.Skip("low resolution time.Sleep invalidates test (golang.org/issue/14183)") + return + } + + // The test runs for a few seconds executing many requests and then checks + // that overall number of requests is reasonable. + const ( + limit = 100 + burst = 100 + ) + var numOK = int32(0) + + lim := NewLimiter(limit, burst) + + var wg sync.WaitGroup + f := func() { + if ok := lim.Allow(); ok { + atomic.AddInt32(&numOK, 1) + } + wg.Done() + } + + start := time.Now() + end := start.Add(5 * time.Second) + for time.Now().Before(end) { + wg.Add(1) + go f() + + // This will still offer ~500 requests per second, but won't consume + // outrageous amount of CPU. + time.Sleep(2 * time.Millisecond) + } + wg.Wait() + elapsed := time.Since(start) + ideal := burst + (limit * float64(elapsed) / float64(time.Second)) + + // We should never get more requests than allowed. + if want := int32(ideal + 1); numOK > want { + t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int32(0.999 * ideal); numOK < want { + t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) + } +} + +type request struct { + t time.Time + n int + act time.Time + ok bool +} + +// dFromDuration converts a duration to a multiple of the global constant d +func dFromDuration(dur time.Duration) int { + // Adding a millisecond to be swallowed by the integer division + // because we don't care about small inaccuracies + return int((dur + time.Millisecond) / d) +} + +// dSince returns multiples of d since t0 +func dSince(t time.Time) int { + return dFromDuration(t.Sub(t0)) +} + +// +//func runReserve(t *testing.T, lim *Limiter, req request) *Reservation { +// return runReserveMax(t, lim, req, InfDuration) +//} +// +//func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation { +// r := lim.reserveN(req.t, req.n, maxReserve) +// if r.ok && (dSince(r.timeToAct) != dSince(req.act)) || r.ok != req.ok { +// t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)", +// dSince(req.t), req.n, maxReserve, dSince(r.timeToAct), r.ok, dSince(req.act), req.ok) +// } +// return &r +//} +// +//func TestSimpleReserve(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// runReserve(t, lim, request{t0, 2, t2, true}) +// runReserve(t, lim, request{t3, 2, t4, true}) +//} +// +//func TestMix(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst +// runReserve(t, lim, request{t0, 2, t0, true}) +// run(t, lim, []allow{{t1, 2, false}}) // not enought tokens - don't allow +// runReserve(t, lim, request{t1, 2, t2, true}) +// run(t, lim, []allow{{t1, 1, false}}) // negative tokens - don't allow +// run(t, lim, []allow{{t3, 1, true}}) +//} +// +//func TestCancelInvalid(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 3, t3, false}) +// r.CancelAt(t0) // should have no effect +// runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens +//} +// +//func TestCancelLast(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 2, t2, true}) +// r.CancelAt(t1) // got 2 tokens back +// runReserve(t, lim, request{t1, 2, t2, true}) +//} +// +//func TestCancelTooLate(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 2, t2, true}) +// r.CancelAt(t3) // too late to cancel - should have no effect +// runReserve(t, lim, request{t3, 2, t4, true}) +//} +// +//func TestCancel0Tokens(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 1, t1, true}) +// runReserve(t, lim, request{t0, 1, t2, true}) +// r.CancelAt(t0) // got 0 tokens back +// runReserve(t, lim, request{t0, 1, t3, true}) +//} +// +//func TestCancel1Token(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 2, t2, true}) +// runReserve(t, lim, request{t0, 1, t3, true}) +// r.CancelAt(t2) // got 1 token back +// runReserve(t, lim, request{t2, 2, t4, true}) +//} +// +//func TestCancelMulti(t *testing.T) { +// lim := NewLimiter(10, 4) +// +// runReserve(t, lim, request{t0, 4, t0, true}) +// rA := runReserve(t, lim, request{t0, 3, t3, true}) +// runReserve(t, lim, request{t0, 1, t4, true}) +// rC := runReserve(t, lim, request{t0, 1, t5, true}) +// rC.CancelAt(t1) // get 1 token back +// rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved +// runReserve(t, lim, request{t1, 3, t5, true}) +//} +// +//func TestReserveJumpBack(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 +// runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst +// runReserve(t, lim, request{t2, 2, t3, true}) +//} + +//func TestReserveJumpBackCancel(t *testing.T) { +// lim := NewLimiter(10, 2) +// +// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 +// r := runReserve(t, lim, request{t1, 2, t3, true}) +// runReserve(t, lim, request{t1, 1, t4, true}) +// r.CancelAt(t0) // cancel at t0, get 1 token back +// runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst +//} +// +//func TestReserveSetLimit(t *testing.T) { +// lim := NewLimiter(5, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// runReserve(t, lim, request{t0, 2, t4, true}) +// lim.SetLimitAt(t2, 10) +// runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst +//} +// +//func TestReserveSetBurst(t *testing.T) { +// lim := NewLimiter(5, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// runReserve(t, lim, request{t0, 2, t4, true}) +// lim.SetBurstAt(t3, 4) +// runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst +//} +// +//func TestReserveSetLimitCancel(t *testing.T) { +// lim := NewLimiter(5, 2) +// +// runReserve(t, lim, request{t0, 2, t0, true}) +// r := runReserve(t, lim, request{t0, 2, t4, true}) +// lim.SetLimitAt(t2, 10) +// r.CancelAt(t2) // 2 tokens back +// runReserve(t, lim, request{t2, 2, t3, true}) +//} +// +//func TestReserveMax(t *testing.T) { +// lim := NewLimiter(10, 2) +// maxT := d +// +// runReserveMax(t, lim, request{t0, 2, t0, true}, maxT) +// runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future +// runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future +//} + +type wait struct { + name string + ctx context.Context + n int + delay int // in multiples of d + nilErr bool +} + +func runWait(t *testing.T, lim *Limiter, w wait) { + t.Helper() + start := time.Now() + err := lim.WaitN(w.ctx, w.n) + delay := time.Since(start) + if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) { + errString := "<nil>" + if !w.nilErr { + errString = "<non-nil error>" + } + t.Errorf("lim.WaitN(%v, lim, %v) = %v with delay %v ; want %v with delay %v", + w.name, w.n, err, delay, errString, d*time.Duration(w.delay)) + } +} + +func TestWaitSimple(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + runWait(t, lim, wait{"already-cancelled", ctx, 1, 0, false}) + + runWait(t, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false}) + + runWait(t, lim, wait{"act-now", context.Background(), 2, 0, true}) + runWait(t, lim, wait{"act-later", context.Background(), 3, 2, true}) +} + +func TestWaitCancel(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithCancel(context.Background()) + runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1 + go func() { + time.Sleep(d) + cancel() + }() + runWait(t, lim, wait{"will-cancel", ctx, 3, 1, false}) + // should get 3 tokens back, and have lim.tokens = 2 + //t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent) + runWait(t, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true}) +} + +func TestWaitTimeout(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) + runWait(t, lim, wait{"w-timeout-err", ctx, 3, 0, false}) +} + +func TestWaitInf(t *testing.T) { + lim := NewLimiter(Inf, 0) + + runWait(t, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true}) +} + +func BenchmarkAllowN(b *testing.B) { + lim := NewLimiter(Every(1*time.Second), 1) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + lim.AllowN(1) + } + }) +} + +func BenchmarkWaitNNoDelay(b *testing.B) { + lim := NewLimiter(Limit(b.N), b.N) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + lim.WaitN(ctx, 1) + } +} + +func TestDecrement(t *testing.T) { + x := NewVar(1000) + for i := 0; i < 500; i++ { + go Atomically(VoidOperation(func(tx *Tx) { + cur := x.Get(tx) + x.Set(tx, cur-1) + })) + } + done := make(chan struct{}) + go func() { + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(x.Get(tx) == 500) + })) + close(done) + }() + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("decrement did not complete in time") + } +} + +// read-only transaction aren't exempt from calling tx.inputsChanged +func TestReadVerify(t *testing.T) { + read := make(chan struct{}) + x, y := NewVar(1), NewVar(2) + + // spawn a transaction that writes to x + go func() { + <-read + AtomicSet(x, 3) + read <- struct{}{} + // other tx should retry, so we need to read/send again + read <- <-read + }() + + // spawn a transaction that reads x, then y. The other tx will modify x in + // between the reads, causing this tx to retry. + var x2, y2 int + Atomically(VoidOperation(func(tx *Tx) { + x2 = x.Get(tx) + read <- struct{}{} + <-read // wait for other tx to complete + y2 = y.Get(tx) + })) + if x2 == 1 && y2 == 2 { + t.Fatal("read was not verified") + } +} + +func TestRetry(t *testing.T) { + x := NewVar(10) + // spawn 10 transactions, one every 10 milliseconds. This will decrement x + // to 0 over the course of 100 milliseconds. + go func() { + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + Atomically(VoidOperation(func(tx *Tx) { + cur := x.Get(tx) + x.Set(tx, cur-1) + })) + } + }() + // Each time we read x before the above loop has finished, we need to + // retry. This should result in no more than 1 retry per transaction. + retry := 0 + Atomically(VoidOperation(func(tx *Tx) { + cur := x.Get(tx) + if cur != 0 { + retry++ + tx.Retry() + } + })) + if retry > 10 { + t.Fatal("should have retried at most 10 times, got", retry) + } +} + +func TestVerify(t *testing.T) { + // tx.inputsChanged should check more than pointer equality + type foo struct { + i int + } + x := NewVar(&foo{3}) + read := make(chan struct{}) + + // spawn a transaction that modifies x + go func() { + Atomically(VoidOperation(func(tx *Tx) { + <-read + rx := x.Get(tx) + rx.i = 7 + x.Set(tx, rx) + })) + read <- struct{}{} + // other tx should retry, so we need to read/send again + read <- <-read + }() + + // spawn a transaction that reads x, then y. The other tx will modify x in + // between the reads, causing this tx to retry. + var i int + Atomically(VoidOperation(func(tx *Tx) { + f := x.Get(tx) + i = f.i + read <- struct{}{} + <-read // wait for other tx to complete + })) + if i == 3 { + t.Fatal("inputsChanged did not retry despite modified Var", i) + } +} + +func TestSelect(t *testing.T) { + // empty Select should panic + // require.Panics(t, func() { Atomically(Select[struct{}]()) }) + + // with one arg, Select adds no effect + x := NewVar(2) + Atomically(Select(VoidOperation(func(tx *Tx) { + tx.Assert(x.Get(tx) == 2) + }))) + + picked := Atomically(Select( + // always blocks; should never be selected + func(tx *Tx) int { + tx.Retry() + panic("unreachable") + }, + // always succeeds; should always be selected + func(tx *Tx) int { + return 2 + }, + // always succeeds; should never be selected + func(tx *Tx) int { + return 3 + }, + )) + g.TAssertEqual(2, picked) +} + +func TestCompose(t *testing.T) { + nums := make([]int, 100) + fns := make([]Operation[struct{}], 100) + for i := range fns { + fns[i] = func(x int) Operation[struct{}] { + return VoidOperation(func(*Tx) { nums[x] = x }) + }(i) // capture loop var + } + Atomically(Compose(fns...)) + for i := range nums { + if nums[i] != i { + t.Error("Compose failed:", nums[i], i) + } + } +} + +func TestPanic(t *testing.T) { + // normal panics should escape Atomically + /* + assert.PanicsWithValue(t, "foo", func() { + Atomically(func(*Tx) any { + panic("foo") + }) + }) + */ +} + +func TestReadWritten(t *testing.T) { + // reading a variable written in the same transaction should return the + // previously written value + x := NewVar(3) + Atomically(VoidOperation(func(tx *Tx) { + x.Set(tx, 5) + tx.Assert(x.Get(tx) == 5) + })) +} + +func TestAtomicSetRetry(t *testing.T) { + // AtomicSet should cause waiting transactions to retry + x := NewVar(3) + done := make(chan struct{}) + go func() { + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(x.Get(tx) == 5) + })) + done <- struct{}{} + }() + time.Sleep(10 * time.Millisecond) + AtomicSet(x, 5) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("AtomicSet did not wake up a waiting transaction") + } +} + +func testPingPong(t testing.TB, n int, afterHit func(string)) { + ball := NewBuiltinEqVar(false) + doneVar := NewVar(false) + hits := NewVar(0) + ready := NewVar(true) // The ball is ready for hitting. + var wg sync.WaitGroup + bat := func(from, to bool, noise string) { + defer wg.Done() + for !Atomically(func(tx *Tx) any { + if doneVar.Get(tx) { + return true + } + tx.Assert(ready.Get(tx)) + if ball.Get(tx) == from { + ball.Set(tx, to) + hits.Set(tx, hits.Get(tx)+1) + ready.Set(tx, false) + return false + } + return tx.Retry() + }).(bool) { + afterHit(noise) + AtomicSet(ready, true) + } + } + wg.Add(2) + go bat(false, true, "ping!") + go bat(true, false, "pong!") + Atomically(VoidOperation(func(tx *Tx) { + tx.Assert(hits.Get(tx) >= n) + doneVar.Set(tx, true) + })) + wg.Wait() +} + +func TestPingPong(t *testing.T) { + testPingPong(t, 42, func(s string) { t.Log(s) }) +} + +func TestSleepingBeauty(t *testing.T) { + /* + require.Panics(t, func() { + Atomically(func(tx *Tx) any { + tx.Assert(false) + return nil + }) + }) + */ +} + +//func TestRetryStack(t *testing.T) { +// v := NewVar[int](nil) +// go func() { +// i := 0 +// for { +// AtomicSet(v, i) +// i++ +// } +// }() +// Atomically(func(tx *Tx) any { +// debug.PrintStack() +// ret := func() { +// defer Atomically(nil) +// } +// v.Get(tx) +// tx.Assert(false) +// return ret +// }) +//} + +func TestContextEquality(t *testing.T) { + ctx := context.Background() + g.TAssertEqual(ctx, context.Background()) + childCtx, cancel := context.WithCancel(ctx) + g.TAssertEqual(childCtx != ctx, true) + g.TAssertEqual(childCtx != ctx, true) + g.TAssertEqual(context.Background(), ctx) + cancel() + g.TAssertEqual(context.Background(), ctx) + g.TAssertEqual(ctx != childCtx, true) +} + func MainTest() { @@ -1,231 +0,0 @@ -package stm - -import ( - "fmt" - "sort" - "sync" - "unsafe" - - "github.com/alecthomas/atomic" -) - -type txVar interface { - getValue() *atomic.Value[VarValue] - changeValue(any) - getWatchers() *sync.Map - getLock() *sync.Mutex -} - -// A Tx represents an atomic transaction. -type Tx struct { - reads map[txVar]VarValue - writes map[txVar]any - watching map[txVar]struct{} - locks txLocks - mu sync.Mutex - cond sync.Cond - waiting bool - completed bool - tries int - numRetryValues int -} - -// Check that none of the logged values have changed since the transaction began. -func (tx *Tx) inputsChanged() bool { - for v, read := range tx.reads { - if read.Changed(v.getValue().Load()) { - return true - } - } - return false -} - -// Writes the values in the transaction log to their respective Vars. -func (tx *Tx) commit() { - for v, val := range tx.writes { - v.changeValue(val) - } -} - -func (tx *Tx) updateWatchers() { - for v := range tx.watching { - if _, ok := tx.reads[v]; !ok { - delete(tx.watching, v) - v.getWatchers().Delete(tx) - } - } - for v := range tx.reads { - if _, ok := tx.watching[v]; !ok { - v.getWatchers().Store(tx, nil) - tx.watching[v] = struct{}{} - } - } -} - -// wait blocks until another transaction modifies any of the Vars read by tx. -func (tx *Tx) wait() { - if len(tx.reads) == 0 { - panic("not waiting on anything") - } - tx.updateWatchers() - tx.mu.Lock() - firstWait := true - for !tx.inputsChanged() { - if !firstWait { - expvars.Add("wakes for unchanged versions", 1) - } - expvars.Add("waits", 1) - tx.waiting = true - tx.cond.Broadcast() - tx.cond.Wait() - tx.waiting = false - firstWait = false - } - tx.mu.Unlock() -} - -// Get returns the value of v as of the start of the transaction. -func (v *Var[T]) Get(tx *Tx) T { - // If we previously wrote to v, it will be in the write log. - if val, ok := tx.writes[v]; ok { - return val.(T) - } - // If we haven't previously read v, record its version - vv, ok := tx.reads[v] - if !ok { - vv = v.getValue().Load() - tx.reads[v] = vv - } - return vv.Get().(T) -} - -// Set sets the value of a Var for the lifetime of the transaction. -func (v *Var[T]) Set(tx *Tx, val T) { - if v == nil { - panic("nil Var") - } - tx.writes[v] = val -} - -type txProfileValue struct { - *Tx - int -} - -// Retry aborts the transaction and retries it when a Var changes. You can return from this method -// to satisfy return values, but it should never actually return anything as it panics internally. -func (tx *Tx) Retry() struct{} { - retries.Add(txProfileValue{tx, tx.numRetryValues}, 1) - tx.numRetryValues++ - panic(retry) -} - -// Assert is a helper function that retries a transaction if the condition is -// not satisfied. -func (tx *Tx) Assert(p bool) { - if !p { - tx.Retry() - } -} - -func (tx *Tx) reset() { - tx.mu.Lock() - for k := range tx.reads { - delete(tx.reads, k) - } - for k := range tx.writes { - delete(tx.writes, k) - } - tx.mu.Unlock() - tx.removeRetryProfiles() - tx.resetLocks() -} - -func (tx *Tx) removeRetryProfiles() { - for tx.numRetryValues > 0 { - tx.numRetryValues-- - retries.Remove(txProfileValue{tx, tx.numRetryValues}) - } -} - -func (tx *Tx) recycle() { - for v := range tx.watching { - delete(tx.watching, v) - v.getWatchers().Delete(tx) - } - tx.removeRetryProfiles() - // I don't think we can reuse Txs, because the "completed" field should/needs to be set - // indefinitely after use. - //txPool.Put(tx) -} - -func (tx *Tx) lockAllVars() { - tx.resetLocks() - tx.collectAllLocks() - tx.sortLocks() - tx.lock() -} - -func (tx *Tx) resetLocks() { - tx.locks.clear() -} - -func (tx *Tx) collectReadLocks() { - for v := range tx.reads { - tx.locks.append(v.getLock()) - } -} - -func (tx *Tx) collectAllLocks() { - tx.collectReadLocks() - for v := range tx.writes { - if _, ok := tx.reads[v]; !ok { - tx.locks.append(v.getLock()) - } - } -} - -func (tx *Tx) sortLocks() { - sort.Sort(&tx.locks) -} - -func (tx *Tx) lock() { - for _, l := range tx.locks.mus { - l.Lock() - } -} - -func (tx *Tx) unlock() { - for _, l := range tx.locks.mus { - l.Unlock() - } -} - -func (tx *Tx) String() string { - return fmt.Sprintf("%[1]T %[1]p", tx) -} - -// Dedicated type avoids reflection in sort.Slice. -type txLocks struct { - mus []*sync.Mutex -} - -func (me txLocks) Len() int { - return len(me.mus) -} - -func (me txLocks) Less(i, j int) bool { - return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j])) -} - -func (me txLocks) Swap(i, j int) { - me.mus[i], me.mus[j] = me.mus[j], me.mus[i] -} - -func (me *txLocks) clear() { - me.mus = me.mus[:0] -} - -func (me *txLocks) append(mu *sync.Mutex) { - me.mus = append(me.mus, mu) -} diff --git a/var-value.go b/var-value.go deleted file mode 100644 index 3518bf6..0000000 --- a/var-value.go +++ /dev/null @@ -1,51 +0,0 @@ -package stm - -type VarValue interface { - Set(any) VarValue - Get() any - Changed(VarValue) bool -} - -type version uint64 - -type versionedValue[T any] struct { - value T - version version -} - -func (me versionedValue[T]) Set(newValue any) VarValue { - return versionedValue[T]{ - value: newValue.(T), - version: me.version + 1, - } -} - -func (me versionedValue[T]) Get() any { - return me.value -} - -func (me versionedValue[T]) Changed(other VarValue) bool { - return me.version != other.(versionedValue[T]).version -} - -type customVarValue[T any] struct { - value T - changed func(T, T) bool -} - -var _ VarValue = customVarValue[struct{}]{} - -func (me customVarValue[T]) Changed(other VarValue) bool { - return me.changed(me.value, other.(customVarValue[T]).value) -} - -func (me customVarValue[T]) Set(newValue any) VarValue { - return customVarValue[T]{ - value: newValue.(T), - changed: me.changed, - } -} - -func (me customVarValue[T]) Get() any { - return me.value -} @@ -1,76 +0,0 @@ -package stm - -import ( - "sync" - - "github.com/alecthomas/atomic" -) - -// Holds an STM variable. -type Var[T any] struct { - value atomic.Value[VarValue] - watchers sync.Map - mu sync.Mutex -} - -func (v *Var[T]) getValue() *atomic.Value[VarValue] { - return &v.value -} - -func (v *Var[T]) getWatchers() *sync.Map { - return &v.watchers -} - -func (v *Var[T]) getLock() *sync.Mutex { - return &v.mu -} - -func (v *Var[T]) changeValue(new any) { - old := v.value.Load() - newVarValue := old.Set(new) - v.value.Store(newVarValue) - if old.Changed(newVarValue) { - go v.wakeWatchers(newVarValue) - } -} - -func (v *Var[T]) wakeWatchers(new VarValue) { - v.watchers.Range(func(k, _ any) bool { - tx := k.(*Tx) - // We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we - // could signal it before it goes to sleep and it will miss the notification. - tx.mu.Lock() - if read := tx.reads[v]; read != nil && read.Changed(new) { - tx.cond.Broadcast() - for !tx.waiting && !tx.completed { - tx.cond.Wait() - } - } - tx.mu.Unlock() - return !v.value.Load().Changed(new) - }) -} - -// Returns a new STM variable. -func NewVar[T any](val T) *Var[T] { - v := &Var[T]{} - v.value.Store(versionedValue[T]{ - value: val, - }) - return v -} - -func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { - v := &Var[T]{} - v.value.Store(customVarValue[T]{ - value: val, - changed: changed, - }) - return v -} - -func NewBuiltinEqVar[T comparable](val T) *Var[T] { - return NewCustomVar(val, func(a, b T) bool { - return a != b - }) -} |