aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bench_test.go151
-rw-r--r--doc.go77
-rw-r--r--doc_test.go77
-rw-r--r--external_test.go151
-rw-r--r--funcs.go169
-rw-r--r--metrics.go7
-rw-r--r--rate/rate_test.go480
-rw-r--r--rate/ratelimit.go130
-rw-r--r--retry.go24
-rw-r--r--src/stm.go1105
-rw-r--r--stm_test.go274
-rw-r--r--stmutil/containers.go192
-rw-r--r--stmutil/context.go39
-rw-r--r--stmutil/context_test.go20
-rw-r--r--tests/stm.go1177
-rw-r--r--tx.go231
-rw-r--r--var-value.go51
-rw-r--r--var.go76
18 files changed, 2282 insertions, 2149 deletions
diff --git a/bench_test.go b/bench_test.go
deleted file mode 100644
index 0d40caf..0000000
--- a/bench_test.go
+++ /dev/null
@@ -1,151 +0,0 @@
-package stm
-
-import (
- "sync"
- "testing"
-
- "github.com/anacrolix/missinggo/iter"
-)
-
-func BenchmarkAtomicGet(b *testing.B) {
- x := NewVar(0)
- for i := 0; i < b.N; i++ {
- AtomicGet(x)
- }
-}
-
-func BenchmarkAtomicSet(b *testing.B) {
- x := NewVar(0)
- for i := 0; i < b.N; i++ {
- AtomicSet(x, 0)
- }
-}
-
-func BenchmarkIncrementSTM(b *testing.B) {
- for i := 0; i < b.N; i++ {
- // spawn 1000 goroutines that each increment x by 1
- x := NewVar(0)
- for i := 0; i < 1000; i++ {
- go Atomically(VoidOperation(func(tx *Tx) {
- cur := x.Get(tx)
- x.Set(tx, cur+1)
- }))
- }
- // wait for x to reach 1000
- Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(x.Get(tx) == 1000)
- }))
- }
-}
-
-func BenchmarkIncrementMutex(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var mu sync.Mutex
- x := 0
- for i := 0; i < 1000; i++ {
- go func() {
- mu.Lock()
- x++
- mu.Unlock()
- }()
- }
- for {
- mu.Lock()
- read := x
- mu.Unlock()
- if read == 1000 {
- break
- }
- }
- }
-}
-
-func BenchmarkIncrementChannel(b *testing.B) {
- for i := 0; i < b.N; i++ {
- c := make(chan int, 1)
- c <- 0
- for i := 0; i < 1000; i++ {
- go func() {
- c <- 1 + <-c
- }()
- }
- for {
- read := <-c
- if read == 1000 {
- break
- }
- c <- read
- }
- }
-}
-
-func BenchmarkReadVarSTM(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var wg sync.WaitGroup
- wg.Add(1000)
- x := NewVar(0)
- for i := 0; i < 1000; i++ {
- go func() {
- AtomicGet(x)
- wg.Done()
- }()
- }
- wg.Wait()
- }
-}
-
-func BenchmarkReadVarMutex(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var mu sync.Mutex
- var wg sync.WaitGroup
- wg.Add(1000)
- x := 0
- for i := 0; i < 1000; i++ {
- go func() {
- mu.Lock()
- _ = x
- mu.Unlock()
- wg.Done()
- }()
- }
- wg.Wait()
- }
-}
-
-func BenchmarkReadVarChannel(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var wg sync.WaitGroup
- wg.Add(1000)
- c := make(chan int)
- close(c)
- for i := 0; i < 1000; i++ {
- go func() {
- <-c
- wg.Done()
- }()
- }
- wg.Wait()
- }
-}
-
-func parallelPingPongs(b *testing.B, n int) {
- var wg sync.WaitGroup
- wg.Add(n)
- for range iter.N(n) {
- go func() {
- defer wg.Done()
- testPingPong(b, b.N, func(string) {})
- }()
- }
- wg.Wait()
-}
-
-func BenchmarkPingPong4(b *testing.B) {
- b.ReportAllocs()
- parallelPingPongs(b, 4)
-}
-
-func BenchmarkPingPong(b *testing.B) {
- b.ReportAllocs()
- parallelPingPongs(b, 1)
-}
diff --git a/doc.go b/doc.go
deleted file mode 100644
index c704c33..0000000
--- a/doc.go
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
-Package stm provides Software Transactional Memory operations for Go. This is
-an alternative to the standard way of writing concurrent code (channels and
-mutexes). STM makes it easy to perform arbitrarily complex operations in an
-atomic fashion. One of its primary advantages over traditional locking is that
-STM transactions are composable, whereas locking functions are not -- the
-composition will either deadlock or release the lock between functions (making
-it non-atomic).
-
-To begin, create an STM object that wraps the data you want to access
-concurrently.
-
- x := stm.NewVar[int](3)
-
-You can then use the Atomically method to atomically read and/or write the the
-data. This code atomically decrements x:
-
- stm.Atomically(func(tx *stm.Tx) {
- cur := x.Get(tx)
- x.Set(tx, cur-1)
- })
-
-An important part of STM transactions is retrying. At any point during the
-transaction, you can call tx.Retry(), which will abort the transaction, but
-not cancel it entirely. The call to Atomically will block until another call
-to Atomically finishes, at which point the transaction will be rerun.
-Specifically, one of the values read by the transaction (via tx.Get) must be
-updated before the transaction will be rerun. As an example, this code will
-try to decrement x, but will block as long as x is zero:
-
- stm.Atomically(func(tx *stm.Tx) {
- cur := x.Get(tx)
- if cur == 0 {
- tx.Retry()
- }
- x.Set(tx, cur-1)
- })
-
-Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other
-value will cancel the transaction; no values will be changed. However, it is
-the responsibility of the caller to catch such panics.
-
-Multiple transactions can be composed using Select. If the first transaction
-calls Retry, the next transaction will be run, and so on. If all of the
-transactions call Retry, the call will block and the entire selection will be
-retried. For example, this code implements the "decrement-if-nonzero"
-transaction above, but for two values. It will first try to decrement x, then
-y, and block if both values are zero.
-
- func dec(v *stm.Var[int]) {
- return func(tx *stm.Tx) {
- cur := v.Get(tx)
- if cur == 0 {
- tx.Retry()
- }
- v.Set(tx, cur-1)
- }
- }
-
- // Note that Select does not perform any work itself, but merely
- // returns a transaction function.
- stm.Atomically(stm.Select(dec(x), dec(y)))
-
-An important caveat: transactions must be idempotent (they should have the
-same effect every time they are invoked). This is because a transaction may be
-retried several times before successfully completing, meaning its side effects
-may execute more than once. This will almost certainly cause incorrect
-behavior. One common way to get around this is to build up a list of impure
-operations inside the transaction, and then perform them after the transaction
-completes.
-
-The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but
-Haskell can enforce at compile time that STM variables are not modified outside
-the STM monad. This is not possible in Go, so be especially careful when using
-pointers in your STM code. Remember: modifying a pointer is a side effect!
-*/
-package stm
diff --git a/doc_test.go b/doc_test.go
deleted file mode 100644
index ae1af9a..0000000
--- a/doc_test.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package stm_test
-
-import (
- "github.com/anacrolix/stm"
-)
-
-func Example() {
- // create a shared variable
- n := stm.NewVar(3)
-
- // read a variable
- var v int
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- v = n.Get(tx)
- }))
- // or:
- v = stm.AtomicGet(n)
- _ = v
-
- // write to a variable
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- n.Set(tx, 12)
- }))
- // or:
- stm.AtomicSet(n, 12)
-
- // update a variable
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := n.Get(tx)
- n.Set(tx, cur-1)
- }))
-
- // block until a condition is met
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := n.Get(tx)
- if cur != 0 {
- tx.Retry()
- }
- n.Set(tx, 10)
- }))
- // or:
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := n.Get(tx)
- tx.Assert(cur == 0)
- n.Set(tx, 10)
- }))
-
- // select among multiple (potentially blocking) transactions
- stm.Atomically(stm.Select(
- // this function blocks forever, so it will be skipped
- stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }),
-
- // this function will always succeed without blocking
- stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }),
-
- // this function will never run, because the previous
- // function succeeded
- stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }),
- ))
-
- // since Select is a normal transaction, if the entire select retries
- // (blocks), it will be retried as a whole:
- x := 0
- stm.Atomically(stm.Select(
- // this function will run twice, and succeed the second time
- stm.VoidOperation(func(tx *stm.Tx) { tx.Assert(x == 1) }),
-
- // this function will run once
- stm.VoidOperation(func(tx *stm.Tx) {
- x = 1
- tx.Retry()
- }),
- ))
- // But wait! Transactions are only retried when one of the Vars they read is
- // updated. Since x isn't a stm Var, this code will actually block forever --
- // but you get the idea.
-}
diff --git a/external_test.go b/external_test.go
deleted file mode 100644
index 1291cce..0000000
--- a/external_test.go
+++ /dev/null
@@ -1,151 +0,0 @@
-package stm_test
-
-import (
- "sync"
- "testing"
-
- "github.com/anacrolix/missinggo/iter"
- "github.com/anacrolix/stm"
- "github.com/anacrolix/stm/stmutil"
-)
-
-const maxTokens = 25
-
-func BenchmarkThunderingHerdCondVar(b *testing.B) {
- for i := 0; i < b.N; i++ {
- var mu sync.Mutex
- consumer := sync.NewCond(&mu)
- generator := sync.NewCond(&mu)
- done := false
- tokens := 0
- var pending sync.WaitGroup
- for range iter.N(1000) {
- pending.Add(1)
- go func() {
- mu.Lock()
- for {
- if tokens > 0 {
- tokens--
- generator.Signal()
- break
- }
- consumer.Wait()
- }
- mu.Unlock()
- pending.Done()
- }()
- }
- go func() {
- mu.Lock()
- for !done {
- if tokens < maxTokens {
- tokens++
- consumer.Signal()
- } else {
- generator.Wait()
- }
- }
- mu.Unlock()
- }()
- pending.Wait()
- mu.Lock()
- done = true
- generator.Signal()
- mu.Unlock()
- }
-
-}
-
-func BenchmarkThunderingHerd(b *testing.B) {
- for i := 0; i < b.N; i++ {
- done := stm.NewBuiltinEqVar(false)
- tokens := stm.NewBuiltinEqVar(0)
- pending := stm.NewBuiltinEqVar(0)
- for range iter.N(1000) {
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- pending.Set(tx, pending.Get(tx)+1)
- }))
- go func() {
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- t := tokens.Get(tx)
- if t > 0 {
- tokens.Set(tx, t-1)
- pending.Set(tx, pending.Get(tx)-1)
- } else {
- tx.Retry()
- }
- }))
- }()
- }
- go func() {
- for stm.Atomically(func(tx *stm.Tx) bool {
- if done.Get(tx) {
- return false
- }
- tx.Assert(tokens.Get(tx) < maxTokens)
- tokens.Set(tx, tokens.Get(tx)+1)
- return true
- }) {
- }
- }()
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(pending.Get(tx) == 0)
- }))
- stm.AtomicSet(done, true)
- }
-}
-
-func BenchmarkInvertedThunderingHerd(b *testing.B) {
- for i := 0; i < b.N; i++ {
- done := stm.NewBuiltinEqVar(false)
- tokens := stm.NewBuiltinEqVar(0)
- pending := stm.NewVar(stmutil.NewSet[*stm.Var[bool]]())
- for range iter.N(1000) {
- ready := stm.NewVar(false)
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- pending.Set(tx, pending.Get(tx).Add(ready))
- }))
- go func() {
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(ready.Get(tx))
- set := pending.Get(tx)
- if !set.Contains(ready) {
- panic("couldn't find ourselves in pending")
- }
- pending.Set(tx, set.Delete(ready))
- }))
- //b.Log("waiter finished")
- }()
- }
- go func() {
- for stm.Atomically(func(tx *stm.Tx) bool {
- if done.Get(tx) {
- return false
- }
- tx.Assert(tokens.Get(tx) < maxTokens)
- tokens.Set(tx, tokens.Get(tx)+1)
- return true
- }) {
- }
- }()
- go func() {
- for stm.Atomically(func(tx *stm.Tx) bool {
- tx.Assert(tokens.Get(tx) > 0)
- tokens.Set(tx, tokens.Get(tx)-1)
- pending.Get(tx).Range(func(ready *stm.Var[bool]) bool {
- if !ready.Get(tx) {
- ready.Set(tx, true)
- return false
- }
- return true
- })
- return !done.Get(tx)
- }) {
- }
- }()
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0)
- }))
- stm.AtomicSet(done, true)
- }
-}
diff --git a/funcs.go b/funcs.go
deleted file mode 100644
index 07d35ec..0000000
--- a/funcs.go
+++ /dev/null
@@ -1,169 +0,0 @@
-package stm
-
-import (
- "math/rand"
- "runtime/pprof"
- "sync"
- "time"
-)
-
-var (
- txPool = sync.Pool{New: func() any {
- expvars.Add("new txs", 1)
- tx := &Tx{
- reads: make(map[txVar]VarValue),
- writes: make(map[txVar]any),
- watching: make(map[txVar]struct{}),
- }
- tx.cond.L = &tx.mu
- return tx
- }}
- failedCommitsProfile *pprof.Profile
-)
-
-const (
- profileFailedCommits = false
- sleepBetweenRetries = false
-)
-
-func init() {
- if profileFailedCommits {
- failedCommitsProfile = pprof.NewProfile("stmFailedCommits")
- }
-}
-
-func newTx() *Tx {
- tx := txPool.Get().(*Tx)
- tx.tries = 0
- tx.completed = false
- return tx
-}
-
-func WouldBlock[R any](fn Operation[R]) (block bool) {
- tx := newTx()
- tx.reset()
- _, block = catchRetry(fn, tx)
- if len(tx.watching) != 0 {
- panic("shouldn't have installed any watchers")
- }
- tx.recycle()
- return
-}
-
-// Atomically executes the atomic function fn.
-func Atomically[R any](op Operation[R]) R {
- expvars.Add("atomically", 1)
- // run the transaction
- tx := newTx()
-retry:
- tx.tries++
- tx.reset()
- if sleepBetweenRetries {
- shift := int64(tx.tries - 1)
- const maxShift = 30
- if shift > maxShift {
- shift = maxShift
- }
- ns := int64(1) << shift
- d := time.Duration(rand.Int63n(ns))
- if d > 100*time.Microsecond {
- tx.updateWatchers()
- time.Sleep(time.Duration(ns))
- }
- }
- tx.mu.Lock()
- ret, retry := catchRetry(op, tx)
- tx.mu.Unlock()
- if retry {
- expvars.Add("retries", 1)
- // wait for one of the variables we read to change before retrying
- tx.wait()
- goto retry
- }
- // verify the read log
- tx.lockAllVars()
- if tx.inputsChanged() {
- tx.unlock()
- expvars.Add("failed commits", 1)
- if profileFailedCommits {
- failedCommitsProfile.Add(new(int), 0)
- }
- goto retry
- }
- // commit the write log and broadcast that variables have changed
- tx.commit()
- tx.mu.Lock()
- tx.completed = true
- tx.cond.Broadcast()
- tx.mu.Unlock()
- tx.unlock()
- expvars.Add("commits", 1)
- tx.recycle()
- return ret
-}
-
-// AtomicGet is a helper function that atomically reads a value.
-func AtomicGet[T any](v *Var[T]) T {
- return v.value.Load().Get().(T)
-}
-
-// AtomicSet is a helper function that atomically writes a value.
-func AtomicSet[T any](v *Var[T], val T) {
- v.mu.Lock()
- v.changeValue(val)
- v.mu.Unlock()
-}
-
-// Compose is a helper function that composes multiple transactions into a
-// single transaction.
-func Compose[R any](fns ...Operation[R]) Operation[struct{}] {
- return VoidOperation(func(tx *Tx) {
- for _, f := range fns {
- f(tx)
- }
- })
-}
-
-// Select runs the supplied functions in order. Execution stops when a
-// function succeeds without calling Retry. If no functions succeed, the
-// entire selection will be retried.
-func Select[R any](fns ...Operation[R]) Operation[R] {
- return func(tx *Tx) R {
- switch len(fns) {
- case 0:
- // empty Select blocks forever
- tx.Retry()
- panic("unreachable")
- case 1:
- return fns[0](tx)
- default:
- oldWrites := tx.writes
- tx.writes = make(map[txVar]any, len(oldWrites))
- for k, v := range oldWrites {
- tx.writes[k] = v
- }
- ret, retry := catchRetry(fns[0], tx)
- if retry {
- tx.writes = oldWrites
- return Select(fns[1:]...)(tx)
- } else {
- return ret
- }
- }
- }
-}
-
-type Operation[R any] func(*Tx) R
-
-func VoidOperation(f func(*Tx)) Operation[struct{}] {
- return func(tx *Tx) struct{} {
- f(tx)
- return struct{}{}
- }
-}
-
-func AtomicModify[T any](v *Var[T], f func(T) T) {
- Atomically(VoidOperation(func(tx *Tx) {
- v.Set(tx, f(v.Get(tx)))
- }))
-}
diff --git a/metrics.go b/metrics.go
deleted file mode 100644
index b8563e2..0000000
--- a/metrics.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package stm
-
-import (
- "expvar"
-)
-
-var expvars = expvar.NewMap("stm")
diff --git a/rate/rate_test.go b/rate/rate_test.go
deleted file mode 100644
index 3078b4c..0000000
--- a/rate/rate_test.go
+++ /dev/null
@@ -1,480 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.7
-// +build go1.7
-
-package rate
-
-import (
- "context"
- "math"
- "runtime"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-)
-
-func TestLimit(t *testing.T) {
- if Limit(10) == Inf {
- t.Errorf("Limit(10) == Inf should be false")
- }
-}
-
-func closeEnough(a, b Limit) bool {
- return (math.Abs(float64(a)/float64(b)) - 1.0) < 1e-9
-}
-
-func TestEvery(t *testing.T) {
- cases := []struct {
- interval time.Duration
- lim Limit
- }{
- {0, Inf},
- {-1, Inf},
- {1 * time.Nanosecond, Limit(1e9)},
- {1 * time.Microsecond, Limit(1e6)},
- {1 * time.Millisecond, Limit(1e3)},
- {10 * time.Millisecond, Limit(100)},
- {100 * time.Millisecond, Limit(10)},
- {1 * time.Second, Limit(1)},
- {2 * time.Second, Limit(0.5)},
- {time.Duration(2.5 * float64(time.Second)), Limit(0.4)},
- {4 * time.Second, Limit(0.25)},
- {10 * time.Second, Limit(0.1)},
- {time.Duration(math.MaxInt64), Limit(1e9 / float64(math.MaxInt64))},
- }
- for _, tc := range cases {
- lim := Every(tc.interval)
- if !closeEnough(lim, tc.lim) {
- t.Errorf("Every(%v) = %v want %v", tc.interval, lim, tc.lim)
- }
- }
-}
-
-const (
- d = 100 * time.Millisecond
-)
-
-var (
- t0 = time.Now()
- t1 = t0.Add(time.Duration(1) * d)
- t2 = t0.Add(time.Duration(2) * d)
- t3 = t0.Add(time.Duration(3) * d)
- t4 = t0.Add(time.Duration(4) * d)
- t5 = t0.Add(time.Duration(5) * d)
- t9 = t0.Add(time.Duration(9) * d)
-)
-
-type allow struct {
- t time.Time
- n int
- ok bool
-}
-
-//
-//func run(t *testing.T, lim *Limiter, allows []allow) {
-// for i, allow := range allows {
-// ok := lim.AllowN(allow.t, allow.n)
-// if ok != allow.ok {
-// t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v",
-// i, allow.t, allow.n, ok, allow.ok)
-// }
-// }
-//}
-//
-//func TestLimiterBurst1(t *testing.T) {
-// run(t, NewLimiter(10, 1), []allow{
-// {t0, 1, true},
-// {t0, 1, false},
-// {t0, 1, false},
-// {t1, 1, true},
-// {t1, 1, false},
-// {t1, 1, false},
-// {t2, 2, false}, // burst size is 1, so n=2 always fails
-// {t2, 1, true},
-// {t2, 1, false},
-// })
-//}
-//
-//func TestLimiterBurst3(t *testing.T) {
-// run(t, NewLimiter(10, 3), []allow{
-// {t0, 2, true},
-// {t0, 2, false},
-// {t0, 1, true},
-// {t0, 1, false},
-// {t1, 4, false},
-// {t2, 1, true},
-// {t3, 1, true},
-// {t4, 1, true},
-// {t4, 1, true},
-// {t4, 1, false},
-// {t4, 1, false},
-// {t9, 3, true},
-// {t9, 0, true},
-// })
-//}
-//
-//func TestLimiterJumpBackwards(t *testing.T) {
-// run(t, NewLimiter(10, 3), []allow{
-// {t1, 1, true}, // start at t1
-// {t0, 1, true}, // jump back to t0, two tokens remain
-// {t0, 1, true},
-// {t0, 1, false},
-// {t0, 1, false},
-// {t1, 1, true}, // got a token
-// {t1, 1, false},
-// {t1, 1, false},
-// {t2, 1, true}, // got another token
-// {t2, 1, false},
-// {t2, 1, false},
-// })
-//}
-
-// Ensure that tokensFromDuration doesn't produce
-// rounding errors by truncating nanoseconds.
-// See golang.org/issues/34861.
-func TestLimiter_noTruncationErrors(t *testing.T) {
- if !NewLimiter(0.7692307692307693, 1).Allow() {
- t.Fatal("expected true")
- }
-}
-
-func TestSimultaneousRequests(t *testing.T) {
- const (
- limit = 1
- burst = 5
- numRequests = 15
- )
- var (
- wg sync.WaitGroup
- numOK = uint32(0)
- )
-
- // Very slow replenishing bucket.
- lim := NewLimiter(limit, burst)
-
- // Tries to take a token, atomically updates the counter and decreases the wait
- // group counter.
- f := func() {
- defer wg.Done()
- if ok := lim.Allow(); ok {
- atomic.AddUint32(&numOK, 1)
- }
- }
-
- wg.Add(numRequests)
- for i := 0; i < numRequests; i++ {
- go f()
- }
- wg.Wait()
- if numOK != burst {
- t.Errorf("numOK = %d, want %d", numOK, burst)
- }
-}
-
-func TestLongRunningQPS(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
- if runtime.GOOS == "openbsd" {
- t.Skip("low resolution time.Sleep invalidates test (golang.org/issue/14183)")
- return
- }
-
- // The test runs for a few seconds executing many requests and then checks
- // that overall number of requests is reasonable.
- const (
- limit = 100
- burst = 100
- )
- var numOK = int32(0)
-
- lim := NewLimiter(limit, burst)
-
- var wg sync.WaitGroup
- f := func() {
- if ok := lim.Allow(); ok {
- atomic.AddInt32(&numOK, 1)
- }
- wg.Done()
- }
-
- start := time.Now()
- end := start.Add(5 * time.Second)
- for time.Now().Before(end) {
- wg.Add(1)
- go f()
-
- // This will still offer ~500 requests per second, but won't consume
- // outrageous amount of CPU.
- time.Sleep(2 * time.Millisecond)
- }
- wg.Wait()
- elapsed := time.Since(start)
- ideal := burst + (limit * float64(elapsed) / float64(time.Second))
-
- // We should never get more requests than allowed.
- if want := int32(ideal + 1); numOK > want {
- t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal)
- }
- // We should get very close to the number of requests allowed.
- if want := int32(0.999 * ideal); numOK < want {
- t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal)
- }
-}
-
-type request struct {
- t time.Time
- n int
- act time.Time
- ok bool
-}
-
-// dFromDuration converts a duration to a multiple of the global constant d
-func dFromDuration(dur time.Duration) int {
- // Adding a millisecond to be swallowed by the integer division
- // because we don't care about small inaccuracies
- return int((dur + time.Millisecond) / d)
-}
-
-// dSince returns multiples of d since t0
-func dSince(t time.Time) int {
- return dFromDuration(t.Sub(t0))
-}
-
-//
-//func runReserve(t *testing.T, lim *Limiter, req request) *Reservation {
-// return runReserveMax(t, lim, req, InfDuration)
-//}
-//
-//func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation {
-// r := lim.reserveN(req.t, req.n, maxReserve)
-// if r.ok && (dSince(r.timeToAct) != dSince(req.act)) || r.ok != req.ok {
-// t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)",
-// dSince(req.t), req.n, maxReserve, dSince(r.timeToAct), r.ok, dSince(req.act), req.ok)
-// }
-// return &r
-//}
-//
-//func TestSimpleReserve(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// runReserve(t, lim, request{t0, 2, t2, true})
-// runReserve(t, lim, request{t3, 2, t4, true})
-//}
-//
-//func TestMix(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst
-// runReserve(t, lim, request{t0, 2, t0, true})
-// run(t, lim, []allow{{t1, 2, false}}) // not enought tokens - don't allow
-// runReserve(t, lim, request{t1, 2, t2, true})
-// run(t, lim, []allow{{t1, 1, false}}) // negative tokens - don't allow
-// run(t, lim, []allow{{t3, 1, true}})
-//}
-//
-//func TestCancelInvalid(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 3, t3, false})
-// r.CancelAt(t0) // should have no effect
-// runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens
-//}
-//
-//func TestCancelLast(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 2, t2, true})
-// r.CancelAt(t1) // got 2 tokens back
-// runReserve(t, lim, request{t1, 2, t2, true})
-//}
-//
-//func TestCancelTooLate(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 2, t2, true})
-// r.CancelAt(t3) // too late to cancel - should have no effect
-// runReserve(t, lim, request{t3, 2, t4, true})
-//}
-//
-//func TestCancel0Tokens(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 1, t1, true})
-// runReserve(t, lim, request{t0, 1, t2, true})
-// r.CancelAt(t0) // got 0 tokens back
-// runReserve(t, lim, request{t0, 1, t3, true})
-//}
-//
-//func TestCancel1Token(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 2, t2, true})
-// runReserve(t, lim, request{t0, 1, t3, true})
-// r.CancelAt(t2) // got 1 token back
-// runReserve(t, lim, request{t2, 2, t4, true})
-//}
-//
-//func TestCancelMulti(t *testing.T) {
-// lim := NewLimiter(10, 4)
-//
-// runReserve(t, lim, request{t0, 4, t0, true})
-// rA := runReserve(t, lim, request{t0, 3, t3, true})
-// runReserve(t, lim, request{t0, 1, t4, true})
-// rC := runReserve(t, lim, request{t0, 1, t5, true})
-// rC.CancelAt(t1) // get 1 token back
-// rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved
-// runReserve(t, lim, request{t1, 3, t5, true})
-//}
-//
-//func TestReserveJumpBack(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1
-// runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst
-// runReserve(t, lim, request{t2, 2, t3, true})
-//}
-
-//func TestReserveJumpBackCancel(t *testing.T) {
-// lim := NewLimiter(10, 2)
-//
-// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1
-// r := runReserve(t, lim, request{t1, 2, t3, true})
-// runReserve(t, lim, request{t1, 1, t4, true})
-// r.CancelAt(t0) // cancel at t0, get 1 token back
-// runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst
-//}
-//
-//func TestReserveSetLimit(t *testing.T) {
-// lim := NewLimiter(5, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// runReserve(t, lim, request{t0, 2, t4, true})
-// lim.SetLimitAt(t2, 10)
-// runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst
-//}
-//
-//func TestReserveSetBurst(t *testing.T) {
-// lim := NewLimiter(5, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// runReserve(t, lim, request{t0, 2, t4, true})
-// lim.SetBurstAt(t3, 4)
-// runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst
-//}
-//
-//func TestReserveSetLimitCancel(t *testing.T) {
-// lim := NewLimiter(5, 2)
-//
-// runReserve(t, lim, request{t0, 2, t0, true})
-// r := runReserve(t, lim, request{t0, 2, t4, true})
-// lim.SetLimitAt(t2, 10)
-// r.CancelAt(t2) // 2 tokens back
-// runReserve(t, lim, request{t2, 2, t3, true})
-//}
-//
-//func TestReserveMax(t *testing.T) {
-// lim := NewLimiter(10, 2)
-// maxT := d
-//
-// runReserveMax(t, lim, request{t0, 2, t0, true}, maxT)
-// runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future
-// runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future
-//}
-
-type wait struct {
- name string
- ctx context.Context
- n int
- delay int // in multiples of d
- nilErr bool
-}
-
-func runWait(t *testing.T, lim *Limiter, w wait) {
- t.Helper()
- start := time.Now()
- err := lim.WaitN(w.ctx, w.n)
- delay := time.Since(start)
- if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) {
- errString := "<nil>"
- if !w.nilErr {
- errString = "<non-nil error>"
- }
- t.Errorf("lim.WaitN(%v, lim, %v) = %v with delay %v ; want %v with delay %v",
- w.name, w.n, err, delay, errString, d*time.Duration(w.delay))
- }
-}
-
-func TestWaitSimple(t *testing.T) {
- lim := NewLimiter(10, 3)
-
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- runWait(t, lim, wait{"already-cancelled", ctx, 1, 0, false})
-
- runWait(t, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false})
-
- runWait(t, lim, wait{"act-now", context.Background(), 2, 0, true})
- runWait(t, lim, wait{"act-later", context.Background(), 3, 2, true})
-}
-
-func TestWaitCancel(t *testing.T) {
- lim := NewLimiter(10, 3)
-
- ctx, cancel := context.WithCancel(context.Background())
- runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1
- go func() {
- time.Sleep(d)
- cancel()
- }()
- runWait(t, lim, wait{"will-cancel", ctx, 3, 1, false})
- // should get 3 tokens back, and have lim.tokens = 2
- //t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent)
- runWait(t, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true})
-}
-
-func TestWaitTimeout(t *testing.T) {
- lim := NewLimiter(10, 3)
-
- ctx, cancel := context.WithTimeout(context.Background(), d)
- defer cancel()
- runWait(t, lim, wait{"act-now", ctx, 2, 0, true})
- runWait(t, lim, wait{"w-timeout-err", ctx, 3, 0, false})
-}
-
-func TestWaitInf(t *testing.T) {
- lim := NewLimiter(Inf, 0)
-
- runWait(t, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true})
-}
-
-func BenchmarkAllowN(b *testing.B) {
- lim := NewLimiter(Every(1*time.Second), 1)
- b.ReportAllocs()
- b.ResetTimer()
- b.RunParallel(func(pb *testing.PB) {
- for pb.Next() {
- lim.AllowN(1)
- }
- })
-}
-
-func BenchmarkWaitNNoDelay(b *testing.B) {
- lim := NewLimiter(Limit(b.N), b.N)
- ctx := context.Background()
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- lim.WaitN(ctx, 1)
- }
-}
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
deleted file mode 100644
index f521f66..0000000
--- a/rate/ratelimit.go
+++ /dev/null
@@ -1,130 +0,0 @@
-package rate
-
-import (
- "context"
- "errors"
- "math"
- "time"
-
- "github.com/anacrolix/stm"
- "github.com/anacrolix/stm/stmutil"
-)
-
-type numTokens = int
-
-type Limiter struct {
- max *stm.Var[numTokens]
- cur *stm.Var[numTokens]
- lastAdd *stm.Var[time.Time]
- rate Limit
-}
-
-const Inf = Limit(math.MaxFloat64)
-
-type Limit float64
-
-func (l Limit) interval() time.Duration {
- return time.Duration(Limit(1*time.Second) / l)
-}
-
-func Every(interval time.Duration) Limit {
- if interval == 0 {
- return Inf
- }
- return Limit(time.Second / interval)
-}
-
-func NewLimiter(rate Limit, burst numTokens) *Limiter {
- rl := &Limiter{
- max: stm.NewVar(burst),
- cur: stm.NewBuiltinEqVar(burst),
- lastAdd: stm.NewVar(time.Now()),
- rate: rate,
- }
- if rate != Inf {
- go rl.tokenGenerator(rate.interval())
- }
- return rl
-}
-
-func (rl *Limiter) tokenGenerator(interval time.Duration) {
- for {
- lastAdd := stm.AtomicGet(rl.lastAdd)
- time.Sleep(time.Until(lastAdd.Add(interval)))
- now := time.Now()
- available := numTokens(now.Sub(lastAdd) / interval)
- if available < 1 {
- continue
- }
- stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := rl.cur.Get(tx)
- max := rl.max.Get(tx)
- tx.Assert(cur < max)
- newCur := cur + available
- if newCur > max {
- newCur = max
- }
- if newCur != cur {
- rl.cur.Set(tx, newCur)
- }
- rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available)))
- }))
- }
-}
-
-func (rl *Limiter) Allow() bool {
- return rl.AllowN(1)
-}
-
-func (rl *Limiter) AllowN(n numTokens) bool {
- return stm.Atomically(func(tx *stm.Tx) bool {
- return rl.takeTokens(tx, n)
- })
-}
-
-func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
- return rl.takeTokens(tx, 1)
-}
-
-func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
- if rl.rate == Inf {
- return true
- }
- cur := rl.cur.Get(tx)
- if cur >= n {
- rl.cur.Set(tx, cur-n)
- return true
- }
- return false
-}
-
-func (rl *Limiter) Wait(ctx context.Context) error {
- return rl.WaitN(ctx, 1)
-}
-
-func (rl *Limiter) WaitN(ctx context.Context, n int) error {
- ctxDone, cancel := stmutil.ContextDoneVar(ctx)
- defer cancel()
- if err := stm.Atomically(func(tx *stm.Tx) error {
- if ctxDone.Get(tx) {
- return ctx.Err()
- }
- if rl.takeTokens(tx, n) {
- return nil
- }
- if n > rl.max.Get(tx) {
- return errors.New("burst exceeded")
- }
- if dl, ok := ctx.Deadline(); ok {
- if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n {
- return context.DeadlineExceeded
- }
- }
- tx.Retry()
- panic("unreachable")
- }); err != nil {
- return err
- }
- return nil
-
-}
diff --git a/retry.go b/retry.go
deleted file mode 100644
index 1adcfd0..0000000
--- a/retry.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package stm
-
-import (
- "runtime/pprof"
-)
-
-var retries = pprof.NewProfile("stmRetries")
-
-// retry is a sentinel value. When thrown via panic, it indicates that a
-// transaction should be retried.
-var retry = &struct{}{}
-
-// catchRetry returns true if fn calls tx.Retry.
-func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) {
- defer func() {
- if r := recover(); r == retry {
- gotRetry = true
- } else if r != nil {
- panic(r)
- }
- }()
- result = fn(tx)
- return
-}
diff --git a/src/stm.go b/src/stm.go
index 53b4fe1..0740d0c 100644
--- a/src/stm.go
+++ b/src/stm.go
@@ -1,9 +1,1114 @@
+/// Package stm provides Software Transactional Memory operations for Go. This
+/// is an alternative to the standard way of writing concurrent code (channels
+/// and mutexes). STM makes it easy to perform arbitrarily complex operations
+/// in an atomic fashion. One of its primary advantages over traditional
+/// locking is that STM transactions are composable, whereas locking functions
+/// are not -- the composition will either deadlock or release the lock between
+/// functions (making it non-atomic).
+///
+/// To begin, create an STM object that wraps the data you want to access
+/// concurrently.
+///
+/// x := stm.NewVar[int](3)
+///
+/// You can then use the Atomically method to atomically read and/or write the
+/// the data. This code atomically decrements x:
+///
+/// stm.Atomically(func(tx *stm.Tx) {
+/// cur := x.Get(tx)
+/// x.Set(tx, cur-1)
+/// })
+///
+/// An important part of STM transactions is retrying. At any point during the
+/// transaction, you can call tx.Retry(), which will abort the transaction, but
+/// not cancel it entirely. The call to Atomically will block until another
+/// call to Atomically finishes, at which point the transaction will be rerun.
+/// Specifically, one of the values read by the transaction (via tx.Get) must be
+/// updated before the transaction will be rerun. As an example, this code will
+/// try to decrement x, but will block as long as x is zero:
+///
+/// stm.Atomically(func(tx *stm.Tx) {
+/// cur := x.Get(tx)
+/// if cur == 0 {
+/// tx.Retry()
+/// }
+/// x.Set(tx, cur-1)
+/// })
+///
+/// Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any
+/// other value will cancel the transaction; no values will be changed.
+/// However, it is the responsibility of the caller to catch such panics.
+///
+/// Multiple transactions can be composed using Select. If the first
+/// transaction calls Retry, the next transaction will be run, and so on. If
+/// all of the transactions call Retry, the call will block and the entire
+/// selection will be retried. For example, this code implements the
+/// "decrement-if-nonzero" transaction above, but for two values. It will
+/// first try to decrement x, then y, and block if both values are zero.
+///
+/// func dec(v *stm.Var[int]) {
+/// return func(tx *stm.Tx) {
+/// cur := v.Get(tx)
+/// if cur == 0 {
+/// tx.Retry()
+/// }
+/// v.Set(tx, cur-1)
+/// }
+/// }
+///
+/// // Note that Select does not perform any work itself, but merely
+/// // returns a transaction function.
+/// stm.Atomically(stm.Select(dec(x), dec(y)))
+///
+/// An important caveat: transactions must be idempotent (they should have the
+/// same effect every time they are invoked). This is because a transaction may
+/// be retried several times before successfully completing, meaning its side
+/// effects may execute more than once. This will almost certainly cause
+/// incorrect behavior. One common way to get around this is to build up a list
+/// of impure operations inside the transaction, and then perform them after the
+/// transaction completes.
+///
+/// The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but
+/// Haskell can enforce at compile time that STM variables are not modified
+/// outside the STM monad. This is not possible in Go, so be especially careful
+/// when using pointers in your STM code. Remember: modifying a pointer is a
+/// side effect!
package stm
import (
+ "context"
+ "errors"
+ "expvar"
+ "fmt"
+ "math"
+ "math/rand"
+ "runtime/pprof"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+ "unsafe"
+
+ "pds"
)
+// Package atomic contains type-safe atomic types.
+//
+// The zero value for the numeric types cannot be used. Use New*. The
+// rationale for this behaviour is that copying an atomic integer is not
+// reliable. Copying can be prevented by embedding sync.Mutex, but that bloats
+// the type.
+
+// Interface represents atomic operations on a value.
+type Interface[T any] interface {
+ // Load value atomically.
+ Load() T
+ // Store value atomically.
+ Store(value T)
+ // Swap the previous value with the new value atomically.
+ Swap(new T) (old T)
+ // CompareAndSwap the previous value with new if its value is "old".
+ CompareAndSwap(old, new T) (swapped bool)
+}
+
+var _ Interface[bool] = &Value[bool]{}
+
+// Value wraps any generic value in atomic load and store operations.
+//
+// The zero value should be initialised using [Value.Store] before use.
+type Value[T any] struct {
+ value atomic.Value
+}
+
+// New atomic Value.
+func New[T any](seed T) *Value[T] {
+ v := &Value[T]{}
+ v.value.Store(seed)
+ return v
+}
+
+// Load value atomically.
+//
+// Will panic if the value is nil.
+func (v *Value[T]) Load() (out T) {
+ value := v.value.Load()
+ if value == nil {
+ panic("nil value in atomic.Value")
+ }
+ return value.(T)
+}
+func (v *Value[T]) Store(value T) { v.value.Store(value) }
+func (v *Value[T]) Swap(new T) (old T) { return v.value.Swap(new).(T) }
+func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { return v.value.CompareAndSwap(old, new) }
+
+// atomicint defines the types that atomic integer operations are supported on.
+type atomicint interface {
+ int32 | uint32 | int64 | uint64
+}
+
+// Int expresses atomic operations on signed or unsigned integer values.
+type Int[T atomicint] interface {
+ Interface[T]
+ // Add a value and return the new result.
+ Add(delta T) (new T)
+}
+
+// Currently not supported by Go's generic type system:
+//
+// ./atomic.go:48:9: cannot use type switch on type parameter value v (variable of type T constrained by atomicint)
+//
+// // ForInt infers and creates an atomic Int[T] type for a value.
+// func ForInt[T atomicint](v T) Int[T] {
+// switch v.(type) {
+// case int32:
+// return NewInt32(v)
+// case uint32:
+// return NewUint32(v)
+// case int64:
+// return NewInt64(v)
+// case uint64:
+// return NewUint64(v)
+// }
+// panic("can't happen")
+// }
+
+// Int32 atomic value.
+//
+// Copying creates an alias. The zero value is not usable, use NewInt32.
+type Int32 struct{ value *int32 }
+
+// NewInt32 creates a new atomic integer with an initial value.
+func NewInt32(value int32) Int32 { return Int32{value: &value} }
+
+var _ Int[int32] = &Int32{}
+
+func (i Int32) Add(delta int32) (new int32) { return atomic.AddInt32(i.value, delta) }
+func (i Int32) Load() (val int32) { return atomic.LoadInt32(i.value) }
+func (i Int32) Store(val int32) { atomic.StoreInt32(i.value, val) }
+func (i Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(i.value, new) }
+func (i Int32) CompareAndSwap(old, new int32) (swapped bool) {
+ return atomic.CompareAndSwapInt32(i.value, old, new)
+}
+
+// Uint32 atomic value.
+//
+// Copying creates an alias.
+type Uint32 struct{ value *uint32 }
+
+var _ Int[uint32] = Uint32{}
+
+// NewUint32 creates a new atomic integer with an initial value.
+func NewUint32(value uint32) Uint32 { return Uint32{value: &value} }
+
+func (i Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(i.value, delta) }
+func (i Uint32) Load() (val uint32) { return atomic.LoadUint32(i.value) }
+func (i Uint32) Store(val uint32) { atomic.StoreUint32(i.value, val) }
+func (i Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(i.value, new) }
+func (i Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
+ return atomic.CompareAndSwapUint32(i.value, old, new)
+}
+
+// Int64 atomic value.
+//
+// Copying creates an alias.
+type Int64 struct{ value *int64 }
+
+var _ Int[int64] = Int64{}
+
+// NewInt64 creates a new atomic integer with an initial value.
+func NewInt64(value int64) Int64 { return Int64{value: &value} }
+
+func (i Int64) Add(delta int64) (new int64) { return atomic.AddInt64(i.value, delta) }
+func (i Int64) Load() (val int64) { return atomic.LoadInt64(i.value) }
+func (i Int64) Store(val int64) { atomic.StoreInt64(i.value, val) }
+func (i Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(i.value, new) }
+func (i Int64) CompareAndSwap(old, new int64) (swapped bool) {
+ return atomic.CompareAndSwapInt64(i.value, old, new)
+}
+
+// Uint64 atomic value.
+//
+// Copying creates an alias.
+type Uint64 struct{ value *uint64 }
+
+var _ Int[uint64] = Uint64{}
+
+// NewUint64 creates a new atomic integer with an initial value.
+func NewUint64(value uint64) Uint64 { return Uint64{value: &value} }
+
+func (i Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(i.value, delta) }
+func (i Uint64) Load() (val uint64) { return atomic.LoadUint64(i.value) }
+func (i Uint64) Store(val uint64) { atomic.StoreUint64(i.value, val) }
+func (i Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(i.value, new) }
+func (i Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
+ return atomic.CompareAndSwapUint64(i.value, old, new)
+}
+
+
func F() {
}
+
+var (
+ txPool = sync.Pool{New: func() any {
+ expvars.Add("new txs", 1)
+ tx := &Tx{
+ reads: make(map[txVar]VarValue),
+ writes: make(map[txVar]any),
+ watching: make(map[txVar]struct{}),
+ }
+ tx.cond.L = &tx.mu
+ return tx
+ }}
+ failedCommitsProfile *pprof.Profile
+)
+
+const (
+ profileFailedCommits = false
+ sleepBetweenRetries = false
+)
+
+func init() {
+ if profileFailedCommits {
+ failedCommitsProfile = pprof.NewProfile("stmFailedCommits")
+ }
+}
+
+func newTx() *Tx {
+ tx := txPool.Get().(*Tx)
+ tx.tries = 0
+ tx.completed = false
+ return tx
+}
+
+func WouldBlock[R any](fn Operation[R]) (block bool) {
+ tx := newTx()
+ tx.reset()
+ _, block = catchRetry(fn, tx)
+ if len(tx.watching) != 0 {
+ panic("shouldn't have installed any watchers")
+ }
+ tx.recycle()
+ return
+}
+
+// Atomically executes the atomic function fn.
+func Atomically[R any](op Operation[R]) R {
+ expvars.Add("atomically", 1)
+ // run the transaction
+ tx := newTx()
+retry:
+ tx.tries++
+ tx.reset()
+ if sleepBetweenRetries {
+ shift := int64(tx.tries - 1)
+ const maxShift = 30
+ if shift > maxShift {
+ shift = maxShift
+ }
+ ns := int64(1) << shift
+ d := time.Duration(rand.Int63n(ns))
+ if d > 100*time.Microsecond {
+ tx.updateWatchers()
+ time.Sleep(time.Duration(ns))
+ }
+ }
+ tx.mu.Lock()
+ ret, retry := catchRetry(op, tx)
+ tx.mu.Unlock()
+ if retry {
+ expvars.Add("retries", 1)
+ // wait for one of the variables we read to change before retrying
+ tx.wait()
+ goto retry
+ }
+ // verify the read log
+ tx.lockAllVars()
+ if tx.inputsChanged() {
+ tx.unlock()
+ expvars.Add("failed commits", 1)
+ if profileFailedCommits {
+ failedCommitsProfile.Add(new(int), 0)
+ }
+ goto retry
+ }
+ // commit the write log and broadcast that variables have changed
+ tx.commit()
+ tx.mu.Lock()
+ tx.completed = true
+ tx.cond.Broadcast()
+ tx.mu.Unlock()
+ tx.unlock()
+ expvars.Add("commits", 1)
+ tx.recycle()
+ return ret
+}
+
+// AtomicGet is a helper function that atomically reads a value.
+func AtomicGet[T any](v *Var[T]) T {
+ return v.value.Load().Get().(T)
+}
+
+// AtomicSet is a helper function that atomically writes a value.
+func AtomicSet[T any](v *Var[T], val T) {
+ v.mu.Lock()
+ v.changeValue(val)
+ v.mu.Unlock()
+}
+
+// Compose is a helper function that composes multiple transactions into a
+// single transaction.
+func Compose[R any](fns ...Operation[R]) Operation[struct{}] {
+ return VoidOperation(func(tx *Tx) {
+ for _, f := range fns {
+ f(tx)
+ }
+ })
+}
+
+// Select runs the supplied functions in order. Execution stops when a
+// function succeeds without calling Retry. If no functions succeed, the
+// entire selection will be retried.
+func Select[R any](fns ...Operation[R]) Operation[R] {
+ return func(tx *Tx) R {
+ switch len(fns) {
+ case 0:
+ // empty Select blocks forever
+ tx.Retry()
+ panic("unreachable")
+ case 1:
+ return fns[0](tx)
+ default:
+ oldWrites := tx.writes
+ tx.writes = make(map[txVar]any, len(oldWrites))
+ for k, v := range oldWrites {
+ tx.writes[k] = v
+ }
+ ret, retry := catchRetry(fns[0], tx)
+ if retry {
+ tx.writes = oldWrites
+ return Select(fns[1:]...)(tx)
+ } else {
+ return ret
+ }
+ }
+ }
+}
+
+type Operation[R any] func(*Tx) R
+
+func VoidOperation(f func(*Tx)) Operation[struct{}] {
+ return func(tx *Tx) struct{} {
+ f(tx)
+ return struct{}{}
+ }
+}
+
+func AtomicModify[T any](v *Var[T], f func(T) T) {
+ Atomically(VoidOperation(func(tx *Tx) {
+ v.Set(tx, f(v.Get(tx)))
+ }))
+}
+
+var expvars = expvar.NewMap("stm")
+
+type VarValue interface {
+ Set(any) VarValue
+ Get() any
+ Changed(VarValue) bool
+}
+
+type version uint64
+
+type versionedValue[T any] struct {
+ value T
+ version version
+}
+
+func (me versionedValue[T]) Set(newValue any) VarValue {
+ return versionedValue[T]{
+ value: newValue.(T),
+ version: me.version + 1,
+ }
+}
+
+func (me versionedValue[T]) Get() any {
+ return me.value
+}
+
+func (me versionedValue[T]) Changed(other VarValue) bool {
+ return me.version != other.(versionedValue[T]).version
+}
+
+type customVarValue[T any] struct {
+ value T
+ changed func(T, T) bool
+}
+
+var _ VarValue = customVarValue[struct{}]{}
+
+func (me customVarValue[T]) Changed(other VarValue) bool {
+ return me.changed(me.value, other.(customVarValue[T]).value)
+}
+
+func (me customVarValue[T]) Set(newValue any) VarValue {
+ return customVarValue[T]{
+ value: newValue.(T),
+ changed: me.changed,
+ }
+}
+
+func (me customVarValue[T]) Get() any {
+ return me.value
+}
+
+type txVar interface {
+ getValue() *Value[VarValue]
+ changeValue(any)
+ getWatchers() *sync.Map
+ getLock() *sync.Mutex
+}
+
+// A Tx represents an atomic transaction.
+type Tx struct {
+ reads map[txVar]VarValue
+ writes map[txVar]any
+ watching map[txVar]struct{}
+ locks txLocks
+ mu sync.Mutex
+ cond sync.Cond
+ waiting bool
+ completed bool
+ tries int
+ numRetryValues int
+}
+
+// Check that none of the logged values have changed since the transaction began.
+func (tx *Tx) inputsChanged() bool {
+ for v, read := range tx.reads {
+ if read.Changed(v.getValue().Load()) {
+ return true
+ }
+ }
+ return false
+}
+
+// Writes the values in the transaction log to their respective Vars.
+func (tx *Tx) commit() {
+ for v, val := range tx.writes {
+ v.changeValue(val)
+ }
+}
+
+func (tx *Tx) updateWatchers() {
+ for v := range tx.watching {
+ if _, ok := tx.reads[v]; !ok {
+ delete(tx.watching, v)
+ v.getWatchers().Delete(tx)
+ }
+ }
+ for v := range tx.reads {
+ if _, ok := tx.watching[v]; !ok {
+ v.getWatchers().Store(tx, nil)
+ tx.watching[v] = struct{}{}
+ }
+ }
+}
+
+// wait blocks until another transaction modifies any of the Vars read by tx.
+func (tx *Tx) wait() {
+ if len(tx.reads) == 0 {
+ panic("not waiting on anything")
+ }
+ tx.updateWatchers()
+ tx.mu.Lock()
+ firstWait := true
+ for !tx.inputsChanged() {
+ if !firstWait {
+ expvars.Add("wakes for unchanged versions", 1)
+ }
+ expvars.Add("waits", 1)
+ tx.waiting = true
+ tx.cond.Broadcast()
+ tx.cond.Wait()
+ tx.waiting = false
+ firstWait = false
+ }
+ tx.mu.Unlock()
+}
+
+// Get returns the value of v as of the start of the transaction.
+func (v *Var[T]) Get(tx *Tx) T {
+ // If we previously wrote to v, it will be in the write log.
+ if val, ok := tx.writes[v]; ok {
+ return val.(T)
+ }
+ // If we haven't previously read v, record its version
+ vv, ok := tx.reads[v]
+ if !ok {
+ vv = v.getValue().Load()
+ tx.reads[v] = vv
+ }
+ return vv.Get().(T)
+}
+
+// Set sets the value of a Var for the lifetime of the transaction.
+func (v *Var[T]) Set(tx *Tx, val T) {
+ if v == nil {
+ panic("nil Var")
+ }
+ tx.writes[v] = val
+}
+
+type txProfileValue struct {
+ *Tx
+ int
+}
+
+// Retry aborts the transaction and retries it when a Var changes. You can return from this method
+// to satisfy return values, but it should never actually return anything as it panics internally.
+func (tx *Tx) Retry() struct{} {
+ retries.Add(txProfileValue{tx, tx.numRetryValues}, 1)
+ tx.numRetryValues++
+ panic(retry)
+}
+
+// Assert is a helper function that retries a transaction if the condition is
+// not satisfied.
+func (tx *Tx) Assert(p bool) {
+ if !p {
+ tx.Retry()
+ }
+}
+
+func (tx *Tx) reset() {
+ tx.mu.Lock()
+ for k := range tx.reads {
+ delete(tx.reads, k)
+ }
+ for k := range tx.writes {
+ delete(tx.writes, k)
+ }
+ tx.mu.Unlock()
+ tx.removeRetryProfiles()
+ tx.resetLocks()
+}
+
+func (tx *Tx) removeRetryProfiles() {
+ for tx.numRetryValues > 0 {
+ tx.numRetryValues--
+ retries.Remove(txProfileValue{tx, tx.numRetryValues})
+ }
+}
+
+func (tx *Tx) recycle() {
+ for v := range tx.watching {
+ delete(tx.watching, v)
+ v.getWatchers().Delete(tx)
+ }
+ tx.removeRetryProfiles()
+ // I don't think we can reuse Txs, because the "completed" field should/needs to be set
+ // indefinitely after use.
+ //txPool.Put(tx)
+}
+
+func (tx *Tx) lockAllVars() {
+ tx.resetLocks()
+ tx.collectAllLocks()
+ tx.sortLocks()
+ tx.lock()
+}
+
+func (tx *Tx) resetLocks() {
+ tx.locks.clear()
+}
+
+func (tx *Tx) collectReadLocks() {
+ for v := range tx.reads {
+ tx.locks.append(v.getLock())
+ }
+}
+
+func (tx *Tx) collectAllLocks() {
+ tx.collectReadLocks()
+ for v := range tx.writes {
+ if _, ok := tx.reads[v]; !ok {
+ tx.locks.append(v.getLock())
+ }
+ }
+}
+
+func (tx *Tx) sortLocks() {
+ sort.Sort(&tx.locks)
+}
+
+func (tx *Tx) lock() {
+ for _, l := range tx.locks.mus {
+ l.Lock()
+ }
+}
+
+func (tx *Tx) unlock() {
+ for _, l := range tx.locks.mus {
+ l.Unlock()
+ }
+}
+
+func (tx *Tx) String() string {
+ return fmt.Sprintf("%[1]T %[1]p", tx)
+}
+
+// Dedicated type avoids reflection in sort.Slice.
+type txLocks struct {
+ mus []*sync.Mutex
+}
+
+func (me txLocks) Len() int {
+ return len(me.mus)
+}
+
+func (me txLocks) Less(i, j int) bool {
+ return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j]))
+}
+
+func (me txLocks) Swap(i, j int) {
+ me.mus[i], me.mus[j] = me.mus[j], me.mus[i]
+}
+
+func (me *txLocks) clear() {
+ me.mus = me.mus[:0]
+}
+
+func (me *txLocks) append(mu *sync.Mutex) {
+ me.mus = append(me.mus, mu)
+}
+
+// Holds an STM variable.
+type Var[T any] struct {
+ value Value[VarValue]
+ watchers sync.Map
+ mu sync.Mutex
+}
+
+func (v *Var[T]) getValue() *Value[VarValue] {
+ return &v.value
+}
+
+func (v *Var[T]) getWatchers() *sync.Map {
+ return &v.watchers
+}
+
+func (v *Var[T]) getLock() *sync.Mutex {
+ return &v.mu
+}
+
+func (v *Var[T]) changeValue(new any) {
+ old := v.value.Load()
+ newVarValue := old.Set(new)
+ v.value.Store(newVarValue)
+ if old.Changed(newVarValue) {
+ go v.wakeWatchers(newVarValue)
+ }
+}
+
+func (v *Var[T]) wakeWatchers(new VarValue) {
+ v.watchers.Range(func(k, _ any) bool {
+ tx := k.(*Tx)
+ // We have to lock here to ensure that the Tx is waiting before
+ // we signal it. Otherwise we could signal it before it goes to
+ // sleep and it will miss the notification.
+ tx.mu.Lock()
+ if read := tx.reads[v]; read != nil && read.Changed(new) {
+ tx.cond.Broadcast()
+ for !tx.waiting && !tx.completed {
+ tx.cond.Wait()
+ }
+ }
+ tx.mu.Unlock()
+ return !v.value.Load().Changed(new)
+ })
+}
+
+// Returns a new STM variable.
+func NewVar[T any](val T) *Var[T] {
+ v := &Var[T]{}
+ v.value.Store(versionedValue[T]{
+ value: val,
+ })
+ return v
+}
+
+func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] {
+ v := &Var[T]{}
+ v.value.Store(customVarValue[T]{
+ value: val,
+ changed: changed,
+ })
+ return v
+}
+
+func NewBuiltinEqVar[T comparable](val T) *Var[T] {
+ return NewCustomVar(val, func(a, b T) bool {
+ return a != b
+ })
+}
+
+var retries = pprof.NewProfile("stmRetries")
+
+// retry is a sentinel value. When thrown via panic, it indicates that a
+// transaction should be retried.
+var retry = &struct{}{}
+
+// catchRetry returns true if fn calls tx.Retry.
+func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) {
+ defer func() {
+ if r := recover(); r == retry {
+ gotRetry = true
+ } else if r != nil {
+ panic(r)
+ }
+ }()
+ result = fn(tx)
+ return
+}
+
+// This is the type constraint for keys passed through from pds
+type KeyConstraint interface {
+ comparable
+}
+
+type Settish[K KeyConstraint] interface {
+ Add(K) Settish[K]
+ Delete(K) Settish[K]
+ Contains(K) bool
+ Range(func(K) bool)
+ Len() int
+ // iter.Iterable
+}
+
+type mapToSet[K KeyConstraint] struct {
+ m Mappish[K, struct{}]
+}
+
+type interhash[K KeyConstraint] struct{}
+
+func (interhash[K]) Hash(x K) uint32 {
+ return uint32(nilinterhash(unsafe.Pointer(&x), 0))
+}
+
+func (interhash[K]) Equal(i, j K) bool {
+ return i == j
+}
+
+func NewSet[K KeyConstraint]() Settish[K] {
+ return mapToSet[K]{NewMap[K, struct{}]()}
+}
+
+func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] {
+ return mapToSet[K]{NewSortedMap[K, struct{}](lesser)}
+}
+
+func (s mapToSet[K]) Add(x K) Settish[K] {
+ s.m = s.m.Set(x, struct{}{})
+ return s
+}
+
+func (s mapToSet[K]) Delete(x K) Settish[K] {
+ s.m = s.m.Delete(x)
+ return s
+}
+
+func (s mapToSet[K]) Len() int {
+ return s.m.Len()
+}
+
+func (s mapToSet[K]) Contains(x K) bool {
+ _, ok := s.m.Get(x)
+ return ok
+}
+
+func (s mapToSet[K]) Range(f func(K) bool) {
+ s.m.Range(func(k K, _ struct{}) bool {
+ return f(k)
+ })
+}
+
+/*
+func (s mapToSet[K]) Iter(cb iter.Callback) {
+ s.Range(func(k K) bool {
+ return cb(k)
+ })
+}
+*/
+
+type Map[K KeyConstraint, V any] struct {
+ *pds.Map[K, V]
+}
+
+func NewMap[K KeyConstraint, V any]() Mappish[K, V] {
+ return Map[K, V]{pds.NewMap[K, V](interhash[K]{})}
+}
+
+func (m Map[K, V]) Delete(x K) Mappish[K, V] {
+ m.Map = m.Map.Delete(x)
+ return m
+}
+
+func (m Map[K, V]) Set(key K, value V) Mappish[K, V] {
+ m.Map = m.Map.Set(key, value)
+ return m
+}
+
+func (sm Map[K, V]) Range(f func(K, V) bool) {
+ iter := sm.Map.Iterator()
+ for {
+ k, v, ok := iter.Next()
+ if !ok {
+ break
+ }
+ if !f(k, v) {
+ return
+ }
+ }
+}
+
+/*
+func (sm Map[K, V]) Iter(cb iter.Callback) {
+ sm.Range(func(key K, _ V) bool {
+ return cb(key)
+ })
+}
+*/
+
+type SortedMap[K KeyConstraint, V any] struct {
+ *pds.SortedMap[K, V]
+}
+
+func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] {
+ sm.SortedMap = sm.SortedMap.Set(key, value)
+ return sm
+}
+
+func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] {
+ sm.SortedMap = sm.SortedMap.Delete(key)
+ return sm
+}
+
+func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) {
+ iter := sm.SortedMap.Iterator()
+ for {
+ k, v, ok := iter.Next()
+ if !ok {
+ break
+ }
+ if !f(k, v) {
+ return
+ }
+ }
+}
+
+/*
+func (sm SortedMap[K, V]) Iter(cb iter.Callback) {
+ sm.Range(func(key K, _ V) bool {
+ return cb(key)
+ })
+}
+*/
+
+type lessFunc[T KeyConstraint] func(l, r T) bool
+
+type comparer[K KeyConstraint] struct {
+ less lessFunc[K]
+}
+
+func (me comparer[K]) Compare(i, j K) int {
+ if me.less(i, j) {
+ return -1
+ } else if me.less(j, i) {
+ return 1
+ } else {
+ return 0
+ }
+}
+
+func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] {
+ return SortedMap[K, V]{
+ SortedMap: pds.NewSortedMap[K, V](comparer[K]{less}),
+ }
+}
+
+type Mappish[K, V any] interface {
+ Set(K, V) Mappish[K, V]
+ Delete(key K) Mappish[K, V]
+ Get(key K) (V, bool)
+ Range(func(K, V) bool)
+ Len() int
+ // iter.Iterable
+}
+
+func GetLeft(l, _ any) any {
+ return l
+}
+
+//go:noescape
+//go:linkname nilinterhash runtime.nilinterhash
+func nilinterhash(p unsafe.Pointer, h uintptr) uintptr
+
+func interfaceHash(x any) uint32 {
+ return uint32(nilinterhash(unsafe.Pointer(&x), 0))
+}
+
+type Lenner interface {
+ Len() int
+}
+
+var (
+ mu sync.Mutex
+ ctxVars = map[context.Context]*Var[bool]{}
+)
+
+// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be
+// called when the user is no longer interested in the var.
+func ContextDoneVar(ctx context.Context) (*Var[bool], func()) {
+ mu.Lock()
+ defer mu.Unlock()
+ if v, ok := ctxVars[ctx]; ok {
+ return v, func() {}
+ }
+ if ctx.Err() != nil {
+ // TODO: What if we had read-only Vars? Then we could have a global one for this that we
+ // just reuse.
+ v := NewBuiltinEqVar(true)
+ return v, func() {}
+ }
+ v := NewVar(false)
+ go func() {
+ <-ctx.Done()
+ AtomicSet(v, true)
+ mu.Lock()
+ delete(ctxVars, ctx)
+ mu.Unlock()
+ }()
+ ctxVars[ctx] = v
+ return v, func() {}
+}
+
+type numTokens = int
+
+type Limiter struct {
+ max *Var[numTokens]
+ cur *Var[numTokens]
+ lastAdd *Var[time.Time]
+ rate Limit
+}
+
+const Inf = Limit(math.MaxFloat64)
+
+type Limit float64
+
+func (l Limit) interval() time.Duration {
+ return time.Duration(Limit(1*time.Second) / l)
+}
+
+func Every(interval time.Duration) Limit {
+ if interval == 0 {
+ return Inf
+ }
+ return Limit(time.Second / interval)
+}
+
+func NewLimiter(rate Limit, burst numTokens) *Limiter {
+ rl := &Limiter{
+ max: NewVar(burst),
+ cur: NewBuiltinEqVar(burst),
+ lastAdd: NewVar(time.Now()),
+ rate: rate,
+ }
+ if rate != Inf {
+ go rl.tokenGenerator(rate.interval())
+ }
+ return rl
+}
+
+func (rl *Limiter) tokenGenerator(interval time.Duration) {
+ for {
+ lastAdd := AtomicGet(rl.lastAdd)
+ time.Sleep(time.Until(lastAdd.Add(interval)))
+ now := time.Now()
+ available := numTokens(now.Sub(lastAdd) / interval)
+ if available < 1 {
+ continue
+ }
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := rl.cur.Get(tx)
+ max := rl.max.Get(tx)
+ tx.Assert(cur < max)
+ newCur := cur + available
+ if newCur > max {
+ newCur = max
+ }
+ if newCur != cur {
+ rl.cur.Set(tx, newCur)
+ }
+ rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available)))
+ }))
+ }
+}
+
+func (rl *Limiter) Allow() bool {
+ return rl.AllowN(1)
+}
+
+func (rl *Limiter) AllowN(n numTokens) bool {
+ return Atomically(func(tx *Tx) bool {
+ return rl.takeTokens(tx, n)
+ })
+}
+
+func (rl *Limiter) AllowStm(tx *Tx) bool {
+ return rl.takeTokens(tx, 1)
+}
+
+func (rl *Limiter) takeTokens(tx *Tx, n numTokens) bool {
+ if rl.rate == Inf {
+ return true
+ }
+ cur := rl.cur.Get(tx)
+ if cur >= n {
+ rl.cur.Set(tx, cur-n)
+ return true
+ }
+ return false
+}
+
+func (rl *Limiter) Wait(ctx context.Context) error {
+ return rl.WaitN(ctx, 1)
+}
+
+func (rl *Limiter) WaitN(ctx context.Context, n int) error {
+ ctxDone, cancel := ContextDoneVar(ctx)
+ defer cancel()
+ if err := Atomically(func(tx *Tx) error {
+ if ctxDone.Get(tx) {
+ return ctx.Err()
+ }
+ if rl.takeTokens(tx, n) {
+ return nil
+ }
+ if n > rl.max.Get(tx) {
+ return errors.New("burst exceeded")
+ }
+ if dl, ok := ctx.Deadline(); ok {
+ if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n {
+ return context.DeadlineExceeded
+ }
+ }
+ tx.Retry()
+ panic("unreachable")
+ }); err != nil {
+ return err
+ }
+ return nil
+
+}
diff --git a/stm_test.go b/stm_test.go
deleted file mode 100644
index a98685b..0000000
--- a/stm_test.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package stm
-
-import (
- "sync"
- "testing"
- "time"
-
- _ "github.com/anacrolix/envpprof"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestDecrement(t *testing.T) {
- x := NewVar(1000)
- for i := 0; i < 500; i++ {
- go Atomically(VoidOperation(func(tx *Tx) {
- cur := x.Get(tx)
- x.Set(tx, cur-1)
- }))
- }
- done := make(chan struct{})
- go func() {
- Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(x.Get(tx) == 500)
- }))
- close(done)
- }()
- select {
- case <-done:
- case <-time.After(10 * time.Second):
- t.Fatal("decrement did not complete in time")
- }
-}
-
-// read-only transaction aren't exempt from calling tx.inputsChanged
-func TestReadVerify(t *testing.T) {
- read := make(chan struct{})
- x, y := NewVar(1), NewVar(2)
-
- // spawn a transaction that writes to x
- go func() {
- <-read
- AtomicSet(x, 3)
- read <- struct{}{}
- // other tx should retry, so we need to read/send again
- read <- <-read
- }()
-
- // spawn a transaction that reads x, then y. The other tx will modify x in
- // between the reads, causing this tx to retry.
- var x2, y2 int
- Atomically(VoidOperation(func(tx *Tx) {
- x2 = x.Get(tx)
- read <- struct{}{}
- <-read // wait for other tx to complete
- y2 = y.Get(tx)
- }))
- if x2 == 1 && y2 == 2 {
- t.Fatal("read was not verified")
- }
-}
-
-func TestRetry(t *testing.T) {
- x := NewVar(10)
- // spawn 10 transactions, one every 10 milliseconds. This will decrement x
- // to 0 over the course of 100 milliseconds.
- go func() {
- for i := 0; i < 10; i++ {
- time.Sleep(10 * time.Millisecond)
- Atomically(VoidOperation(func(tx *Tx) {
- cur := x.Get(tx)
- x.Set(tx, cur-1)
- }))
- }
- }()
- // Each time we read x before the above loop has finished, we need to
- // retry. This should result in no more than 1 retry per transaction.
- retry := 0
- Atomically(VoidOperation(func(tx *Tx) {
- cur := x.Get(tx)
- if cur != 0 {
- retry++
- tx.Retry()
- }
- }))
- if retry > 10 {
- t.Fatal("should have retried at most 10 times, got", retry)
- }
-}
-
-func TestVerify(t *testing.T) {
- // tx.inputsChanged should check more than pointer equality
- type foo struct {
- i int
- }
- x := NewVar(&foo{3})
- read := make(chan struct{})
-
- // spawn a transaction that modifies x
- go func() {
- Atomically(VoidOperation(func(tx *Tx) {
- <-read
- rx := x.Get(tx)
- rx.i = 7
- x.Set(tx, rx)
- }))
- read <- struct{}{}
- // other tx should retry, so we need to read/send again
- read <- <-read
- }()
-
- // spawn a transaction that reads x, then y. The other tx will modify x in
- // between the reads, causing this tx to retry.
- var i int
- Atomically(VoidOperation(func(tx *Tx) {
- f := x.Get(tx)
- i = f.i
- read <- struct{}{}
- <-read // wait for other tx to complete
- }))
- if i == 3 {
- t.Fatal("inputsChanged did not retry despite modified Var", i)
- }
-}
-
-func TestSelect(t *testing.T) {
- // empty Select should panic
- require.Panics(t, func() { Atomically(Select[struct{}]()) })
-
- // with one arg, Select adds no effect
- x := NewVar(2)
- Atomically(Select(VoidOperation(func(tx *Tx) {
- tx.Assert(x.Get(tx) == 2)
- })))
-
- picked := Atomically(Select(
- // always blocks; should never be selected
- func(tx *Tx) int {
- tx.Retry()
- panic("unreachable")
- },
- // always succeeds; should always be selected
- func(tx *Tx) int {
- return 2
- },
- // always succeeds; should never be selected
- func(tx *Tx) int {
- return 3
- },
- ))
- assert.EqualValues(t, 2, picked)
-}
-
-func TestCompose(t *testing.T) {
- nums := make([]int, 100)
- fns := make([]Operation[struct{}], 100)
- for i := range fns {
- fns[i] = func(x int) Operation[struct{}] {
- return VoidOperation(func(*Tx) { nums[x] = x })
- }(i) // capture loop var
- }
- Atomically(Compose(fns...))
- for i := range nums {
- if nums[i] != i {
- t.Error("Compose failed:", nums[i], i)
- }
- }
-}
-
-func TestPanic(t *testing.T) {
- // normal panics should escape Atomically
- assert.PanicsWithValue(t, "foo", func() {
- Atomically(func(*Tx) any {
- panic("foo")
- })
- })
-}
-
-func TestReadWritten(t *testing.T) {
- // reading a variable written in the same transaction should return the
- // previously written value
- x := NewVar(3)
- Atomically(VoidOperation(func(tx *Tx) {
- x.Set(tx, 5)
- tx.Assert(x.Get(tx) == 5)
- }))
-}
-
-func TestAtomicSetRetry(t *testing.T) {
- // AtomicSet should cause waiting transactions to retry
- x := NewVar(3)
- done := make(chan struct{})
- go func() {
- Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(x.Get(tx) == 5)
- }))
- done <- struct{}{}
- }()
- time.Sleep(10 * time.Millisecond)
- AtomicSet(x, 5)
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("AtomicSet did not wake up a waiting transaction")
- }
-}
-
-func testPingPong(t testing.TB, n int, afterHit func(string)) {
- ball := NewBuiltinEqVar(false)
- doneVar := NewVar(false)
- hits := NewVar(0)
- ready := NewVar(true) // The ball is ready for hitting.
- var wg sync.WaitGroup
- bat := func(from, to bool, noise string) {
- defer wg.Done()
- for !Atomically(func(tx *Tx) any {
- if doneVar.Get(tx) {
- return true
- }
- tx.Assert(ready.Get(tx))
- if ball.Get(tx) == from {
- ball.Set(tx, to)
- hits.Set(tx, hits.Get(tx)+1)
- ready.Set(tx, false)
- return false
- }
- return tx.Retry()
- }).(bool) {
- afterHit(noise)
- AtomicSet(ready, true)
- }
- }
- wg.Add(2)
- go bat(false, true, "ping!")
- go bat(true, false, "pong!")
- Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(hits.Get(tx) >= n)
- doneVar.Set(tx, true)
- }))
- wg.Wait()
-}
-
-func TestPingPong(t *testing.T) {
- testPingPong(t, 42, func(s string) { t.Log(s) })
-}
-
-func TestSleepingBeauty(t *testing.T) {
- require.Panics(t, func() {
- Atomically(func(tx *Tx) any {
- tx.Assert(false)
- return nil
- })
- })
-}
-
-//func TestRetryStack(t *testing.T) {
-// v := NewVar[int](nil)
-// go func() {
-// i := 0
-// for {
-// AtomicSet(v, i)
-// i++
-// }
-// }()
-// Atomically(func(tx *Tx) any {
-// debug.PrintStack()
-// ret := func() {
-// defer Atomically(nil)
-// }
-// v.Get(tx)
-// tx.Assert(false)
-// return ret
-// })
-//}
diff --git a/stmutil/containers.go b/stmutil/containers.go
deleted file mode 100644
index 0cc592d..0000000
--- a/stmutil/containers.go
+++ /dev/null
@@ -1,192 +0,0 @@
-package stmutil
-
-import (
- "unsafe"
-
- "github.com/anacrolix/missinggo/v2/iter"
- "github.com/benbjohnson/immutable"
-)
-
-// This is the type constraint for keys passed through from github.com/benbjohnson/immutable.
-type KeyConstraint interface {
- comparable
-}
-
-type Settish[K KeyConstraint] interface {
- Add(K) Settish[K]
- Delete(K) Settish[K]
- Contains(K) bool
- Range(func(K) bool)
- iter.Iterable
- Len() int
-}
-
-type mapToSet[K KeyConstraint] struct {
- m Mappish[K, struct{}]
-}
-
-type interhash[K KeyConstraint] struct{}
-
-func (interhash[K]) Hash(x K) uint32 {
- return uint32(nilinterhash(unsafe.Pointer(&x), 0))
-}
-
-func (interhash[K]) Equal(i, j K) bool {
- return i == j
-}
-
-func NewSet[K KeyConstraint]() Settish[K] {
- return mapToSet[K]{NewMap[K, struct{}]()}
-}
-
-func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] {
- return mapToSet[K]{NewSortedMap[K, struct{}](lesser)}
-}
-
-func (s mapToSet[K]) Add(x K) Settish[K] {
- s.m = s.m.Set(x, struct{}{})
- return s
-}
-
-func (s mapToSet[K]) Delete(x K) Settish[K] {
- s.m = s.m.Delete(x)
- return s
-}
-
-func (s mapToSet[K]) Len() int {
- return s.m.Len()
-}
-
-func (s mapToSet[K]) Contains(x K) bool {
- _, ok := s.m.Get(x)
- return ok
-}
-
-func (s mapToSet[K]) Range(f func(K) bool) {
- s.m.Range(func(k K, _ struct{}) bool {
- return f(k)
- })
-}
-
-func (s mapToSet[K]) Iter(cb iter.Callback) {
- s.Range(func(k K) bool {
- return cb(k)
- })
-}
-
-type Map[K KeyConstraint, V any] struct {
- *immutable.Map[K, V]
-}
-
-func NewMap[K KeyConstraint, V any]() Mappish[K, V] {
- return Map[K, V]{immutable.NewMap[K, V](interhash[K]{})}
-}
-
-func (m Map[K, V]) Delete(x K) Mappish[K, V] {
- m.Map = m.Map.Delete(x)
- return m
-}
-
-func (m Map[K, V]) Set(key K, value V) Mappish[K, V] {
- m.Map = m.Map.Set(key, value)
- return m
-}
-
-func (sm Map[K, V]) Range(f func(K, V) bool) {
- iter := sm.Map.Iterator()
- for {
- k, v, ok := iter.Next()
- if !ok {
- break
- }
- if !f(k, v) {
- return
- }
- }
-}
-
-func (sm Map[K, V]) Iter(cb iter.Callback) {
- sm.Range(func(key K, _ V) bool {
- return cb(key)
- })
-}
-
-type SortedMap[K KeyConstraint, V any] struct {
- *immutable.SortedMap[K, V]
-}
-
-func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] {
- sm.SortedMap = sm.SortedMap.Set(key, value)
- return sm
-}
-
-func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] {
- sm.SortedMap = sm.SortedMap.Delete(key)
- return sm
-}
-
-func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) {
- iter := sm.SortedMap.Iterator()
- for {
- k, v, ok := iter.Next()
- if !ok {
- break
- }
- if !f(k, v) {
- return
- }
- }
-}
-
-func (sm SortedMap[K, V]) Iter(cb iter.Callback) {
- sm.Range(func(key K, _ V) bool {
- return cb(key)
- })
-}
-
-type lessFunc[T KeyConstraint] func(l, r T) bool
-
-type comparer[K KeyConstraint] struct {
- less lessFunc[K]
-}
-
-func (me comparer[K]) Compare(i, j K) int {
- if me.less(i, j) {
- return -1
- } else if me.less(j, i) {
- return 1
- } else {
- return 0
- }
-}
-
-func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] {
- return SortedMap[K, V]{
- SortedMap: immutable.NewSortedMap[K, V](comparer[K]{less}),
- }
-}
-
-type Mappish[K, V any] interface {
- Set(K, V) Mappish[K, V]
- Delete(key K) Mappish[K, V]
- Get(key K) (V, bool)
- Range(func(K, V) bool)
- Len() int
- iter.Iterable
-}
-
-func GetLeft(l, _ any) any {
- return l
-}
-
-//go:noescape
-//go:linkname nilinterhash runtime.nilinterhash
-func nilinterhash(p unsafe.Pointer, h uintptr) uintptr
-
-func interfaceHash(x any) uint32 {
- return uint32(nilinterhash(unsafe.Pointer(&x), 0))
-}
-
-type Lenner interface {
- Len() int
-}
diff --git a/stmutil/context.go b/stmutil/context.go
deleted file mode 100644
index 6f8ba9b..0000000
--- a/stmutil/context.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package stmutil
-
-import (
- "context"
- "sync"
-
- "github.com/anacrolix/stm"
-)
-
-var (
- mu sync.Mutex
- ctxVars = map[context.Context]*stm.Var[bool]{}
-)
-
-// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be
-// called when the user is no longer interested in the var.
-func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) {
- mu.Lock()
- defer mu.Unlock()
- if v, ok := ctxVars[ctx]; ok {
- return v, func() {}
- }
- if ctx.Err() != nil {
- // TODO: What if we had read-only Vars? Then we could have a global one for this that we
- // just reuse.
- v := stm.NewBuiltinEqVar(true)
- return v, func() {}
- }
- v := stm.NewVar(false)
- go func() {
- <-ctx.Done()
- stm.AtomicSet(v, true)
- mu.Lock()
- delete(ctxVars, ctx)
- mu.Unlock()
- }()
- ctxVars[ctx] = v
- return v, func() {}
-}
diff --git a/stmutil/context_test.go b/stmutil/context_test.go
deleted file mode 100644
index 0a6b7c0..0000000
--- a/stmutil/context_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package stmutil
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestContextEquality(t *testing.T) {
- ctx := context.Background()
- assert.True(t, ctx == context.Background())
- childCtx, cancel := context.WithCancel(ctx)
- assert.True(t, childCtx != ctx)
- assert.True(t, childCtx != ctx)
- assert.Equal(t, context.Background(), ctx)
- cancel()
- assert.Equal(t, context.Background(), ctx)
- assert.NotEqual(t, ctx, childCtx)
-}
diff --git a/tests/stm.go b/tests/stm.go
index a30ac1a..816cdec 100644
--- a/tests/stm.go
+++ b/tests/stm.go
@@ -1,8 +1,1185 @@
package stm
import (
+ "context"
+ "fmt"
+ "math"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ g "gobang"
+)
+
+
+
+func TestValue(t *testing.T) {
+ v := New("hello")
+ g.TAssertEqual("hello", v.Load())
+ v.Store("world")
+ g.TAssertEqual("world", v.Load())
+}
+
+func TestValueZeroValue(t *testing.T) {
+ var v Value[string]
+ // assert.Panics(t, func() { v.Load() })
+ v.Store("world")
+ g.TAssertEqual("world", v.Load())
+}
+
+func TestValueSwapZeroValue(t *testing.T) {
+ // var v Value[string]
+ // assert.Panics(t, func() { v.Swap("hello") })
+}
+
+func TestInt32(t *testing.T) {
+ v := NewInt32(0)
+ g.TAssertEqual(0, v.Load())
+ g.TAssertEqual(true, v.CompareAndSwap(0, 10))
+ g.TAssertEqual(false, v.CompareAndSwap(0, 10))
+}
+
+func BenchmarkInt64Add(b *testing.B) {
+ v := NewInt64(0)
+ for i := 0; i < b.N; i++ {
+ v.Add(1)
+ }
+}
+
+func BenchmarkIntInterfaceAdd(b *testing.B) {
+ var v Int[int64] = NewInt64(0)
+ for i := 0; i < b.N; i++ {
+ v.Add(1)
+ }
+}
+
+func BenchmarkStdlibInt64Add(b *testing.B) {
+ var n int64
+ for i := 0; i < b.N; i++ {
+ atomic.AddInt64(&n, 1)
+ }
+}
+
+func BenchmarkInterfaceStore(b *testing.B) {
+ var v Interface[string] = New("hello")
+ for i := 0; i < b.N; i++ {
+ v.Store(fmt.Sprint(i))
+ }
+}
+
+func BenchmarkValueStore(b *testing.B) {
+ v := New("hello")
+ for i := 0; i < b.N; i++ {
+ v.Store(fmt.Sprint(i))
+ }
+}
+
+func BenchmarkStdlibValueStore(b *testing.B) {
+ v := atomic.Value{}
+ for i := 0; i < b.N; i++ {
+ v.Store(fmt.Sprint(i))
+ }
+}
+
+func BenchmarkAtomicGet(b *testing.B) {
+ x := NewVar(0)
+ for i := 0; i < b.N; i++ {
+ AtomicGet(x)
+ }
+}
+
+func BenchmarkAtomicSet(b *testing.B) {
+ x := NewVar(0)
+ for i := 0; i < b.N; i++ {
+ AtomicSet(x, 0)
+ }
+}
+
+func BenchmarkIncrementSTM(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ // spawn 1000 goroutines that each increment x by 1
+ x := NewVar(0)
+ for i := 0; i < 1000; i++ {
+ go Atomically(VoidOperation(func(tx *Tx) {
+ cur := x.Get(tx)
+ x.Set(tx, cur+1)
+ }))
+ }
+ // wait for x to reach 1000
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(x.Get(tx) == 1000)
+ }))
+ }
+}
+
+func BenchmarkIncrementMutex(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var mu sync.Mutex
+ x := 0
+ for i := 0; i < 1000; i++ {
+ go func() {
+ mu.Lock()
+ x++
+ mu.Unlock()
+ }()
+ }
+ for {
+ mu.Lock()
+ read := x
+ mu.Unlock()
+ if read == 1000 {
+ break
+ }
+ }
+ }
+}
+
+func BenchmarkIncrementChannel(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ c := make(chan int, 1)
+ c <- 0
+ for i := 0; i < 1000; i++ {
+ go func() {
+ c <- 1 + <-c
+ }()
+ }
+ for {
+ read := <-c
+ if read == 1000 {
+ break
+ }
+ c <- read
+ }
+ }
+}
+
+func BenchmarkReadVarSTM(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1000)
+ x := NewVar(0)
+ for i := 0; i < 1000; i++ {
+ go func() {
+ AtomicGet(x)
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ }
+}
+
+func BenchmarkReadVarMutex(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var mu sync.Mutex
+ var wg sync.WaitGroup
+ wg.Add(1000)
+ x := 0
+ for i := 0; i < 1000; i++ {
+ go func() {
+ mu.Lock()
+ _ = x
+ mu.Unlock()
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ }
+}
+
+func BenchmarkReadVarChannel(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1000)
+ c := make(chan int)
+ close(c)
+ for i := 0; i < 1000; i++ {
+ go func() {
+ <-c
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ }
+}
+
+func parallelPingPongs(b *testing.B, n int) {
+ var wg sync.WaitGroup
+ wg.Add(n)
+ for i := 0; i < n; i++ {
+ go func() {
+ defer wg.Done()
+ testPingPong(b, b.N, func(string) {})
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkPingPong4(b *testing.B) {
+ b.ReportAllocs()
+ parallelPingPongs(b, 4)
+}
+
+func BenchmarkPingPong(b *testing.B) {
+ b.ReportAllocs()
+ parallelPingPongs(b, 1)
+}
+
+func Example() {
+ // create a shared variable
+ n := NewVar(3)
+
+ // read a variable
+ var v int
+ Atomically(VoidOperation(func(tx *Tx) {
+ v = n.Get(tx)
+ }))
+ // or:
+ v = AtomicGet(n)
+ _ = v
+
+ // write to a variable
+ Atomically(VoidOperation(func(tx *Tx) {
+ n.Set(tx, 12)
+ }))
+ // or:
+ AtomicSet(n, 12)
+
+ // update a variable
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := n.Get(tx)
+ n.Set(tx, cur-1)
+ }))
+
+ // block until a condition is met
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := n.Get(tx)
+ if cur != 0 {
+ tx.Retry()
+ }
+ n.Set(tx, 10)
+ }))
+ // or:
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := n.Get(tx)
+ tx.Assert(cur == 0)
+ n.Set(tx, 10)
+ }))
+
+ // select among multiple (potentially blocking) transactions
+ Atomically(Select(
+ // this function blocks forever, so it will be skipped
+ VoidOperation(func(tx *Tx) { tx.Retry() }),
+
+ // this function will always succeed without blocking
+ VoidOperation(func(tx *Tx) { n.Set(tx, 10) }),
+
+ // this function will never run, because the previous
+ // function succeeded
+ VoidOperation(func(tx *Tx) { n.Set(tx, 11) }),
+ ))
+
+ // since Select is a normal transaction, if the entire select retries
+ // (blocks), it will be retried as a whole:
+ x := 0
+ Atomically(Select(
+ // this function will run twice, and succeed the second time
+ VoidOperation(func(tx *Tx) { tx.Assert(x == 1) }),
+
+ // this function will run once
+ VoidOperation(func(tx *Tx) {
+ x = 1
+ tx.Retry()
+ }),
+ ))
+ // But wait! Transactions are only retried when one of the Vars they read is
+ // updated. Since x isn't a stm Var, this code will actually block forever --
+ // but you get the idea.
+}
+
+const maxTokens = 25
+
+func BenchmarkThunderingHerdCondVar(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ var mu sync.Mutex
+ consumer := sync.NewCond(&mu)
+ generator := sync.NewCond(&mu)
+ done := false
+ tokens := 0
+ var pending sync.WaitGroup
+ for i := 0; i < 1000; i++ {
+ pending.Add(1)
+ go func() {
+ mu.Lock()
+ for {
+ if tokens > 0 {
+ tokens--
+ generator.Signal()
+ break
+ }
+ consumer.Wait()
+ }
+ mu.Unlock()
+ pending.Done()
+ }()
+ }
+ go func() {
+ mu.Lock()
+ for !done {
+ if tokens < maxTokens {
+ tokens++
+ consumer.Signal()
+ } else {
+ generator.Wait()
+ }
+ }
+ mu.Unlock()
+ }()
+ pending.Wait()
+ mu.Lock()
+ done = true
+ generator.Signal()
+ mu.Unlock()
+ }
+
+}
+
+func BenchmarkThunderingHerd(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ done := NewBuiltinEqVar(false)
+ tokens := NewBuiltinEqVar(0)
+ pending := NewBuiltinEqVar(0)
+ for i := 0; i < 1000; i++ {
+ Atomically(VoidOperation(func(tx *Tx) {
+ pending.Set(tx, pending.Get(tx)+1)
+ }))
+ go func() {
+ Atomically(VoidOperation(func(tx *Tx) {
+ t := tokens.Get(tx)
+ if t > 0 {
+ tokens.Set(tx, t-1)
+ pending.Set(tx, pending.Get(tx)-1)
+ } else {
+ tx.Retry()
+ }
+ }))
+ }()
+ }
+ go func() {
+ for Atomically(func(tx *Tx) bool {
+ if done.Get(tx) {
+ return false
+ }
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
+ return true
+ }) {
+ }
+ }()
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(pending.Get(tx) == 0)
+ }))
+ AtomicSet(done, true)
+ }
+}
+
+func BenchmarkInvertedThunderingHerd(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ done := NewBuiltinEqVar(false)
+ tokens := NewBuiltinEqVar(0)
+ pending := NewVar(NewSet[*Var[bool]]())
+ for i := 0; i < 1000; i++ {
+ ready := NewVar(false)
+ Atomically(VoidOperation(func(tx *Tx) {
+ pending.Set(tx, pending.Get(tx).Add(ready))
+ }))
+ go func() {
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(ready.Get(tx))
+ set := pending.Get(tx)
+ if !set.Contains(ready) {
+ panic("couldn't find ourselves in pending")
+ }
+ pending.Set(tx, set.Delete(ready))
+ }))
+ //b.Log("waiter finished")
+ }()
+ }
+ go func() {
+ for Atomically(func(tx *Tx) bool {
+ if done.Get(tx) {
+ return false
+ }
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
+ return true
+ }) {
+ }
+ }()
+ go func() {
+ for Atomically(func(tx *Tx) bool {
+ tx.Assert(tokens.Get(tx) > 0)
+ tokens.Set(tx, tokens.Get(tx)-1)
+ pending.Get(tx).Range(func(ready *Var[bool]) bool {
+ if !ready.Get(tx) {
+ ready.Set(tx, true)
+ return false
+ }
+ return true
+ })
+ return !done.Get(tx)
+ }) {
+ }
+ }()
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(pending.Get(tx).(Lenner).Len() == 0)
+ }))
+ AtomicSet(done, true)
+ }
+}
+
+func TestLimit(t *testing.T) {
+ if Limit(10) == Inf {
+ t.Errorf("Limit(10) == Inf should be false")
+ }
+}
+
+func closeEnough(a, b Limit) bool {
+ return (math.Abs(float64(a)/float64(b)) - 1.0) < 1e-9
+}
+
+func TestEvery(t *testing.T) {
+ cases := []struct {
+ interval time.Duration
+ lim Limit
+ }{
+ {0, Inf},
+ {-1, Inf},
+ {1 * time.Nanosecond, Limit(1e9)},
+ {1 * time.Microsecond, Limit(1e6)},
+ {1 * time.Millisecond, Limit(1e3)},
+ {10 * time.Millisecond, Limit(100)},
+ {100 * time.Millisecond, Limit(10)},
+ {1 * time.Second, Limit(1)},
+ {2 * time.Second, Limit(0.5)},
+ {time.Duration(2.5 * float64(time.Second)), Limit(0.4)},
+ {4 * time.Second, Limit(0.25)},
+ {10 * time.Second, Limit(0.1)},
+ {time.Duration(math.MaxInt64), Limit(1e9 / float64(math.MaxInt64))},
+ }
+ for _, tc := range cases {
+ lim := Every(tc.interval)
+ if !closeEnough(lim, tc.lim) {
+ t.Errorf("Every(%v) = %v want %v", tc.interval, lim, tc.lim)
+ }
+ }
+}
+
+const (
+ d = 100 * time.Millisecond
+)
+
+var (
+ t0 = time.Now()
+ t1 = t0.Add(time.Duration(1) * d)
+ t2 = t0.Add(time.Duration(2) * d)
+ t3 = t0.Add(time.Duration(3) * d)
+ t4 = t0.Add(time.Duration(4) * d)
+ t5 = t0.Add(time.Duration(5) * d)
+ t9 = t0.Add(time.Duration(9) * d)
)
+type allow struct {
+ t time.Time
+ n int
+ ok bool
+}
+
+//
+//func run(t *testing.T, lim *Limiter, allows []allow) {
+// for i, allow := range allows {
+// ok := lim.AllowN(allow.t, allow.n)
+// if ok != allow.ok {
+// t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v",
+// i, allow.t, allow.n, ok, allow.ok)
+// }
+// }
+//}
+//
+//func TestLimiterBurst1(t *testing.T) {
+// run(t, NewLimiter(10, 1), []allow{
+// {t0, 1, true},
+// {t0, 1, false},
+// {t0, 1, false},
+// {t1, 1, true},
+// {t1, 1, false},
+// {t1, 1, false},
+// {t2, 2, false}, // burst size is 1, so n=2 always fails
+// {t2, 1, true},
+// {t2, 1, false},
+// })
+//}
+//
+//func TestLimiterBurst3(t *testing.T) {
+// run(t, NewLimiter(10, 3), []allow{
+// {t0, 2, true},
+// {t0, 2, false},
+// {t0, 1, true},
+// {t0, 1, false},
+// {t1, 4, false},
+// {t2, 1, true},
+// {t3, 1, true},
+// {t4, 1, true},
+// {t4, 1, true},
+// {t4, 1, false},
+// {t4, 1, false},
+// {t9, 3, true},
+// {t9, 0, true},
+// })
+//}
+//
+//func TestLimiterJumpBackwards(t *testing.T) {
+// run(t, NewLimiter(10, 3), []allow{
+// {t1, 1, true}, // start at t1
+// {t0, 1, true}, // jump back to t0, two tokens remain
+// {t0, 1, true},
+// {t0, 1, false},
+// {t0, 1, false},
+// {t1, 1, true}, // got a token
+// {t1, 1, false},
+// {t1, 1, false},
+// {t2, 1, true}, // got another token
+// {t2, 1, false},
+// {t2, 1, false},
+// })
+//}
+
+// Ensure that tokensFromDuration doesn't produce
+// rounding errors by truncating nanoseconds.
+// See golang.org/issues/34861.
+func TestLimiter_noTruncationErrors(t *testing.T) {
+ if !NewLimiter(0.7692307692307693, 1).Allow() {
+ t.Fatal("expected true")
+ }
+}
+
+func TestSimultaneousRequests(t *testing.T) {
+ const (
+ limit = 1
+ burst = 5
+ numRequests = 15
+ )
+ var (
+ wg sync.WaitGroup
+ numOK = uint32(0)
+ )
+
+ // Very slow replenishing bucket.
+ lim := NewLimiter(limit, burst)
+
+ // Tries to take a token, atomically updates the counter and decreases the wait
+ // group counter.
+ f := func() {
+ defer wg.Done()
+ if ok := lim.Allow(); ok {
+ atomic.AddUint32(&numOK, 1)
+ }
+ }
+
+ wg.Add(numRequests)
+ for i := 0; i < numRequests; i++ {
+ go f()
+ }
+ wg.Wait()
+ if numOK != burst {
+ t.Errorf("numOK = %d, want %d", numOK, burst)
+ }
+}
+
+func TestLongRunningQPS(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ if runtime.GOOS == "openbsd" {
+ t.Skip("low resolution time.Sleep invalidates test (golang.org/issue/14183)")
+ return
+ }
+
+ // The test runs for a few seconds executing many requests and then checks
+ // that overall number of requests is reasonable.
+ const (
+ limit = 100
+ burst = 100
+ )
+ var numOK = int32(0)
+
+ lim := NewLimiter(limit, burst)
+
+ var wg sync.WaitGroup
+ f := func() {
+ if ok := lim.Allow(); ok {
+ atomic.AddInt32(&numOK, 1)
+ }
+ wg.Done()
+ }
+
+ start := time.Now()
+ end := start.Add(5 * time.Second)
+ for time.Now().Before(end) {
+ wg.Add(1)
+ go f()
+
+ // This will still offer ~500 requests per second, but won't consume
+ // outrageous amount of CPU.
+ time.Sleep(2 * time.Millisecond)
+ }
+ wg.Wait()
+ elapsed := time.Since(start)
+ ideal := burst + (limit * float64(elapsed) / float64(time.Second))
+
+ // We should never get more requests than allowed.
+ if want := int32(ideal + 1); numOK > want {
+ t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal)
+ }
+ // We should get very close to the number of requests allowed.
+ if want := int32(0.999 * ideal); numOK < want {
+ t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal)
+ }
+}
+
+type request struct {
+ t time.Time
+ n int
+ act time.Time
+ ok bool
+}
+
+// dFromDuration converts a duration to a multiple of the global constant d
+func dFromDuration(dur time.Duration) int {
+ // Adding a millisecond to be swallowed by the integer division
+ // because we don't care about small inaccuracies
+ return int((dur + time.Millisecond) / d)
+}
+
+// dSince returns multiples of d since t0
+func dSince(t time.Time) int {
+ return dFromDuration(t.Sub(t0))
+}
+
+//
+//func runReserve(t *testing.T, lim *Limiter, req request) *Reservation {
+// return runReserveMax(t, lim, req, InfDuration)
+//}
+//
+//func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation {
+// r := lim.reserveN(req.t, req.n, maxReserve)
+// if r.ok && (dSince(r.timeToAct) != dSince(req.act)) || r.ok != req.ok {
+// t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)",
+// dSince(req.t), req.n, maxReserve, dSince(r.timeToAct), r.ok, dSince(req.act), req.ok)
+// }
+// return &r
+//}
+//
+//func TestSimpleReserve(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// runReserve(t, lim, request{t0, 2, t2, true})
+// runReserve(t, lim, request{t3, 2, t4, true})
+//}
+//
+//func TestMix(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst
+// runReserve(t, lim, request{t0, 2, t0, true})
+// run(t, lim, []allow{{t1, 2, false}}) // not enought tokens - don't allow
+// runReserve(t, lim, request{t1, 2, t2, true})
+// run(t, lim, []allow{{t1, 1, false}}) // negative tokens - don't allow
+// run(t, lim, []allow{{t3, 1, true}})
+//}
+//
+//func TestCancelInvalid(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 3, t3, false})
+// r.CancelAt(t0) // should have no effect
+// runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens
+//}
+//
+//func TestCancelLast(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 2, t2, true})
+// r.CancelAt(t1) // got 2 tokens back
+// runReserve(t, lim, request{t1, 2, t2, true})
+//}
+//
+//func TestCancelTooLate(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 2, t2, true})
+// r.CancelAt(t3) // too late to cancel - should have no effect
+// runReserve(t, lim, request{t3, 2, t4, true})
+//}
+//
+//func TestCancel0Tokens(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 1, t1, true})
+// runReserve(t, lim, request{t0, 1, t2, true})
+// r.CancelAt(t0) // got 0 tokens back
+// runReserve(t, lim, request{t0, 1, t3, true})
+//}
+//
+//func TestCancel1Token(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 2, t2, true})
+// runReserve(t, lim, request{t0, 1, t3, true})
+// r.CancelAt(t2) // got 1 token back
+// runReserve(t, lim, request{t2, 2, t4, true})
+//}
+//
+//func TestCancelMulti(t *testing.T) {
+// lim := NewLimiter(10, 4)
+//
+// runReserve(t, lim, request{t0, 4, t0, true})
+// rA := runReserve(t, lim, request{t0, 3, t3, true})
+// runReserve(t, lim, request{t0, 1, t4, true})
+// rC := runReserve(t, lim, request{t0, 1, t5, true})
+// rC.CancelAt(t1) // get 1 token back
+// rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved
+// runReserve(t, lim, request{t1, 3, t5, true})
+//}
+//
+//func TestReserveJumpBack(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1
+// runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst
+// runReserve(t, lim, request{t2, 2, t3, true})
+//}
+
+//func TestReserveJumpBackCancel(t *testing.T) {
+// lim := NewLimiter(10, 2)
+//
+// runReserve(t, lim, request{t1, 2, t1, true}) // start at t1
+// r := runReserve(t, lim, request{t1, 2, t3, true})
+// runReserve(t, lim, request{t1, 1, t4, true})
+// r.CancelAt(t0) // cancel at t0, get 1 token back
+// runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst
+//}
+//
+//func TestReserveSetLimit(t *testing.T) {
+// lim := NewLimiter(5, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// runReserve(t, lim, request{t0, 2, t4, true})
+// lim.SetLimitAt(t2, 10)
+// runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst
+//}
+//
+//func TestReserveSetBurst(t *testing.T) {
+// lim := NewLimiter(5, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// runReserve(t, lim, request{t0, 2, t4, true})
+// lim.SetBurstAt(t3, 4)
+// runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst
+//}
+//
+//func TestReserveSetLimitCancel(t *testing.T) {
+// lim := NewLimiter(5, 2)
+//
+// runReserve(t, lim, request{t0, 2, t0, true})
+// r := runReserve(t, lim, request{t0, 2, t4, true})
+// lim.SetLimitAt(t2, 10)
+// r.CancelAt(t2) // 2 tokens back
+// runReserve(t, lim, request{t2, 2, t3, true})
+//}
+//
+//func TestReserveMax(t *testing.T) {
+// lim := NewLimiter(10, 2)
+// maxT := d
+//
+// runReserveMax(t, lim, request{t0, 2, t0, true}, maxT)
+// runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future
+// runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future
+//}
+
+type wait struct {
+ name string
+ ctx context.Context
+ n int
+ delay int // in multiples of d
+ nilErr bool
+}
+
+func runWait(t *testing.T, lim *Limiter, w wait) {
+ t.Helper()
+ start := time.Now()
+ err := lim.WaitN(w.ctx, w.n)
+ delay := time.Since(start)
+ if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) {
+ errString := "<nil>"
+ if !w.nilErr {
+ errString = "<non-nil error>"
+ }
+ t.Errorf("lim.WaitN(%v, lim, %v) = %v with delay %v ; want %v with delay %v",
+ w.name, w.n, err, delay, errString, d*time.Duration(w.delay))
+ }
+}
+
+func TestWaitSimple(t *testing.T) {
+ lim := NewLimiter(10, 3)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ runWait(t, lim, wait{"already-cancelled", ctx, 1, 0, false})
+
+ runWait(t, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false})
+
+ runWait(t, lim, wait{"act-now", context.Background(), 2, 0, true})
+ runWait(t, lim, wait{"act-later", context.Background(), 3, 2, true})
+}
+
+func TestWaitCancel(t *testing.T) {
+ lim := NewLimiter(10, 3)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1
+ go func() {
+ time.Sleep(d)
+ cancel()
+ }()
+ runWait(t, lim, wait{"will-cancel", ctx, 3, 1, false})
+ // should get 3 tokens back, and have lim.tokens = 2
+ //t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent)
+ runWait(t, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true})
+}
+
+func TestWaitTimeout(t *testing.T) {
+ lim := NewLimiter(10, 3)
+
+ ctx, cancel := context.WithTimeout(context.Background(), d)
+ defer cancel()
+ runWait(t, lim, wait{"act-now", ctx, 2, 0, true})
+ runWait(t, lim, wait{"w-timeout-err", ctx, 3, 0, false})
+}
+
+func TestWaitInf(t *testing.T) {
+ lim := NewLimiter(Inf, 0)
+
+ runWait(t, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true})
+}
+
+func BenchmarkAllowN(b *testing.B) {
+ lim := NewLimiter(Every(1*time.Second), 1)
+ b.ReportAllocs()
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ lim.AllowN(1)
+ }
+ })
+}
+
+func BenchmarkWaitNNoDelay(b *testing.B) {
+ lim := NewLimiter(Limit(b.N), b.N)
+ ctx := context.Background()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ lim.WaitN(ctx, 1)
+ }
+}
+
+func TestDecrement(t *testing.T) {
+ x := NewVar(1000)
+ for i := 0; i < 500; i++ {
+ go Atomically(VoidOperation(func(tx *Tx) {
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
+ }))
+ }
+ done := make(chan struct{})
+ go func() {
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(x.Get(tx) == 500)
+ }))
+ close(done)
+ }()
+ select {
+ case <-done:
+ case <-time.After(10 * time.Second):
+ t.Fatal("decrement did not complete in time")
+ }
+}
+
+// read-only transaction aren't exempt from calling tx.inputsChanged
+func TestReadVerify(t *testing.T) {
+ read := make(chan struct{})
+ x, y := NewVar(1), NewVar(2)
+
+ // spawn a transaction that writes to x
+ go func() {
+ <-read
+ AtomicSet(x, 3)
+ read <- struct{}{}
+ // other tx should retry, so we need to read/send again
+ read <- <-read
+ }()
+
+ // spawn a transaction that reads x, then y. The other tx will modify x in
+ // between the reads, causing this tx to retry.
+ var x2, y2 int
+ Atomically(VoidOperation(func(tx *Tx) {
+ x2 = x.Get(tx)
+ read <- struct{}{}
+ <-read // wait for other tx to complete
+ y2 = y.Get(tx)
+ }))
+ if x2 == 1 && y2 == 2 {
+ t.Fatal("read was not verified")
+ }
+}
+
+func TestRetry(t *testing.T) {
+ x := NewVar(10)
+ // spawn 10 transactions, one every 10 milliseconds. This will decrement x
+ // to 0 over the course of 100 milliseconds.
+ go func() {
+ for i := 0; i < 10; i++ {
+ time.Sleep(10 * time.Millisecond)
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
+ }))
+ }
+ }()
+ // Each time we read x before the above loop has finished, we need to
+ // retry. This should result in no more than 1 retry per transaction.
+ retry := 0
+ Atomically(VoidOperation(func(tx *Tx) {
+ cur := x.Get(tx)
+ if cur != 0 {
+ retry++
+ tx.Retry()
+ }
+ }))
+ if retry > 10 {
+ t.Fatal("should have retried at most 10 times, got", retry)
+ }
+}
+
+func TestVerify(t *testing.T) {
+ // tx.inputsChanged should check more than pointer equality
+ type foo struct {
+ i int
+ }
+ x := NewVar(&foo{3})
+ read := make(chan struct{})
+
+ // spawn a transaction that modifies x
+ go func() {
+ Atomically(VoidOperation(func(tx *Tx) {
+ <-read
+ rx := x.Get(tx)
+ rx.i = 7
+ x.Set(tx, rx)
+ }))
+ read <- struct{}{}
+ // other tx should retry, so we need to read/send again
+ read <- <-read
+ }()
+
+ // spawn a transaction that reads x, then y. The other tx will modify x in
+ // between the reads, causing this tx to retry.
+ var i int
+ Atomically(VoidOperation(func(tx *Tx) {
+ f := x.Get(tx)
+ i = f.i
+ read <- struct{}{}
+ <-read // wait for other tx to complete
+ }))
+ if i == 3 {
+ t.Fatal("inputsChanged did not retry despite modified Var", i)
+ }
+}
+
+func TestSelect(t *testing.T) {
+ // empty Select should panic
+ // require.Panics(t, func() { Atomically(Select[struct{}]()) })
+
+ // with one arg, Select adds no effect
+ x := NewVar(2)
+ Atomically(Select(VoidOperation(func(tx *Tx) {
+ tx.Assert(x.Get(tx) == 2)
+ })))
+
+ picked := Atomically(Select(
+ // always blocks; should never be selected
+ func(tx *Tx) int {
+ tx.Retry()
+ panic("unreachable")
+ },
+ // always succeeds; should always be selected
+ func(tx *Tx) int {
+ return 2
+ },
+ // always succeeds; should never be selected
+ func(tx *Tx) int {
+ return 3
+ },
+ ))
+ g.TAssertEqual(2, picked)
+}
+
+func TestCompose(t *testing.T) {
+ nums := make([]int, 100)
+ fns := make([]Operation[struct{}], 100)
+ for i := range fns {
+ fns[i] = func(x int) Operation[struct{}] {
+ return VoidOperation(func(*Tx) { nums[x] = x })
+ }(i) // capture loop var
+ }
+ Atomically(Compose(fns...))
+ for i := range nums {
+ if nums[i] != i {
+ t.Error("Compose failed:", nums[i], i)
+ }
+ }
+}
+
+func TestPanic(t *testing.T) {
+ // normal panics should escape Atomically
+ /*
+ assert.PanicsWithValue(t, "foo", func() {
+ Atomically(func(*Tx) any {
+ panic("foo")
+ })
+ })
+ */
+}
+
+func TestReadWritten(t *testing.T) {
+ // reading a variable written in the same transaction should return the
+ // previously written value
+ x := NewVar(3)
+ Atomically(VoidOperation(func(tx *Tx) {
+ x.Set(tx, 5)
+ tx.Assert(x.Get(tx) == 5)
+ }))
+}
+
+func TestAtomicSetRetry(t *testing.T) {
+ // AtomicSet should cause waiting transactions to retry
+ x := NewVar(3)
+ done := make(chan struct{})
+ go func() {
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(x.Get(tx) == 5)
+ }))
+ done <- struct{}{}
+ }()
+ time.Sleep(10 * time.Millisecond)
+ AtomicSet(x, 5)
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("AtomicSet did not wake up a waiting transaction")
+ }
+}
+
+func testPingPong(t testing.TB, n int, afterHit func(string)) {
+ ball := NewBuiltinEqVar(false)
+ doneVar := NewVar(false)
+ hits := NewVar(0)
+ ready := NewVar(true) // The ball is ready for hitting.
+ var wg sync.WaitGroup
+ bat := func(from, to bool, noise string) {
+ defer wg.Done()
+ for !Atomically(func(tx *Tx) any {
+ if doneVar.Get(tx) {
+ return true
+ }
+ tx.Assert(ready.Get(tx))
+ if ball.Get(tx) == from {
+ ball.Set(tx, to)
+ hits.Set(tx, hits.Get(tx)+1)
+ ready.Set(tx, false)
+ return false
+ }
+ return tx.Retry()
+ }).(bool) {
+ afterHit(noise)
+ AtomicSet(ready, true)
+ }
+ }
+ wg.Add(2)
+ go bat(false, true, "ping!")
+ go bat(true, false, "pong!")
+ Atomically(VoidOperation(func(tx *Tx) {
+ tx.Assert(hits.Get(tx) >= n)
+ doneVar.Set(tx, true)
+ }))
+ wg.Wait()
+}
+
+func TestPingPong(t *testing.T) {
+ testPingPong(t, 42, func(s string) { t.Log(s) })
+}
+
+func TestSleepingBeauty(t *testing.T) {
+ /*
+ require.Panics(t, func() {
+ Atomically(func(tx *Tx) any {
+ tx.Assert(false)
+ return nil
+ })
+ })
+ */
+}
+
+//func TestRetryStack(t *testing.T) {
+// v := NewVar[int](nil)
+// go func() {
+// i := 0
+// for {
+// AtomicSet(v, i)
+// i++
+// }
+// }()
+// Atomically(func(tx *Tx) any {
+// debug.PrintStack()
+// ret := func() {
+// defer Atomically(nil)
+// }
+// v.Get(tx)
+// tx.Assert(false)
+// return ret
+// })
+//}
+
+func TestContextEquality(t *testing.T) {
+ ctx := context.Background()
+ g.TAssertEqual(ctx, context.Background())
+ childCtx, cancel := context.WithCancel(ctx)
+ g.TAssertEqual(childCtx != ctx, true)
+ g.TAssertEqual(childCtx != ctx, true)
+ g.TAssertEqual(context.Background(), ctx)
+ cancel()
+ g.TAssertEqual(context.Background(), ctx)
+ g.TAssertEqual(ctx != childCtx, true)
+}
+
func MainTest() {
diff --git a/tx.go b/tx.go
deleted file mode 100644
index 825ff01..0000000
--- a/tx.go
+++ /dev/null
@@ -1,231 +0,0 @@
-package stm
-
-import (
- "fmt"
- "sort"
- "sync"
- "unsafe"
-
- "github.com/alecthomas/atomic"
-)
-
-type txVar interface {
- getValue() *atomic.Value[VarValue]
- changeValue(any)
- getWatchers() *sync.Map
- getLock() *sync.Mutex
-}
-
-// A Tx represents an atomic transaction.
-type Tx struct {
- reads map[txVar]VarValue
- writes map[txVar]any
- watching map[txVar]struct{}
- locks txLocks
- mu sync.Mutex
- cond sync.Cond
- waiting bool
- completed bool
- tries int
- numRetryValues int
-}
-
-// Check that none of the logged values have changed since the transaction began.
-func (tx *Tx) inputsChanged() bool {
- for v, read := range tx.reads {
- if read.Changed(v.getValue().Load()) {
- return true
- }
- }
- return false
-}
-
-// Writes the values in the transaction log to their respective Vars.
-func (tx *Tx) commit() {
- for v, val := range tx.writes {
- v.changeValue(val)
- }
-}
-
-func (tx *Tx) updateWatchers() {
- for v := range tx.watching {
- if _, ok := tx.reads[v]; !ok {
- delete(tx.watching, v)
- v.getWatchers().Delete(tx)
- }
- }
- for v := range tx.reads {
- if _, ok := tx.watching[v]; !ok {
- v.getWatchers().Store(tx, nil)
- tx.watching[v] = struct{}{}
- }
- }
-}
-
-// wait blocks until another transaction modifies any of the Vars read by tx.
-func (tx *Tx) wait() {
- if len(tx.reads) == 0 {
- panic("not waiting on anything")
- }
- tx.updateWatchers()
- tx.mu.Lock()
- firstWait := true
- for !tx.inputsChanged() {
- if !firstWait {
- expvars.Add("wakes for unchanged versions", 1)
- }
- expvars.Add("waits", 1)
- tx.waiting = true
- tx.cond.Broadcast()
- tx.cond.Wait()
- tx.waiting = false
- firstWait = false
- }
- tx.mu.Unlock()
-}
-
-// Get returns the value of v as of the start of the transaction.
-func (v *Var[T]) Get(tx *Tx) T {
- // If we previously wrote to v, it will be in the write log.
- if val, ok := tx.writes[v]; ok {
- return val.(T)
- }
- // If we haven't previously read v, record its version
- vv, ok := tx.reads[v]
- if !ok {
- vv = v.getValue().Load()
- tx.reads[v] = vv
- }
- return vv.Get().(T)
-}
-
-// Set sets the value of a Var for the lifetime of the transaction.
-func (v *Var[T]) Set(tx *Tx, val T) {
- if v == nil {
- panic("nil Var")
- }
- tx.writes[v] = val
-}
-
-type txProfileValue struct {
- *Tx
- int
-}
-
-// Retry aborts the transaction and retries it when a Var changes. You can return from this method
-// to satisfy return values, but it should never actually return anything as it panics internally.
-func (tx *Tx) Retry() struct{} {
- retries.Add(txProfileValue{tx, tx.numRetryValues}, 1)
- tx.numRetryValues++
- panic(retry)
-}
-
-// Assert is a helper function that retries a transaction if the condition is
-// not satisfied.
-func (tx *Tx) Assert(p bool) {
- if !p {
- tx.Retry()
- }
-}
-
-func (tx *Tx) reset() {
- tx.mu.Lock()
- for k := range tx.reads {
- delete(tx.reads, k)
- }
- for k := range tx.writes {
- delete(tx.writes, k)
- }
- tx.mu.Unlock()
- tx.removeRetryProfiles()
- tx.resetLocks()
-}
-
-func (tx *Tx) removeRetryProfiles() {
- for tx.numRetryValues > 0 {
- tx.numRetryValues--
- retries.Remove(txProfileValue{tx, tx.numRetryValues})
- }
-}
-
-func (tx *Tx) recycle() {
- for v := range tx.watching {
- delete(tx.watching, v)
- v.getWatchers().Delete(tx)
- }
- tx.removeRetryProfiles()
- // I don't think we can reuse Txs, because the "completed" field should/needs to be set
- // indefinitely after use.
- //txPool.Put(tx)
-}
-
-func (tx *Tx) lockAllVars() {
- tx.resetLocks()
- tx.collectAllLocks()
- tx.sortLocks()
- tx.lock()
-}
-
-func (tx *Tx) resetLocks() {
- tx.locks.clear()
-}
-
-func (tx *Tx) collectReadLocks() {
- for v := range tx.reads {
- tx.locks.append(v.getLock())
- }
-}
-
-func (tx *Tx) collectAllLocks() {
- tx.collectReadLocks()
- for v := range tx.writes {
- if _, ok := tx.reads[v]; !ok {
- tx.locks.append(v.getLock())
- }
- }
-}
-
-func (tx *Tx) sortLocks() {
- sort.Sort(&tx.locks)
-}
-
-func (tx *Tx) lock() {
- for _, l := range tx.locks.mus {
- l.Lock()
- }
-}
-
-func (tx *Tx) unlock() {
- for _, l := range tx.locks.mus {
- l.Unlock()
- }
-}
-
-func (tx *Tx) String() string {
- return fmt.Sprintf("%[1]T %[1]p", tx)
-}
-
-// Dedicated type avoids reflection in sort.Slice.
-type txLocks struct {
- mus []*sync.Mutex
-}
-
-func (me txLocks) Len() int {
- return len(me.mus)
-}
-
-func (me txLocks) Less(i, j int) bool {
- return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j]))
-}
-
-func (me txLocks) Swap(i, j int) {
- me.mus[i], me.mus[j] = me.mus[j], me.mus[i]
-}
-
-func (me *txLocks) clear() {
- me.mus = me.mus[:0]
-}
-
-func (me *txLocks) append(mu *sync.Mutex) {
- me.mus = append(me.mus, mu)
-}
diff --git a/var-value.go b/var-value.go
deleted file mode 100644
index 3518bf6..0000000
--- a/var-value.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package stm
-
-type VarValue interface {
- Set(any) VarValue
- Get() any
- Changed(VarValue) bool
-}
-
-type version uint64
-
-type versionedValue[T any] struct {
- value T
- version version
-}
-
-func (me versionedValue[T]) Set(newValue any) VarValue {
- return versionedValue[T]{
- value: newValue.(T),
- version: me.version + 1,
- }
-}
-
-func (me versionedValue[T]) Get() any {
- return me.value
-}
-
-func (me versionedValue[T]) Changed(other VarValue) bool {
- return me.version != other.(versionedValue[T]).version
-}
-
-type customVarValue[T any] struct {
- value T
- changed func(T, T) bool
-}
-
-var _ VarValue = customVarValue[struct{}]{}
-
-func (me customVarValue[T]) Changed(other VarValue) bool {
- return me.changed(me.value, other.(customVarValue[T]).value)
-}
-
-func (me customVarValue[T]) Set(newValue any) VarValue {
- return customVarValue[T]{
- value: newValue.(T),
- changed: me.changed,
- }
-}
-
-func (me customVarValue[T]) Get() any {
- return me.value
-}
diff --git a/var.go b/var.go
deleted file mode 100644
index ec6a81e..0000000
--- a/var.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package stm
-
-import (
- "sync"
-
- "github.com/alecthomas/atomic"
-)
-
-// Holds an STM variable.
-type Var[T any] struct {
- value atomic.Value[VarValue]
- watchers sync.Map
- mu sync.Mutex
-}
-
-func (v *Var[T]) getValue() *atomic.Value[VarValue] {
- return &v.value
-}
-
-func (v *Var[T]) getWatchers() *sync.Map {
- return &v.watchers
-}
-
-func (v *Var[T]) getLock() *sync.Mutex {
- return &v.mu
-}
-
-func (v *Var[T]) changeValue(new any) {
- old := v.value.Load()
- newVarValue := old.Set(new)
- v.value.Store(newVarValue)
- if old.Changed(newVarValue) {
- go v.wakeWatchers(newVarValue)
- }
-}
-
-func (v *Var[T]) wakeWatchers(new VarValue) {
- v.watchers.Range(func(k, _ any) bool {
- tx := k.(*Tx)
- // We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we
- // could signal it before it goes to sleep and it will miss the notification.
- tx.mu.Lock()
- if read := tx.reads[v]; read != nil && read.Changed(new) {
- tx.cond.Broadcast()
- for !tx.waiting && !tx.completed {
- tx.cond.Wait()
- }
- }
- tx.mu.Unlock()
- return !v.value.Load().Changed(new)
- })
-}
-
-// Returns a new STM variable.
-func NewVar[T any](val T) *Var[T] {
- v := &Var[T]{}
- v.value.Store(versionedValue[T]{
- value: val,
- })
- return v
-}
-
-func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] {
- v := &Var[T]{}
- v.value.Store(customVarValue[T]{
- value: val,
- changed: changed,
- })
- return v
-}
-
-func NewBuiltinEqVar[T comparable](val T) *Var[T] {
- return NewCustomVar(val, func(a, b T) bool {
- return a != b
- })
-}