diff options
-rw-r--r-- | .circleci/config.yml | 2 | ||||
-rw-r--r-- | README.md | 33 | ||||
-rw-r--r-- | bench_test.go | 6 | ||||
-rw-r--r-- | cmd/santa-example/main.go | 34 | ||||
-rw-r--r-- | doc.go | 26 | ||||
-rw-r--r-- | doc_test.go | 22 | ||||
-rw-r--r-- | external_test.go | 58 | ||||
-rw-r--r-- | funcs.go | 50 | ||||
-rw-r--r-- | go.mod | 10 | ||||
-rw-r--r-- | go.sum | 26 | ||||
-rw-r--r-- | rate/rate_test.go | 3 | ||||
-rw-r--r-- | rate/ratelimit.go | 34 | ||||
-rw-r--r-- | retry.go | 2 | ||||
-rw-r--r-- | stm_test.go | 77 | ||||
-rw-r--r-- | stmutil/containers.go | 54 | ||||
-rw-r--r-- | stmutil/context.go | 4 | ||||
-rw-r--r-- | tx.go | 40 | ||||
-rw-r--r-- | var-value.go | 40 | ||||
-rw-r--r-- | var.go | 50 |
19 files changed, 286 insertions, 285 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml index 0db2f6b..148a6e6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,7 +3,7 @@ jobs: build: machine: true environment: - GO_BRANCH: release-branch.go1.15 + GO_BRANCH: master steps: - run: echo $CIRCLE_WORKING_DIRECTORY - run: echo $PWD @@ -12,10 +12,9 @@ composition will either deadlock or release the lock between functions (making it non-atomic). The `stm` API tries to mimic that of Haskell's [`Control.Concurrent.STM`](https://hackage.haskell.org/package/stm-2.4.4.1/docs/Control-Concurrent-STM.html), but -this is not entirely possible due to Go's type system; we are forced to use -`interface{}` and type assertions. Furthermore, Haskell can enforce at compile -time that STM variables are not modified outside the STM monad. This is not -possible in Go, so be especially careful when using pointers in your STM code. +Haskell can enforce at compile time that STM variables are not modified outside +the STM monad. This is not possible in Go, so be especially careful when using +pointers in your STM code. Remember: modifying a pointer is a side effect! Unlike Haskell, data in Go is not immutable by default, which means you have to be careful when using STM to manage pointers. If two goroutines have access @@ -31,22 +30,22 @@ applications in Go. If you find this package useful, please tell us about it! See the package examples in the Go package docs for examples of common operations. -See [example_santa_test.go](example_santa_test.go) for a more complex example. +See [cmd/santa-example/main.go](cmd/santa-example/main.go) for a more complex example. ## Pointers -Note that `Operation` now returns a value of type `interface{}`, which isn't included in the +Note that `Operation` now returns a value of type `any`, which isn't included in the examples throughout the documentation yet. See the type signatures for `Atomically` and `Operation`. Be very careful when managing pointers inside transactions! (This includes slices, maps, channels, and captured variables.) Here's why: ```go -p := stm.NewVar([]byte{1,2,3}) +p := stm.NewVar[[]byte]([]byte{1,2,3}) stm.Atomically(func(tx *stm.Tx) { - b := tx.Get(p).([]byte) + b := p.Get(tx) b[0] = 7 - tx.Set(p, b) + stm.p.Set(tx, b) }) ``` @@ -57,11 +56,11 @@ Following this advice, we can rewrite the transaction to perform a copy: ```go stm.Atomically(func(tx *stm.Tx) { - b := tx.Get(p).([]byte) + b := p.Get(tx) c := make([]byte, len(b)) copy(c, b) c[0] = 7 - tx.Set(p, c) + p.Set(tx, c) }) ``` @@ -73,11 +72,11 @@ In the same vein, it would be a mistake to do this: type foo struct { i int } -p := stm.NewVar(&foo{i: 2}) +p := stm.NewVar[*foo](&foo{i: 2}) stm.Atomically(func(tx *stm.Tx) { - f := tx.Get(p).(*foo) + f := p.Get(tx) f.i = 7 - tx.Set(p, f) + stm.p.Set(tx, f) }) ``` @@ -88,11 +87,11 @@ the correct approach is to move the `Var` inside the struct: type foo struct { i *stm.Var } -f := foo{i: stm.NewVar(2)} +f := foo{i: stm.NewVar[int](2)} stm.Atomically(func(tx *stm.Tx) { - i := tx.Get(f.i).(int) + i := f.i.Get(tx) i = 7 - tx.Set(f.i, i) + f.i.Set(tx, i) }) ``` diff --git a/bench_test.go b/bench_test.go index 82aaba1..0d40caf 100644 --- a/bench_test.go +++ b/bench_test.go @@ -27,13 +27,13 @@ func BenchmarkIncrementSTM(b *testing.B) { x := NewVar(0) for i := 0; i < 1000; i++ { go Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur+1) + cur := x.Get(tx) + x.Set(tx, cur+1) })) } // wait for x to reach 1000 Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 1000) + tx.Assert(x.Get(tx) == 1000) })) } } diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go index 3c2b9ea..dcc8067 100644 --- a/cmd/santa-example/main.go +++ b/cmd/santa-example/main.go @@ -39,15 +39,15 @@ import ( type gate struct { capacity int - remaining *stm.Var + remaining *stm.Var[int] } func (g gate) pass() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) // wait until gate can hold us tx.Assert(rem > 0) - tx.Set(g.remaining, rem-1) + g.remaining.Set(tx, rem-1) })) } @@ -56,7 +56,7 @@ func (g gate) operate() { stm.AtomicSet(g.remaining, g.capacity) // wait for gate to be full stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) tx.Assert(rem == 0) })) } @@ -70,8 +70,8 @@ func newGate(capacity int) gate { type group struct { capacity int - remaining *stm.Var - gate1, gate2 *stm.Var + remaining *stm.Var[int] + gate1, gate2 *stm.Var[gate] } func newGroup(capacity int) *group { @@ -85,28 +85,28 @@ func newGroup(capacity int) *group { func (g *group) join() (g1, g2 gate) { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) // wait until the group can hold us tx.Assert(rem > 0) - tx.Set(g.remaining, rem-1) + g.remaining.Set(tx, rem-1) // return the group's gates - g1 = tx.Get(g.gate1).(gate) - g2 = tx.Get(g.gate2).(gate) + g1 = g.gate1.Get(tx) + g2 = g.gate2.Get(tx) })) return } func (g *group) await(tx *stm.Tx) (gate, gate) { // wait for group to be empty - rem := tx.Get(g.remaining).(int) + rem := g.remaining.Get(tx) tx.Assert(rem == 0) // get the group's gates - g1 := tx.Get(g.gate1).(gate) - g2 := tx.Get(g.gate2).(gate) + g1 := g.gate1.Get(tx) + g2 := g.gate2.Get(tx) // reset group - tx.Set(g.remaining, g.capacity) - tx.Set(g.gate1, newGate(g.capacity)) - tx.Set(g.gate2, newGate(g.capacity)) + g.remaining.Set(tx, g.capacity) + g.gate1.Set(tx, newGate(g.capacity)) + g.gate2.Set(tx, newGate(g.capacity)) return g1, g2 } @@ -137,7 +137,7 @@ type selection struct { gate1, gate2 gate } -func chooseGroup(g *group, task string, s *selection) stm.Operation { +func chooseGroup(g *group, task string, s *selection) stm.Operation[struct{}] { return stm.VoidOperation(func(tx *stm.Tx) { s.gate1, s.gate2 = g.await(tx) s.task = task @@ -10,14 +10,14 @@ it non-atomic). To begin, create an STM object that wraps the data you want to access concurrently. - x := stm.NewVar(3) + x := stm.NewVar[int](3) You can then use the Atomically method to atomically read and/or write the the data. This code atomically decrements x: stm.Atomically(func(tx *stm.Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) }) An important part of STM transactions is retrying. At any point during the @@ -29,11 +29,11 @@ updated before the transaction will be rerun. As an example, this code will try to decrement x, but will block as long as x is zero: stm.Atomically(func(tx *stm.Tx) { - cur := tx.Get(x).(int) + cur := x.Get(tx) if cur == 0 { tx.Retry() } - tx.Set(x, cur-1) + x.Set(tx, cur-1) }) Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other @@ -47,13 +47,13 @@ retried. For example, this code implements the "decrement-if-nonzero" transaction above, but for two values. It will first try to decrement x, then y, and block if both values are zero. - func dec(v *stm.Var) { + func dec(v *stm.Var[int]) { return func(tx *stm.Tx) { - cur := tx.Get(v).(int) + cur := v.Get(tx) if cur == 0 { tx.Retry() } - tx.Set(v, cur-1) + v.Set(tx, cur-1) } } @@ -69,11 +69,9 @@ behavior. One common way to get around this is to build up a list of impure operations inside the transaction, and then perform them after the transaction completes. -The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but this -is not entirely possible due to Go's type system; we are forced to use -interface{} and type assertions. Furthermore, Haskell can enforce at compile -time that STM variables are not modified outside the STM monad. This is not -possible in Go, so be especially careful when using pointers in your STM code. -Remember: modifying a pointer is a side effect! +The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but +Haskell can enforce at compile time that STM variables are not modified outside +the STM monad. This is not possible in Go, so be especially careful when using +pointers in your STM code. Remember: modifying a pointer is a side effect! */ package stm diff --git a/doc_test.go b/doc_test.go index f6bb863..ae1af9a 100644 --- a/doc_test.go +++ b/doc_test.go @@ -11,38 +11,38 @@ func Example() { // read a variable var v int stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - v = tx.Get(n).(int) + v = n.Get(tx) })) // or: - v = stm.AtomicGet(n).(int) + v = stm.AtomicGet(n) _ = v // write to a variable stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(n, 12) + n.Set(tx, 12) })) // or: stm.AtomicSet(n, 12) // update a variable stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) - tx.Set(n, cur-1) + cur := n.Get(tx) + n.Set(tx, cur-1) })) // block until a condition is met stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) + cur := n.Get(tx) if cur != 0 { tx.Retry() } - tx.Set(n, 10) + n.Set(tx, 10) })) // or: stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(n).(int) + cur := n.Get(tx) tx.Assert(cur == 0) - tx.Set(n, 10) + n.Set(tx, 10) })) // select among multiple (potentially blocking) transactions @@ -51,11 +51,11 @@ func Example() { stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }), // this function will always succeed without blocking - stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 10) }), + stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }), // this function will never run, because the previous // function succeeded - stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 11) }), + stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }), )) // since Select is a normal transaction, if the entire select retries diff --git a/external_test.go b/external_test.go index ae29ca8..abdf544 100644 --- a/external_test.go +++ b/external_test.go @@ -63,14 +63,14 @@ func BenchmarkThunderingHerd(b *testing.B) { pending := stm.NewBuiltinEqVar(0) for range iter.N(1000) { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(pending, tx.Get(pending).(int)+1) + pending.Set(tx, pending.Get(tx)+1) })) go func() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - t := tx.Get(tokens).(int) + t := tokens.Get(tx) if t > 0 { - tx.Set(tokens, t-1) - tx.Set(pending, tx.Get(pending).(int)-1) + tokens.Set(tx, t-1) + pending.Set(tx, pending.Get(tx)-1) } else { tx.Retry() } @@ -78,18 +78,18 @@ func BenchmarkThunderingHerd(b *testing.B) { }() } go func() { - for stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(done).(bool) { + for stm.Atomically(func(tx *stm.Tx) bool { + if done.Get(tx) { return false } - tx.Assert(tx.Get(tokens).(int) < maxTokens) - tx.Set(tokens, tx.Get(tokens).(int)+1) + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) return true - }).(bool) { + }) { } }() stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(pending).(int) == 0) + tx.Assert(pending.Get(tx) == 0) })) stm.AtomicSet(done, true) } @@ -103,49 +103,49 @@ func BenchmarkInvertedThunderingHerd(b *testing.B) { for range iter.N(1000) { ready := stm.NewVar(false) stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Set(pending, tx.Get(pending).(stmutil.Settish).Add(ready)) + pending.Set(tx, pending.Get(tx).Add(ready)) })) go func() { stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(ready).(bool)) - set := tx.Get(pending).(stmutil.Settish) + tx.Assert(ready.Get(tx)) + set := pending.Get(tx) if !set.Contains(ready) { panic("couldn't find ourselves in pending") } - tx.Set(pending, set.Delete(ready)) + pending.Set(tx, set.Delete(ready)) })) //b.Log("waiter finished") }() } go func() { - for stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(done).(bool) { + for stm.Atomically(func(tx *stm.Tx) bool { + if done.Get(tx) { return false } - tx.Assert(tx.Get(tokens).(int) < maxTokens) - tx.Set(tokens, tx.Get(tokens).(int)+1) + tx.Assert(tokens.Get(tx) < maxTokens) + tokens.Set(tx, tokens.Get(tx)+1) return true - }).(bool) { + }) { } }() go func() { - for stm.Atomically(func(tx *stm.Tx) interface{} { - tx.Assert(tx.Get(tokens).(int) > 0) - tx.Set(tokens, tx.Get(tokens).(int)-1) - tx.Get(pending).(stmutil.Settish).Range(func(i interface{}) bool { - ready := i.(*stm.Var) - if !tx.Get(ready).(bool) { - tx.Set(ready, true) + for stm.Atomically(func(tx *stm.Tx) bool { + tx.Assert(tokens.Get(tx) > 0) + tokens.Set(tx, tokens.Get(tx)-1) + pending.Get(tx).Range(func(i any) bool { + ready := i.(*stm.Var[bool]) + if !ready.Get(tx) { + ready.Set(tx, true) return false } return true }) - return !tx.Get(done).(bool) - }).(bool) { + return !done.Get(tx) + }) { } }() stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - tx.Assert(tx.Get(pending).(stmutil.Lenner).Len() == 0) + tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0) })) stm.AtomicSet(done, true) } @@ -2,19 +2,18 @@ package stm import ( "math/rand" - "reflect" "runtime/pprof" "sync" "time" ) var ( - txPool = sync.Pool{New: func() interface{} { + txPool = sync.Pool{New: func() any { expvars.Add("new txs", 1) tx := &Tx{ - reads: make(map[*Var]VarValue), - writes: make(map[*Var]interface{}), - watching: make(map[*Var]struct{}), + reads: make(map[txVar]VarValue), + writes: make(map[txVar]any), + watching: make(map[txVar]struct{}), } tx.cond.L = &tx.mu return tx @@ -40,7 +39,7 @@ func newTx() *Tx { return tx } -func WouldBlock(fn Operation) (block bool) { +func WouldBlock[R any](fn Operation[R]) (block bool) { tx := newTx() tx.reset() _, block = catchRetry(fn, tx) @@ -52,7 +51,7 @@ func WouldBlock(fn Operation) (block bool) { } // Atomically executes the atomic function fn. -func Atomically(op Operation) interface{} { +func Atomically[R any](op Operation[R]) R { expvars.Add("atomically", 1) // run the transaction tx := newTx() @@ -104,12 +103,12 @@ retry: } // AtomicGet is a helper function that atomically reads a value. -func AtomicGet(v *Var) interface{} { - return v.value.Load().(VarValue).Get() +func AtomicGet[T any](v *Var[T]) T { + return v.value.Load().Get().(T) } // AtomicSet is a helper function that atomically writes a value. -func AtomicSet(v *Var, val interface{}) { +func AtomicSet[T any](v *Var[T], val T) { v.mu.Lock() v.changeValue(val) v.mu.Unlock() @@ -117,20 +116,19 @@ func AtomicSet(v *Var, val interface{}) { // Compose is a helper function that composes multiple transactions into a // single transaction. -func Compose(fns ...Operation) Operation { - return func(tx *Tx) interface{} { +func Compose[R any](fns ...Operation[R]) Operation[struct{}] { + return VoidOperation(func(tx *Tx) { for _, f := range fns { f(tx) } - return nil - } + }) } // Select runs the supplied functions in order. Execution stops when a // function succeeds without calling Retry. If no functions succeed, the // entire selection will be retried. -func Select(fns ...Operation) Operation { - return func(tx *Tx) interface{} { +func Select[R any](fns ...Operation[R]) Operation[R] { + return func(tx *Tx) R { switch len(fns) { case 0: // empty Select blocks forever @@ -140,7 +138,7 @@ func Select(fns ...Operation) Operation { return fns[0](tx) default: oldWrites := tx.writes - tx.writes = make(map[*Var]interface{}, len(oldWrites)) + tx.writes = make(map[txVar]any, len(oldWrites)) for k, v := range oldWrites { tx.writes[k] = v } @@ -155,23 +153,17 @@ func Select(fns ...Operation) Operation { } } -type Operation func(*Tx) interface{} +type Operation[R any] func(*Tx) R -func VoidOperation(f func(*Tx)) Operation { - return func(tx *Tx) interface{} { +func VoidOperation(f func(*Tx)) Operation[struct{}] { + return func(tx *Tx) struct{} { f(tx) - return nil + return struct{}{} } } -func AtomicModify(v *Var, f interface{}) { - r := reflect.ValueOf(f) +func AtomicModify[T any](v *Var[T], f func(T) T) { Atomically(VoidOperation(func(tx *Tx) { - cur := reflect.ValueOf(tx.Get(v)) - out := r.Call([]reflect.Value{cur}) - if lenOut := len(out); lenOut != 1 { - panic(lenOut) - } - tx.Set(v, out[0].Interface()) + v.Set(tx, f(v.Get(tx))) })) } @@ -1,6 +1,6 @@ module github.com/anacrolix/stm -go 1.13 +go 1.18 require ( github.com/anacrolix/envpprof v1.0.0 @@ -9,3 +9,11 @@ require ( github.com/benbjohnson/immutable v0.2.0 github.com/stretchr/testify v1.3.0 ) + +require ( + github.com/alecthomas/atomic v0.1.0-alpha2 + github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/huandu/xstrings v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect +) @@ -2,10 +2,13 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= -github.com/RoaringBitmap/roaring v0.4.17 h1:oCYFIFEMSQZrLHpywH7919esI1VSrQZ0pJXkZPGIJ78= github.com/RoaringBitmap/roaring v0.4.17/go.mod h1:D3qVegWTmfCaX4Bl5CrBE9hfrSrrXIr8KVNvRsDi1NI= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/alecthomas/assert/v2 v2.0.0-alpha3 h1:pcHeMvQ3OMstAWgaeaXIAL8uzB9xMm2zlxt+/4ml8lk= +github.com/alecthomas/atomic v0.1.0-alpha2 h1:dqwXmax66gXvHhsOS4pGPZKqYOlTkapELkLb3MNdlH8= +github.com/alecthomas/atomic v0.1.0-alpha2/go.mod h1:zD6QGEyw49HIq19caJDc2NMXAy8rNi9ROrxtMXATfyI= +github.com/alecthomas/repr v0.0.0-20210801044451-80ca428c5142 h1:8Uy0oSf5co/NZXje7U1z8Mpep++QJOldL2hs/sBQf48= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/anacrolix/envpprof v0.0.0-20180404065416-323002cec2fa/go.mod h1:KgHhUaQMc8cC0+cEflSgCFNFbKwi5h54gqtVn8yhP7c= @@ -16,12 +19,10 @@ github.com/anacrolix/missinggo v1.1.0/go.mod h1:MBJu3Sk/k3ZfGYcS7z18gwfu72Ey/xop github.com/anacrolix/missinggo/v2 v2.2.0 h1:JUZh/gF/F4hXejj6I71wuO92MQDwQdLM3yRgYqTlmCg= github.com/anacrolix/missinggo/v2 v2.2.0/go.mod h1:o0jgJoYOyaoYQ4E2ZMISVa9c88BbUBVQQW4QeRkNCGY= github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= -github.com/anacrolix/tagflag v1.0.0 h1:NoxBVyke6iEtXfSY/n3lY3jNCBjQDu7aTvwHJxNLJAQ= github.com/anacrolix/tagflag v1.0.0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/benbjohnson/immutable v0.2.0 h1:t0rW3lNFwfQ85IDO1mhMbumxdVSti4nnVaal4r45Oio= github.com/benbjohnson/immutable v0.2.0/go.mod h1:uc6OHo6PN2++n98KHLxW8ef4W42ylHiQSENghE1ezxI= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c h1:FUUopH4brHNO2kJoNN3pV+OBEYmgraLT/KHZrMM69r0= @@ -30,17 +31,14 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= -github.com/glycerine/go-unsnap-stream v0.0.0-20181221182339-f9677308dec2 h1:Ujru1hufTHVb++eG6OuNDKMxZnGIvF6o/u8q/8h2+I4= github.com/glycerine/go-unsnap-stream v0.0.0-20181221182339-f9677308dec2/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE= github.com/glycerine/goconvey v0.0.0-20180728074245-46e3a41ad493/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= github.com/glycerine/goconvey v0.0.0-20190315024820-982ee783a72e/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24= @@ -51,22 +49,20 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20190309154008-847fc94819f9/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/huandu/xstrings v1.2.0 h1:yPeWdRnmynF7p+lLYz0H2tthW9lqhMJrQV/U7yy4wX0= @@ -77,7 +73,6 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -85,27 +80,20 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829 h1:D+CiwcpGTW6pL6bv6KI3KbyEyCKyS+1JWS2h8PNDnGA= github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f h1:BVwpUVJDADN2ufcGik7W992pyps0wZ888b/y9GXcLTU= github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.2.0 h1:kUZDBDTdBVBYBj5Tmh2NZLlF60mfjA27rM34b+cVwNU= github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1 h1:/K3IL0Z1quvmJ7X0A1AwNEK7CRkVK3YwfOU/QAL4WGg= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 h1:GHRpF1pTW19a8tTFrMLUcfWwyC0pnifVo2ClaLq+hP8= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= @@ -119,12 +107,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= -github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU= github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -go.opencensus.io v0.20.2 h1:NAfh7zF0/3/HqtMvJNZ/RFrSlCE6ZTlHmKfhL/Dm1Jk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/rate/rate_test.go b/rate/rate_test.go index 7f44d74..3078b4c 100644 --- a/rate/rate_test.go +++ b/rate/rate_test.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.7 // +build go1.7 package rate @@ -403,7 +404,7 @@ func runWait(t *testing.T, lim *Limiter, w wait) { t.Helper() start := time.Now() err := lim.WaitN(w.ctx, w.n) - delay := time.Now().Sub(start) + delay := time.Since(start) if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) { errString := "<nil>" if !w.nilErr { diff --git a/rate/ratelimit.go b/rate/ratelimit.go index a44f383..f521f66 100644 --- a/rate/ratelimit.go +++ b/rate/ratelimit.go @@ -13,9 +13,9 @@ import ( type numTokens = int type Limiter struct { - max *stm.Var - cur *stm.Var - lastAdd *stm.Var + max *stm.Var[numTokens] + cur *stm.Var[numTokens] + lastAdd *stm.Var[time.Time] rate Limit } @@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter { func (rl *Limiter) tokenGenerator(interval time.Duration) { for { - lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time) + lastAdd := stm.AtomicGet(rl.lastAdd) time.Sleep(time.Until(lastAdd.Add(interval))) now := time.Now() available := numTokens(now.Sub(lastAdd) / interval) @@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { continue } stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(rl.cur).(numTokens) - max := tx.Get(rl.max).(numTokens) + cur := rl.cur.Get(tx) + max := rl.max.Get(tx) tx.Assert(cur < max) newCur := cur + available if newCur > max { newCur = max } if newCur != cur { - tx.Set(rl.cur, newCur) + rl.cur.Set(tx, newCur) } - tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) + rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) })) } } @@ -77,9 +77,9 @@ func (rl *Limiter) Allow() bool { } func (rl *Limiter) AllowN(n numTokens) bool { - return stm.Atomically(func(tx *stm.Tx) interface{} { + return stm.Atomically(func(tx *stm.Tx) bool { return rl.takeTokens(tx, n) - }).(bool) + }) } func (rl *Limiter) AllowStm(tx *stm.Tx) bool { @@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool { if rl.rate == Inf { return true } - cur := tx.Get(rl.cur).(numTokens) + cur := rl.cur.Get(tx) if cur >= n { - tx.Set(rl.cur, cur-n) + rl.cur.Set(tx, cur-n) return true } return false @@ -105,25 +105,25 @@ func (rl *Limiter) Wait(ctx context.Context) error { func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := stmutil.ContextDoneVar(ctx) defer cancel() - if err := stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(ctxDone).(bool) { + if err := stm.Atomically(func(tx *stm.Tx) error { + if ctxDone.Get(tx) { return ctx.Err() } if rl.takeTokens(tx, n) { return nil } - if n > tx.Get(rl.max).(numTokens) { + if n > rl.max.Get(tx) { return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { - if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { + if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { return context.DeadlineExceeded } } tx.Retry() panic("unreachable") }); err != nil { - return err.(error) + return err } return nil @@ -11,7 +11,7 @@ var retries = pprof.NewProfile("stmRetries") var retry = &struct{}{} // catchRetry returns true if fn calls tx.Retry. -func catchRetry(fn Operation, tx *Tx) (result interface{}, gotRetry bool) { +func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) { defer func() { if r := recover(); r == retry { gotRetry = true diff --git a/stm_test.go b/stm_test.go index 207148c..a98685b 100644 --- a/stm_test.go +++ b/stm_test.go @@ -14,14 +14,14 @@ func TestDecrement(t *testing.T) { x := NewVar(1000) for i := 0; i < 500; i++ { go Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) })) } done := make(chan struct{}) go func() { Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x) == 500) + tx.Assert(x.Get(tx) == 500) })) close(done) }() @@ -50,10 +50,10 @@ func TestReadVerify(t *testing.T) { // between the reads, causing this tx to retry. var x2, y2 int Atomically(VoidOperation(func(tx *Tx) { - x2 = tx.Get(x).(int) + x2 = x.Get(tx) read <- struct{}{} <-read // wait for other tx to complete - y2 = tx.Get(y).(int) + y2 = y.Get(tx) })) if x2 == 1 && y2 == 2 { t.Fatal("read was not verified") @@ -68,8 +68,8 @@ func TestRetry(t *testing.T) { for i := 0; i < 10; i++ { time.Sleep(10 * time.Millisecond) Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) - tx.Set(x, cur-1) + cur := x.Get(tx) + x.Set(tx, cur-1) })) } }() @@ -77,7 +77,7 @@ func TestRetry(t *testing.T) { // retry. This should result in no more than 1 retry per transaction. retry := 0 Atomically(VoidOperation(func(tx *Tx) { - cur := tx.Get(x).(int) + cur := x.Get(tx) if cur != 0 { retry++ tx.Retry() @@ -100,9 +100,9 @@ func TestVerify(t *testing.T) { go func() { Atomically(VoidOperation(func(tx *Tx) { <-read - rx := tx.Get(x).(*foo) + rx := x.Get(tx) rx.i = 7 - tx.Set(x, rx) + x.Set(tx, rx) })) read <- struct{}{} // other tx should retry, so we need to read/send again @@ -113,7 +113,7 @@ func TestVerify(t *testing.T) { // between the reads, causing this tx to retry. var i int Atomically(VoidOperation(func(tx *Tx) { - f := tx.Get(x).(*foo) + f := x.Get(tx) i = f.i read <- struct{}{} <-read // wait for other tx to complete @@ -125,36 +125,37 @@ func TestVerify(t *testing.T) { func TestSelect(t *testing.T) { // empty Select should panic - require.Panics(t, func() { Atomically(Select()) }) + require.Panics(t, func() { Atomically(Select[struct{}]()) }) // with one arg, Select adds no effect x := NewVar(2) Atomically(Select(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 2) + tx.Assert(x.Get(tx) == 2) }))) picked := Atomically(Select( // always blocks; should never be selected - VoidOperation(func(tx *Tx) { + func(tx *Tx) int { tx.Retry() - }), + panic("unreachable") + }, // always succeeds; should always be selected - func(tx *Tx) interface{} { + func(tx *Tx) int { return 2 }, // always succeeds; should never be selected - func(tx *Tx) interface{} { + func(tx *Tx) int { return 3 }, - )).(int) + )) assert.EqualValues(t, 2, picked) } func TestCompose(t *testing.T) { nums := make([]int, 100) - fns := make([]Operation, 100) + fns := make([]Operation[struct{}], 100) for i := range fns { - fns[i] = func(x int) Operation { + fns[i] = func(x int) Operation[struct{}] { return VoidOperation(func(*Tx) { nums[x] = x }) }(i) // capture loop var } @@ -169,7 +170,7 @@ func TestCompose(t *testing.T) { func TestPanic(t *testing.T) { // normal panics should escape Atomically assert.PanicsWithValue(t, "foo", func() { - Atomically(func(*Tx) interface{} { + Atomically(func(*Tx) any { panic("foo") }) }) @@ -180,8 +181,8 @@ func TestReadWritten(t *testing.T) { // previously written value x := NewVar(3) Atomically(VoidOperation(func(tx *Tx) { - tx.Set(x, 5) - tx.Assert(tx.Get(x).(int) == 5) + x.Set(tx, 5) + tx.Assert(x.Get(tx) == 5) })) } @@ -191,7 +192,7 @@ func TestAtomicSetRetry(t *testing.T) { done := make(chan struct{}) go func() { Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(x).(int) == 5) + tx.Assert(x.Get(tx) == 5) })) done <- struct{}{} }() @@ -210,17 +211,17 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) { hits := NewVar(0) ready := NewVar(true) // The ball is ready for hitting. var wg sync.WaitGroup - bat := func(from, to interface{}, noise string) { + bat := func(from, to bool, noise string) { defer wg.Done() - for !Atomically(func(tx *Tx) interface{} { - if tx.Get(doneVar).(bool) { + for !Atomically(func(tx *Tx) any { + if doneVar.Get(tx) { return true } - tx.Assert(tx.Get(ready).(bool)) - if tx.Get(ball) == from { - tx.Set(ball, to) - tx.Set(hits, tx.Get(hits).(int)+1) - tx.Set(ready, false) + tx.Assert(ready.Get(tx)) + if ball.Get(tx) == from { + ball.Set(tx, to) + hits.Set(tx, hits.Get(tx)+1) + ready.Set(tx, false) return false } return tx.Retry() @@ -233,8 +234,8 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) { go bat(false, true, "ping!") go bat(true, false, "pong!") Atomically(VoidOperation(func(tx *Tx) { - tx.Assert(tx.Get(hits).(int) >= n) - tx.Set(doneVar, true) + tx.Assert(hits.Get(tx) >= n) + doneVar.Set(tx, true) })) wg.Wait() } @@ -245,7 +246,7 @@ func TestPingPong(t *testing.T) { func TestSleepingBeauty(t *testing.T) { require.Panics(t, func() { - Atomically(func(tx *Tx) interface{} { + Atomically(func(tx *Tx) any { tx.Assert(false) return nil }) @@ -253,7 +254,7 @@ func TestSleepingBeauty(t *testing.T) { } //func TestRetryStack(t *testing.T) { -// v := NewVar(nil) +// v := NewVar[int](nil) // go func() { // i := 0 // for { @@ -261,12 +262,12 @@ func TestSleepingBeauty(t *testing.T) { // i++ // } // }() -// Atomically(func(tx *Tx) interface{} { +// Atomically(func(tx *Tx) any { // debug.PrintStack() // ret := func() { // defer Atomically(nil) // } -// tx.Get(v) +// v.Get(tx) // tx.Assert(false) // return ret // }) diff --git a/stmutil/containers.go b/stmutil/containers.go index e0b532d..c7a4a49 100644 --- a/stmutil/containers.go +++ b/stmutil/containers.go @@ -9,10 +9,10 @@ import ( ) type Settish interface { - Add(interface{}) Settish - Delete(interface{}) Settish - Contains(interface{}) bool - Range(func(interface{}) bool) + Add(any) Settish + Delete(any) Settish + Contains(any) bool + Range(func(any) bool) iter.Iterable Len() int } @@ -23,11 +23,11 @@ type mapToSet struct { type interhash struct{} -func (interhash) Hash(x interface{}) uint32 { +func (interhash) Hash(x any) uint32 { return uint32(nilinterhash(unsafe.Pointer(&x), 0)) } -func (interhash) Equal(i, j interface{}) bool { +func (interhash) Equal(i, j any) bool { return i == j } @@ -39,12 +39,12 @@ func NewSortedSet(lesser lessFunc) Settish { return mapToSet{NewSortedMap(lesser)} } -func (s mapToSet) Add(x interface{}) Settish { +func (s mapToSet) Add(x any) Settish { s.m = s.m.Set(x, nil) return s } -func (s mapToSet) Delete(x interface{}) Settish { +func (s mapToSet) Delete(x any) Settish { s.m = s.m.Delete(x) return s } @@ -53,13 +53,13 @@ func (s mapToSet) Len() int { return s.m.Len() } -func (s mapToSet) Contains(x interface{}) bool { +func (s mapToSet) Contains(x any) bool { _, ok := s.m.Get(x) return ok } -func (s mapToSet) Range(f func(interface{}) bool) { - s.m.Range(func(k, _ interface{}) bool { +func (s mapToSet) Range(f func(any) bool) { + s.m.Range(func(k, _ any) bool { return f(k) }) } @@ -78,17 +78,17 @@ func NewMap() Mappish { var _ Mappish = Map{} -func (m Map) Delete(x interface{}) Mappish { +func (m Map) Delete(x any) Mappish { m.Map = m.Map.Delete(x) return m } -func (m Map) Set(key, value interface{}) Mappish { +func (m Map) Set(key, value any) Mappish { m.Map = m.Map.Set(key, value) return m } -func (sm Map) Range(f func(key, value interface{}) bool) { +func (sm Map) Range(f func(key, value any) bool) { iter := sm.Map.Iterator() for !iter.Done() { if !f(iter.Next()) { @@ -98,7 +98,7 @@ func (sm Map) Range(f func(key, value interface{}) bool) { } func (sm Map) Iter(cb iter.Callback) { - sm.Range(func(key, _ interface{}) bool { + sm.Range(func(key, _ any) bool { return cb(key) }) } @@ -107,17 +107,17 @@ type SortedMap struct { *immutable.SortedMap } -func (sm SortedMap) Set(key, value interface{}) Mappish { +func (sm SortedMap) Set(key, value any) Mappish { sm.SortedMap = sm.SortedMap.Set(key, value) return sm } -func (sm SortedMap) Delete(key interface{}) Mappish { +func (sm SortedMap) Delete(key any) Mappish { sm.SortedMap = sm.SortedMap.Delete(key) return sm } -func (sm SortedMap) Range(f func(key, value interface{}) bool) { +func (sm SortedMap) Range(f func(key, value any) bool) { iter := sm.SortedMap.Iterator() for !iter.Done() { if !f(iter.Next()) { @@ -127,18 +127,18 @@ func (sm SortedMap) Range(f func(key, value interface{}) bool) { } func (sm SortedMap) Iter(cb iter.Callback) { - sm.Range(func(key, _ interface{}) bool { + sm.Range(func(key, _ any) bool { return cb(key) }) } -type lessFunc func(l, r interface{}) bool +type lessFunc func(l, r any) bool type comparer struct { less lessFunc } -func (me comparer) Compare(i, j interface{}) int { +func (me comparer) Compare(i, j any) int { if me.less(i, j) { return -1 } else if me.less(j, i) { @@ -155,15 +155,15 @@ func NewSortedMap(less lessFunc) Mappish { } type Mappish interface { - Set(key, value interface{}) Mappish - Delete(key interface{}) Mappish - Get(key interface{}) (interface{}, bool) - Range(func(_, _ interface{}) bool) + Set(key, value any) Mappish + Delete(key any) Mappish + Get(key any) (any, bool) + Range(func(_, _ any) bool) Len() int iter.Iterable } -func GetLeft(l, _ interface{}) interface{} { +func GetLeft(l, _ any) any { return l } @@ -171,7 +171,7 @@ func GetLeft(l, _ interface{}) interface{} { //go:linkname nilinterhash runtime.nilinterhash func nilinterhash(p unsafe.Pointer, h uintptr) uintptr -func interfaceHash(x interface{}) uint32 { +func interfaceHash(x any) uint32 { return uint32(nilinterhash(unsafe.Pointer(&x), 0)) } diff --git a/stmutil/context.go b/stmutil/context.go index 8a4d58d..6f8ba9b 100644 --- a/stmutil/context.go +++ b/stmutil/context.go @@ -9,12 +9,12 @@ import ( var ( mu sync.Mutex - ctxVars = map[context.Context]*stm.Var{} + ctxVars = map[context.Context]*stm.Var[bool]{} ) // Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be // called when the user is no longer interested in the var. -func ContextDoneVar(ctx context.Context) (*stm.Var, func()) { +func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) { mu.Lock() defer mu.Unlock() if v, ok := ctxVars[ctx]; ok { @@ -5,13 +5,22 @@ import ( "sort" "sync" "unsafe" + + "github.com/alecthomas/atomic" ) +type txVar interface { + getValue() *atomic.Value[VarValue] + changeValue(any) + getWatchers() *sync.Map + getLock() *sync.Mutex +} + // A Tx represents an atomic transaction. type Tx struct { - reads map[*Var]VarValue - writes map[*Var]interface{} - watching map[*Var]struct{} + reads map[txVar]VarValue + writes map[txVar]any + watching map[txVar]struct{} locks txLocks mu sync.Mutex cond sync.Cond @@ -24,7 +33,7 @@ type Tx struct { // Check that none of the logged values have changed since the transaction began. func (tx *Tx) inputsChanged() bool { for v, read := range tx.reads { - if read.Changed(v.value.Load().(VarValue)) { + if read.Changed(v.getValue().Load()) { return true } } @@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() { for v := range tx.watching { if _, ok := tx.reads[v]; !ok { delete(tx.watching, v) - v.watchers.Delete(tx) + v.getWatchers().Delete(tx) } } for v := range tx.reads { if _, ok := tx.watching[v]; !ok { - v.watchers.Store(tx, nil) + v.getWatchers().Store(tx, nil) tx.watching[v] = struct{}{} } } @@ -76,22 +85,22 @@ func (tx *Tx) wait() { } // Get returns the value of v as of the start of the transaction. -func (tx *Tx) Get(v *Var) interface{} { +func (v *Var[T]) Get(tx *Tx) T { // If we previously wrote to v, it will be in the write log. if val, ok := tx.writes[v]; ok { - return val + return val.(T) } // If we haven't previously read v, record its version vv, ok := tx.reads[v] if !ok { - vv = v.value.Load().(VarValue) + vv = v.getValue().Load() tx.reads[v] = vv } - return vv.Get() + return vv.Get().(T) } // Set sets the value of a Var for the lifetime of the transaction. -func (tx *Tx) Set(v *Var, val interface{}) { +func (v *Var[T]) Set(tx *Tx, val T) { if v == nil { panic("nil Var") } @@ -105,11 +114,10 @@ type txProfileValue struct { // Retry aborts the transaction and retries it when a Var changes. You can return from this method // to satisfy return values, but it should never actually return anything as it panics internally. -func (tx *Tx) Retry() interface{} { +func (tx *Tx) Retry() struct{} { retries.Add(txProfileValue{tx, tx.numRetryValues}, 1) tx.numRetryValues++ panic(retry) - panic("unreachable") } // Assert is a helper function that retries a transaction if the condition is @@ -143,7 +151,7 @@ func (tx *Tx) removeRetryProfiles() { func (tx *Tx) recycle() { for v := range tx.watching { delete(tx.watching, v) - v.watchers.Delete(tx) + v.getWatchers().Delete(tx) } tx.removeRetryProfiles() // I don't think we can reuse Txs, because the "completed" field should/needs to be set @@ -164,7 +172,7 @@ func (tx *Tx) resetLocks() { func (tx *Tx) collectReadLocks() { for v := range tx.reads { - tx.locks.append(&v.mu) + tx.locks.append(v.getLock()) } } @@ -172,7 +180,7 @@ func (tx *Tx) collectAllLocks() { tx.collectReadLocks() for v := range tx.writes { if _, ok := tx.reads[v]; !ok { - tx.locks.append(&v.mu) + tx.locks.append(v.getLock()) } } } diff --git a/var-value.go b/var-value.go index ff97104..3518bf6 100644 --- a/var-value.go +++ b/var-value.go @@ -1,51 +1,51 @@ package stm type VarValue interface { - Set(interface{}) VarValue - Get() interface{} + Set(any) VarValue + Get() any Changed(VarValue) bool } type version uint64 -type versionedValue struct { - value interface{} +type versionedValue[T any] struct { + value T version version } -func (me versionedValue) Set(newValue interface{}) VarValue { - return versionedValue{ - value: newValue, +func (me versionedValue[T]) Set(newValue any) VarValue { + return versionedValue[T]{ + value: newValue.(T), version: me.version + 1, } } -func (me versionedValue) Get() interface{} { +func (me versionedValue[T]) Get() any { return me.value } -func (me versionedValue) Changed(other VarValue) bool { - return me.version != other.(versionedValue).version +func (me versionedValue[T]) Changed(other VarValue) bool { + return me.version != other.(versionedValue[T]).version } -type customVarValue struct { - value interface{} - changed func(interface{}, interface{}) bool +type customVarValue[T any] struct { + value T + changed func(T, T) bool } -var _ VarValue = customVarValue{} +var _ VarValue = customVarValue[struct{}]{} -func (me customVarValue) Changed(other VarValue) bool { - return me.changed(me.value, other.(customVarValue).value) +func (me customVarValue[T]) Changed(other VarValue) bool { + return me.changed(me.value, other.(customVarValue[T]).value) } -func (me customVarValue) Set(newValue interface{}) VarValue { - return customVarValue{ - value: newValue, +func (me customVarValue[T]) Set(newValue any) VarValue { + return customVarValue[T]{ + value: newValue.(T), changed: me.changed, } } -func (me customVarValue) Get() interface{} { +func (me customVarValue[T]) Get() any { return me.value } @@ -2,18 +2,31 @@ package stm import ( "sync" - "sync/atomic" + + "github.com/alecthomas/atomic" ) // Holds an STM variable. -type Var struct { - value atomic.Value +type Var[T any] struct { + value atomic.Value[VarValue] watchers sync.Map mu sync.Mutex } -func (v *Var) changeValue(new interface{}) { - old := v.value.Load().(VarValue) +func (v *Var[T]) getValue() *atomic.Value[VarValue] { + return &v.value +} + +func (v *Var[T]) getWatchers() *sync.Map { + return &v.watchers +} + +func (v *Var[T]) getLock() *sync.Mutex { + return &v.mu +} + +func (v *Var[T]) changeValue(new any) { + old := v.value.Load() newVarValue := old.Set(new) v.value.Store(newVarValue) if old.Changed(newVarValue) { @@ -21,8 +34,8 @@ func (v *Var) changeValue(new interface{}) { } } -func (v *Var) wakeWatchers(new VarValue) { - v.watchers.Range(func(k, _ interface{}) bool { +func (v *Var[T]) wakeWatchers(new VarValue) { + v.watchers.Range(func(k, _ any) bool { tx := k.(*Tx) // We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we // could signal it before it goes to sleep and it will miss the notification. @@ -34,35 +47,30 @@ func (v *Var) wakeWatchers(new VarValue) { } } tx.mu.Unlock() - return !v.value.Load().(VarValue).Changed(new) + return !v.value.Load().Changed(new) }) } -type varSnapshot struct { - val interface{} - version uint64 -} - // Returns a new STM variable. -func NewVar(val interface{}) *Var { - v := &Var{} - v.value.Store(versionedValue{ +func NewVar[T any](val T) *Var[T] { + v := &Var[T]{} + v.value.Store(versionedValue[T]{ value: val, }) return v } -func NewCustomVar(val interface{}, changed func(interface{}, interface{}) bool) *Var { - v := &Var{} - v.value.Store(customVarValue{ +func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { + v := &Var[T]{} + v.value.Store(customVarValue[T]{ value: val, changed: changed, }) return v } -func NewBuiltinEqVar(val interface{}) *Var { - return NewCustomVar(val, func(a, b interface{}) bool { +func NewBuiltinEqVar[T comparable](val T) *Var[T] { + return NewCustomVar(val, func(a, b T) bool { return a != b }) } |