aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.circleci/config.yml2
-rw-r--r--README.md33
-rw-r--r--bench_test.go6
-rw-r--r--cmd/santa-example/main.go34
-rw-r--r--doc.go26
-rw-r--r--doc_test.go22
-rw-r--r--external_test.go58
-rw-r--r--funcs.go50
-rw-r--r--go.mod10
-rw-r--r--go.sum26
-rw-r--r--rate/rate_test.go3
-rw-r--r--rate/ratelimit.go34
-rw-r--r--retry.go2
-rw-r--r--stm_test.go77
-rw-r--r--stmutil/containers.go54
-rw-r--r--stmutil/context.go4
-rw-r--r--tx.go40
-rw-r--r--var-value.go40
-rw-r--r--var.go50
19 files changed, 286 insertions, 285 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 0db2f6b..148a6e6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -3,7 +3,7 @@ jobs:
build:
machine: true
environment:
- GO_BRANCH: release-branch.go1.15
+ GO_BRANCH: master
steps:
- run: echo $CIRCLE_WORKING_DIRECTORY
- run: echo $PWD
diff --git a/README.md b/README.md
index 6330f79..07e9952 100644
--- a/README.md
+++ b/README.md
@@ -12,10 +12,9 @@ composition will either deadlock or release the lock between functions (making
it non-atomic).
The `stm` API tries to mimic that of Haskell's [`Control.Concurrent.STM`](https://hackage.haskell.org/package/stm-2.4.4.1/docs/Control-Concurrent-STM.html), but
-this is not entirely possible due to Go's type system; we are forced to use
-`interface{}` and type assertions. Furthermore, Haskell can enforce at compile
-time that STM variables are not modified outside the STM monad. This is not
-possible in Go, so be especially careful when using pointers in your STM code.
+Haskell can enforce at compile time that STM variables are not modified outside
+the STM monad. This is not possible in Go, so be especially careful when using
+pointers in your STM code. Remember: modifying a pointer is a side effect!
Unlike Haskell, data in Go is not immutable by default, which means you have
to be careful when using STM to manage pointers. If two goroutines have access
@@ -31,22 +30,22 @@ applications in Go. If you find this package useful, please tell us about it!
See the package examples in the Go package docs for examples of common operations.
-See [example_santa_test.go](example_santa_test.go) for a more complex example.
+See [cmd/santa-example/main.go](cmd/santa-example/main.go) for a more complex example.
## Pointers
-Note that `Operation` now returns a value of type `interface{}`, which isn't included in the
+Note that `Operation` now returns a value of type `any`, which isn't included in the
examples throughout the documentation yet. See the type signatures for `Atomically` and `Operation`.
Be very careful when managing pointers inside transactions! (This includes
slices, maps, channels, and captured variables.) Here's why:
```go
-p := stm.NewVar([]byte{1,2,3})
+p := stm.NewVar[[]byte]([]byte{1,2,3})
stm.Atomically(func(tx *stm.Tx) {
- b := tx.Get(p).([]byte)
+ b := p.Get(tx)
b[0] = 7
- tx.Set(p, b)
+ stm.p.Set(tx, b)
})
```
@@ -57,11 +56,11 @@ Following this advice, we can rewrite the transaction to perform a copy:
```go
stm.Atomically(func(tx *stm.Tx) {
- b := tx.Get(p).([]byte)
+ b := p.Get(tx)
c := make([]byte, len(b))
copy(c, b)
c[0] = 7
- tx.Set(p, c)
+ p.Set(tx, c)
})
```
@@ -73,11 +72,11 @@ In the same vein, it would be a mistake to do this:
type foo struct {
i int
}
-p := stm.NewVar(&foo{i: 2})
+p := stm.NewVar[*foo](&foo{i: 2})
stm.Atomically(func(tx *stm.Tx) {
- f := tx.Get(p).(*foo)
+ f := p.Get(tx)
f.i = 7
- tx.Set(p, f)
+ stm.p.Set(tx, f)
})
```
@@ -88,11 +87,11 @@ the correct approach is to move the `Var` inside the struct:
type foo struct {
i *stm.Var
}
-f := foo{i: stm.NewVar(2)}
+f := foo{i: stm.NewVar[int](2)}
stm.Atomically(func(tx *stm.Tx) {
- i := tx.Get(f.i).(int)
+ i := f.i.Get(tx)
i = 7
- tx.Set(f.i, i)
+ f.i.Set(tx, i)
})
```
diff --git a/bench_test.go b/bench_test.go
index 82aaba1..0d40caf 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -27,13 +27,13 @@ func BenchmarkIncrementSTM(b *testing.B) {
x := NewVar(0)
for i := 0; i < 1000; i++ {
go Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur+1)
+ cur := x.Get(tx)
+ x.Set(tx, cur+1)
}))
}
// wait for x to reach 1000
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 1000)
+ tx.Assert(x.Get(tx) == 1000)
}))
}
}
diff --git a/cmd/santa-example/main.go b/cmd/santa-example/main.go
index 3c2b9ea..dcc8067 100644
--- a/cmd/santa-example/main.go
+++ b/cmd/santa-example/main.go
@@ -39,15 +39,15 @@ import (
type gate struct {
capacity int
- remaining *stm.Var
+ remaining *stm.Var[int]
}
func (g gate) pass() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
// wait until gate can hold us
tx.Assert(rem > 0)
- tx.Set(g.remaining, rem-1)
+ g.remaining.Set(tx, rem-1)
}))
}
@@ -56,7 +56,7 @@ func (g gate) operate() {
stm.AtomicSet(g.remaining, g.capacity)
// wait for gate to be full
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
tx.Assert(rem == 0)
}))
}
@@ -70,8 +70,8 @@ func newGate(capacity int) gate {
type group struct {
capacity int
- remaining *stm.Var
- gate1, gate2 *stm.Var
+ remaining *stm.Var[int]
+ gate1, gate2 *stm.Var[gate]
}
func newGroup(capacity int) *group {
@@ -85,28 +85,28 @@ func newGroup(capacity int) *group {
func (g *group) join() (g1, g2 gate) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
// wait until the group can hold us
tx.Assert(rem > 0)
- tx.Set(g.remaining, rem-1)
+ g.remaining.Set(tx, rem-1)
// return the group's gates
- g1 = tx.Get(g.gate1).(gate)
- g2 = tx.Get(g.gate2).(gate)
+ g1 = g.gate1.Get(tx)
+ g2 = g.gate2.Get(tx)
}))
return
}
func (g *group) await(tx *stm.Tx) (gate, gate) {
// wait for group to be empty
- rem := tx.Get(g.remaining).(int)
+ rem := g.remaining.Get(tx)
tx.Assert(rem == 0)
// get the group's gates
- g1 := tx.Get(g.gate1).(gate)
- g2 := tx.Get(g.gate2).(gate)
+ g1 := g.gate1.Get(tx)
+ g2 := g.gate2.Get(tx)
// reset group
- tx.Set(g.remaining, g.capacity)
- tx.Set(g.gate1, newGate(g.capacity))
- tx.Set(g.gate2, newGate(g.capacity))
+ g.remaining.Set(tx, g.capacity)
+ g.gate1.Set(tx, newGate(g.capacity))
+ g.gate2.Set(tx, newGate(g.capacity))
return g1, g2
}
@@ -137,7 +137,7 @@ type selection struct {
gate1, gate2 gate
}
-func chooseGroup(g *group, task string, s *selection) stm.Operation {
+func chooseGroup(g *group, task string, s *selection) stm.Operation[struct{}] {
return stm.VoidOperation(func(tx *stm.Tx) {
s.gate1, s.gate2 = g.await(tx)
s.task = task
diff --git a/doc.go b/doc.go
index db03501..c704c33 100644
--- a/doc.go
+++ b/doc.go
@@ -10,14 +10,14 @@ it non-atomic).
To begin, create an STM object that wraps the data you want to access
concurrently.
- x := stm.NewVar(3)
+ x := stm.NewVar[int](3)
You can then use the Atomically method to atomically read and/or write the the
data. This code atomically decrements x:
stm.Atomically(func(tx *stm.Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
})
An important part of STM transactions is retrying. At any point during the
@@ -29,11 +29,11 @@ updated before the transaction will be rerun. As an example, this code will
try to decrement x, but will block as long as x is zero:
stm.Atomically(func(tx *stm.Tx) {
- cur := tx.Get(x).(int)
+ cur := x.Get(tx)
if cur == 0 {
tx.Retry()
}
- tx.Set(x, cur-1)
+ x.Set(tx, cur-1)
})
Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any other
@@ -47,13 +47,13 @@ retried. For example, this code implements the "decrement-if-nonzero"
transaction above, but for two values. It will first try to decrement x, then
y, and block if both values are zero.
- func dec(v *stm.Var) {
+ func dec(v *stm.Var[int]) {
return func(tx *stm.Tx) {
- cur := tx.Get(v).(int)
+ cur := v.Get(tx)
if cur == 0 {
tx.Retry()
}
- tx.Set(v, cur-1)
+ v.Set(tx, cur-1)
}
}
@@ -69,11 +69,9 @@ behavior. One common way to get around this is to build up a list of impure
operations inside the transaction, and then perform them after the transaction
completes.
-The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but this
-is not entirely possible due to Go's type system; we are forced to use
-interface{} and type assertions. Furthermore, Haskell can enforce at compile
-time that STM variables are not modified outside the STM monad. This is not
-possible in Go, so be especially careful when using pointers in your STM code.
-Remember: modifying a pointer is a side effect!
+The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but
+Haskell can enforce at compile time that STM variables are not modified outside
+the STM monad. This is not possible in Go, so be especially careful when using
+pointers in your STM code. Remember: modifying a pointer is a side effect!
*/
package stm
diff --git a/doc_test.go b/doc_test.go
index f6bb863..ae1af9a 100644
--- a/doc_test.go
+++ b/doc_test.go
@@ -11,38 +11,38 @@ func Example() {
// read a variable
var v int
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- v = tx.Get(n).(int)
+ v = n.Get(tx)
}))
// or:
- v = stm.AtomicGet(n).(int)
+ v = stm.AtomicGet(n)
_ = v
// write to a variable
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(n, 12)
+ n.Set(tx, 12)
}))
// or:
stm.AtomicSet(n, 12)
// update a variable
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
- tx.Set(n, cur-1)
+ cur := n.Get(tx)
+ n.Set(tx, cur-1)
}))
// block until a condition is met
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
+ cur := n.Get(tx)
if cur != 0 {
tx.Retry()
}
- tx.Set(n, 10)
+ n.Set(tx, 10)
}))
// or:
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(n).(int)
+ cur := n.Get(tx)
tx.Assert(cur == 0)
- tx.Set(n, 10)
+ n.Set(tx, 10)
}))
// select among multiple (potentially blocking) transactions
@@ -51,11 +51,11 @@ func Example() {
stm.VoidOperation(func(tx *stm.Tx) { tx.Retry() }),
// this function will always succeed without blocking
- stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 10) }),
+ stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 10) }),
// this function will never run, because the previous
// function succeeded
- stm.VoidOperation(func(tx *stm.Tx) { tx.Set(n, 11) }),
+ stm.VoidOperation(func(tx *stm.Tx) { n.Set(tx, 11) }),
))
// since Select is a normal transaction, if the entire select retries
diff --git a/external_test.go b/external_test.go
index ae29ca8..abdf544 100644
--- a/external_test.go
+++ b/external_test.go
@@ -63,14 +63,14 @@ func BenchmarkThunderingHerd(b *testing.B) {
pending := stm.NewBuiltinEqVar(0)
for range iter.N(1000) {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(pending, tx.Get(pending).(int)+1)
+ pending.Set(tx, pending.Get(tx)+1)
}))
go func() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- t := tx.Get(tokens).(int)
+ t := tokens.Get(tx)
if t > 0 {
- tx.Set(tokens, t-1)
- tx.Set(pending, tx.Get(pending).(int)-1)
+ tokens.Set(tx, t-1)
+ pending.Set(tx, pending.Get(tx)-1)
} else {
tx.Retry()
}
@@ -78,18 +78,18 @@ func BenchmarkThunderingHerd(b *testing.B) {
}()
}
go func() {
- for stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(done).(bool) {
+ for stm.Atomically(func(tx *stm.Tx) bool {
+ if done.Get(tx) {
return false
}
- tx.Assert(tx.Get(tokens).(int) < maxTokens)
- tx.Set(tokens, tx.Get(tokens).(int)+1)
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
return true
- }).(bool) {
+ }) {
}
}()
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(pending).(int) == 0)
+ tx.Assert(pending.Get(tx) == 0)
}))
stm.AtomicSet(done, true)
}
@@ -103,49 +103,49 @@ func BenchmarkInvertedThunderingHerd(b *testing.B) {
for range iter.N(1000) {
ready := stm.NewVar(false)
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Set(pending, tx.Get(pending).(stmutil.Settish).Add(ready))
+ pending.Set(tx, pending.Get(tx).Add(ready))
}))
go func() {
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(ready).(bool))
- set := tx.Get(pending).(stmutil.Settish)
+ tx.Assert(ready.Get(tx))
+ set := pending.Get(tx)
if !set.Contains(ready) {
panic("couldn't find ourselves in pending")
}
- tx.Set(pending, set.Delete(ready))
+ pending.Set(tx, set.Delete(ready))
}))
//b.Log("waiter finished")
}()
}
go func() {
- for stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(done).(bool) {
+ for stm.Atomically(func(tx *stm.Tx) bool {
+ if done.Get(tx) {
return false
}
- tx.Assert(tx.Get(tokens).(int) < maxTokens)
- tx.Set(tokens, tx.Get(tokens).(int)+1)
+ tx.Assert(tokens.Get(tx) < maxTokens)
+ tokens.Set(tx, tokens.Get(tx)+1)
return true
- }).(bool) {
+ }) {
}
}()
go func() {
- for stm.Atomically(func(tx *stm.Tx) interface{} {
- tx.Assert(tx.Get(tokens).(int) > 0)
- tx.Set(tokens, tx.Get(tokens).(int)-1)
- tx.Get(pending).(stmutil.Settish).Range(func(i interface{}) bool {
- ready := i.(*stm.Var)
- if !tx.Get(ready).(bool) {
- tx.Set(ready, true)
+ for stm.Atomically(func(tx *stm.Tx) bool {
+ tx.Assert(tokens.Get(tx) > 0)
+ tokens.Set(tx, tokens.Get(tx)-1)
+ pending.Get(tx).Range(func(i any) bool {
+ ready := i.(*stm.Var[bool])
+ if !ready.Get(tx) {
+ ready.Set(tx, true)
return false
}
return true
})
- return !tx.Get(done).(bool)
- }).(bool) {
+ return !done.Get(tx)
+ }) {
}
}()
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- tx.Assert(tx.Get(pending).(stmutil.Lenner).Len() == 0)
+ tx.Assert(pending.Get(tx).(stmutil.Lenner).Len() == 0)
}))
stm.AtomicSet(done, true)
}
diff --git a/funcs.go b/funcs.go
index 8694eb4..07d35ec 100644
--- a/funcs.go
+++ b/funcs.go
@@ -2,19 +2,18 @@ package stm
import (
"math/rand"
- "reflect"
"runtime/pprof"
"sync"
"time"
)
var (
- txPool = sync.Pool{New: func() interface{} {
+ txPool = sync.Pool{New: func() any {
expvars.Add("new txs", 1)
tx := &Tx{
- reads: make(map[*Var]VarValue),
- writes: make(map[*Var]interface{}),
- watching: make(map[*Var]struct{}),
+ reads: make(map[txVar]VarValue),
+ writes: make(map[txVar]any),
+ watching: make(map[txVar]struct{}),
}
tx.cond.L = &tx.mu
return tx
@@ -40,7 +39,7 @@ func newTx() *Tx {
return tx
}
-func WouldBlock(fn Operation) (block bool) {
+func WouldBlock[R any](fn Operation[R]) (block bool) {
tx := newTx()
tx.reset()
_, block = catchRetry(fn, tx)
@@ -52,7 +51,7 @@ func WouldBlock(fn Operation) (block bool) {
}
// Atomically executes the atomic function fn.
-func Atomically(op Operation) interface{} {
+func Atomically[R any](op Operation[R]) R {
expvars.Add("atomically", 1)
// run the transaction
tx := newTx()
@@ -104,12 +103,12 @@ retry:
}
// AtomicGet is a helper function that atomically reads a value.
-func AtomicGet(v *Var) interface{} {
- return v.value.Load().(VarValue).Get()
+func AtomicGet[T any](v *Var[T]) T {
+ return v.value.Load().Get().(T)
}
// AtomicSet is a helper function that atomically writes a value.
-func AtomicSet(v *Var, val interface{}) {
+func AtomicSet[T any](v *Var[T], val T) {
v.mu.Lock()
v.changeValue(val)
v.mu.Unlock()
@@ -117,20 +116,19 @@ func AtomicSet(v *Var, val interface{}) {
// Compose is a helper function that composes multiple transactions into a
// single transaction.
-func Compose(fns ...Operation) Operation {
- return func(tx *Tx) interface{} {
+func Compose[R any](fns ...Operation[R]) Operation[struct{}] {
+ return VoidOperation(func(tx *Tx) {
for _, f := range fns {
f(tx)
}
- return nil
- }
+ })
}
// Select runs the supplied functions in order. Execution stops when a
// function succeeds without calling Retry. If no functions succeed, the
// entire selection will be retried.
-func Select(fns ...Operation) Operation {
- return func(tx *Tx) interface{} {
+func Select[R any](fns ...Operation[R]) Operation[R] {
+ return func(tx *Tx) R {
switch len(fns) {
case 0:
// empty Select blocks forever
@@ -140,7 +138,7 @@ func Select(fns ...Operation) Operation {
return fns[0](tx)
default:
oldWrites := tx.writes
- tx.writes = make(map[*Var]interface{}, len(oldWrites))
+ tx.writes = make(map[txVar]any, len(oldWrites))
for k, v := range oldWrites {
tx.writes[k] = v
}
@@ -155,23 +153,17 @@ func Select(fns ...Operation) Operation {
}
}
-type Operation func(*Tx) interface{}
+type Operation[R any] func(*Tx) R
-func VoidOperation(f func(*Tx)) Operation {
- return func(tx *Tx) interface{} {
+func VoidOperation(f func(*Tx)) Operation[struct{}] {
+ return func(tx *Tx) struct{} {
f(tx)
- return nil
+ return struct{}{}
}
}
-func AtomicModify(v *Var, f interface{}) {
- r := reflect.ValueOf(f)
+func AtomicModify[T any](v *Var[T], f func(T) T) {
Atomically(VoidOperation(func(tx *Tx) {
- cur := reflect.ValueOf(tx.Get(v))
- out := r.Call([]reflect.Value{cur})
- if lenOut := len(out); lenOut != 1 {
- panic(lenOut)
- }
- tx.Set(v, out[0].Interface())
+ v.Set(tx, f(v.Get(tx)))
}))
}
diff --git a/go.mod b/go.mod
index 4ef81a7..ed11f22 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/anacrolix/stm
-go 1.13
+go 1.18
require (
github.com/anacrolix/envpprof v1.0.0
@@ -9,3 +9,11 @@ require (
github.com/benbjohnson/immutable v0.2.0
github.com/stretchr/testify v1.3.0
)
+
+require (
+ github.com/alecthomas/atomic v0.1.0-alpha2
+ github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/huandu/xstrings v1.2.0 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+)
diff --git a/go.sum b/go.sum
index 360ee68..f0522d7 100644
--- a/go.sum
+++ b/go.sum
@@ -2,10 +2,13 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w=
-github.com/RoaringBitmap/roaring v0.4.17 h1:oCYFIFEMSQZrLHpywH7919esI1VSrQZ0pJXkZPGIJ78=
github.com/RoaringBitmap/roaring v0.4.17/go.mod h1:D3qVegWTmfCaX4Bl5CrBE9hfrSrrXIr8KVNvRsDi1NI=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
+github.com/alecthomas/assert/v2 v2.0.0-alpha3 h1:pcHeMvQ3OMstAWgaeaXIAL8uzB9xMm2zlxt+/4ml8lk=
+github.com/alecthomas/atomic v0.1.0-alpha2 h1:dqwXmax66gXvHhsOS4pGPZKqYOlTkapELkLb3MNdlH8=
+github.com/alecthomas/atomic v0.1.0-alpha2/go.mod h1:zD6QGEyw49HIq19caJDc2NMXAy8rNi9ROrxtMXATfyI=
+github.com/alecthomas/repr v0.0.0-20210801044451-80ca428c5142 h1:8Uy0oSf5co/NZXje7U1z8Mpep++QJOldL2hs/sBQf48=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/anacrolix/envpprof v0.0.0-20180404065416-323002cec2fa/go.mod h1:KgHhUaQMc8cC0+cEflSgCFNFbKwi5h54gqtVn8yhP7c=
@@ -16,12 +19,10 @@ github.com/anacrolix/missinggo v1.1.0/go.mod h1:MBJu3Sk/k3ZfGYcS7z18gwfu72Ey/xop
github.com/anacrolix/missinggo/v2 v2.2.0 h1:JUZh/gF/F4hXejj6I71wuO92MQDwQdLM3yRgYqTlmCg=
github.com/anacrolix/missinggo/v2 v2.2.0/go.mod h1:o0jgJoYOyaoYQ4E2ZMISVa9c88BbUBVQQW4QeRkNCGY=
github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
-github.com/anacrolix/tagflag v1.0.0 h1:NoxBVyke6iEtXfSY/n3lY3jNCBjQDu7aTvwHJxNLJAQ=
github.com/anacrolix/tagflag v1.0.0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/benbjohnson/immutable v0.2.0 h1:t0rW3lNFwfQ85IDO1mhMbumxdVSti4nnVaal4r45Oio=
github.com/benbjohnson/immutable v0.2.0/go.mod h1:uc6OHo6PN2++n98KHLxW8ef4W42ylHiQSENghE1ezxI=
-github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo=
github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c h1:FUUopH4brHNO2kJoNN3pV+OBEYmgraLT/KHZrMM69r0=
@@ -30,17 +31,14 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
-github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/glycerine/go-unsnap-stream v0.0.0-20180323001048-9f0cb55181dd/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE=
-github.com/glycerine/go-unsnap-stream v0.0.0-20181221182339-f9677308dec2 h1:Ujru1hufTHVb++eG6OuNDKMxZnGIvF6o/u8q/8h2+I4=
github.com/glycerine/go-unsnap-stream v0.0.0-20181221182339-f9677308dec2/go.mod h1:/20jfyN9Y5QPEAprSgKAUr+glWDY39ZiUEAYOEv5dsE=
github.com/glycerine/goconvey v0.0.0-20180728074245-46e3a41ad493/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24=
github.com/glycerine/goconvey v0.0.0-20190315024820-982ee783a72e/go.mod h1:Ogl1Tioa0aV7gstGFO7KhffUsb9M4ydbEbbxpcEDc24=
@@ -51,22 +49,20 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
-github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
-github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
-github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
+github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20190309154008-847fc94819f9/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
-github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo=
github.com/huandu/xstrings v1.2.0 h1:yPeWdRnmynF7p+lLYz0H2tthW9lqhMJrQV/U7yy4wX0=
@@ -77,7 +73,6 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
-github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
@@ -85,27 +80,20 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
-github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
-github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829 h1:D+CiwcpGTW6pL6bv6KI3KbyEyCKyS+1JWS2h8PNDnGA=
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
-github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f h1:BVwpUVJDADN2ufcGik7W992pyps0wZ888b/y9GXcLTU=
github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
-github.com/prometheus/common v0.2.0 h1:kUZDBDTdBVBYBj5Tmh2NZLlF60mfjA27rM34b+cVwNU=
github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
-github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1 h1:/K3IL0Z1quvmJ7X0A1AwNEK7CRkVK3YwfOU/QAL4WGg=
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
-github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 h1:GHRpF1pTW19a8tTFrMLUcfWwyC0pnifVo2ClaLq+hP8=
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
@@ -119,12 +107,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
-github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
-go.opencensus.io v0.20.2 h1:NAfh7zF0/3/HqtMvJNZ/RFrSlCE6ZTlHmKfhL/Dm1Jk=
go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
diff --git a/rate/rate_test.go b/rate/rate_test.go
index 7f44d74..3078b4c 100644
--- a/rate/rate_test.go
+++ b/rate/rate_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build go1.7
// +build go1.7
package rate
@@ -403,7 +404,7 @@ func runWait(t *testing.T, lim *Limiter, w wait) {
t.Helper()
start := time.Now()
err := lim.WaitN(w.ctx, w.n)
- delay := time.Now().Sub(start)
+ delay := time.Since(start)
if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) {
errString := "<nil>"
if !w.nilErr {
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
index a44f383..f521f66 100644
--- a/rate/ratelimit.go
+++ b/rate/ratelimit.go
@@ -13,9 +13,9 @@ import (
type numTokens = int
type Limiter struct {
- max *stm.Var
- cur *stm.Var
- lastAdd *stm.Var
+ max *stm.Var[numTokens]
+ cur *stm.Var[numTokens]
+ lastAdd *stm.Var[time.Time]
rate Limit
}
@@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter {
func (rl *Limiter) tokenGenerator(interval time.Duration) {
for {
- lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
+ lastAdd := stm.AtomicGet(rl.lastAdd)
time.Sleep(time.Until(lastAdd.Add(interval)))
now := time.Now()
available := numTokens(now.Sub(lastAdd) / interval)
@@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
continue
}
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(rl.cur).(numTokens)
- max := tx.Get(rl.max).(numTokens)
+ cur := rl.cur.Get(tx)
+ max := rl.max.Get(tx)
tx.Assert(cur < max)
newCur := cur + available
if newCur > max {
newCur = max
}
if newCur != cur {
- tx.Set(rl.cur, newCur)
+ rl.cur.Set(tx, newCur)
}
- tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
+ rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available)))
}))
}
}
@@ -77,9 +77,9 @@ func (rl *Limiter) Allow() bool {
}
func (rl *Limiter) AllowN(n numTokens) bool {
- return stm.Atomically(func(tx *stm.Tx) interface{} {
+ return stm.Atomically(func(tx *stm.Tx) bool {
return rl.takeTokens(tx, n)
- }).(bool)
+ })
}
func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
@@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
if rl.rate == Inf {
return true
}
- cur := tx.Get(rl.cur).(numTokens)
+ cur := rl.cur.Get(tx)
if cur >= n {
- tx.Set(rl.cur, cur-n)
+ rl.cur.Set(tx, cur-n)
return true
}
return false
@@ -105,25 +105,25 @@ func (rl *Limiter) Wait(ctx context.Context) error {
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
- if err := stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(ctxDone).(bool) {
+ if err := stm.Atomically(func(tx *stm.Tx) error {
+ if ctxDone.Get(tx) {
return ctx.Err()
}
if rl.takeTokens(tx, n) {
return nil
}
- if n > tx.Get(rl.max).(numTokens) {
+ if n > rl.max.Get(tx) {
return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
- if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
+ if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n {
return context.DeadlineExceeded
}
}
tx.Retry()
panic("unreachable")
}); err != nil {
- return err.(error)
+ return err
}
return nil
diff --git a/retry.go b/retry.go
index 1997b18..1adcfd0 100644
--- a/retry.go
+++ b/retry.go
@@ -11,7 +11,7 @@ var retries = pprof.NewProfile("stmRetries")
var retry = &struct{}{}
// catchRetry returns true if fn calls tx.Retry.
-func catchRetry(fn Operation, tx *Tx) (result interface{}, gotRetry bool) {
+func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) {
defer func() {
if r := recover(); r == retry {
gotRetry = true
diff --git a/stm_test.go b/stm_test.go
index 207148c..a98685b 100644
--- a/stm_test.go
+++ b/stm_test.go
@@ -14,14 +14,14 @@ func TestDecrement(t *testing.T) {
x := NewVar(1000)
for i := 0; i < 500; i++ {
go Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
}))
}
done := make(chan struct{})
go func() {
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x) == 500)
+ tx.Assert(x.Get(tx) == 500)
}))
close(done)
}()
@@ -50,10 +50,10 @@ func TestReadVerify(t *testing.T) {
// between the reads, causing this tx to retry.
var x2, y2 int
Atomically(VoidOperation(func(tx *Tx) {
- x2 = tx.Get(x).(int)
+ x2 = x.Get(tx)
read <- struct{}{}
<-read // wait for other tx to complete
- y2 = tx.Get(y).(int)
+ y2 = y.Get(tx)
}))
if x2 == 1 && y2 == 2 {
t.Fatal("read was not verified")
@@ -68,8 +68,8 @@ func TestRetry(t *testing.T) {
for i := 0; i < 10; i++ {
time.Sleep(10 * time.Millisecond)
Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
- tx.Set(x, cur-1)
+ cur := x.Get(tx)
+ x.Set(tx, cur-1)
}))
}
}()
@@ -77,7 +77,7 @@ func TestRetry(t *testing.T) {
// retry. This should result in no more than 1 retry per transaction.
retry := 0
Atomically(VoidOperation(func(tx *Tx) {
- cur := tx.Get(x).(int)
+ cur := x.Get(tx)
if cur != 0 {
retry++
tx.Retry()
@@ -100,9 +100,9 @@ func TestVerify(t *testing.T) {
go func() {
Atomically(VoidOperation(func(tx *Tx) {
<-read
- rx := tx.Get(x).(*foo)
+ rx := x.Get(tx)
rx.i = 7
- tx.Set(x, rx)
+ x.Set(tx, rx)
}))
read <- struct{}{}
// other tx should retry, so we need to read/send again
@@ -113,7 +113,7 @@ func TestVerify(t *testing.T) {
// between the reads, causing this tx to retry.
var i int
Atomically(VoidOperation(func(tx *Tx) {
- f := tx.Get(x).(*foo)
+ f := x.Get(tx)
i = f.i
read <- struct{}{}
<-read // wait for other tx to complete
@@ -125,36 +125,37 @@ func TestVerify(t *testing.T) {
func TestSelect(t *testing.T) {
// empty Select should panic
- require.Panics(t, func() { Atomically(Select()) })
+ require.Panics(t, func() { Atomically(Select[struct{}]()) })
// with one arg, Select adds no effect
x := NewVar(2)
Atomically(Select(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 2)
+ tx.Assert(x.Get(tx) == 2)
})))
picked := Atomically(Select(
// always blocks; should never be selected
- VoidOperation(func(tx *Tx) {
+ func(tx *Tx) int {
tx.Retry()
- }),
+ panic("unreachable")
+ },
// always succeeds; should always be selected
- func(tx *Tx) interface{} {
+ func(tx *Tx) int {
return 2
},
// always succeeds; should never be selected
- func(tx *Tx) interface{} {
+ func(tx *Tx) int {
return 3
},
- )).(int)
+ ))
assert.EqualValues(t, 2, picked)
}
func TestCompose(t *testing.T) {
nums := make([]int, 100)
- fns := make([]Operation, 100)
+ fns := make([]Operation[struct{}], 100)
for i := range fns {
- fns[i] = func(x int) Operation {
+ fns[i] = func(x int) Operation[struct{}] {
return VoidOperation(func(*Tx) { nums[x] = x })
}(i) // capture loop var
}
@@ -169,7 +170,7 @@ func TestCompose(t *testing.T) {
func TestPanic(t *testing.T) {
// normal panics should escape Atomically
assert.PanicsWithValue(t, "foo", func() {
- Atomically(func(*Tx) interface{} {
+ Atomically(func(*Tx) any {
panic("foo")
})
})
@@ -180,8 +181,8 @@ func TestReadWritten(t *testing.T) {
// previously written value
x := NewVar(3)
Atomically(VoidOperation(func(tx *Tx) {
- tx.Set(x, 5)
- tx.Assert(tx.Get(x).(int) == 5)
+ x.Set(tx, 5)
+ tx.Assert(x.Get(tx) == 5)
}))
}
@@ -191,7 +192,7 @@ func TestAtomicSetRetry(t *testing.T) {
done := make(chan struct{})
go func() {
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(x).(int) == 5)
+ tx.Assert(x.Get(tx) == 5)
}))
done <- struct{}{}
}()
@@ -210,17 +211,17 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) {
hits := NewVar(0)
ready := NewVar(true) // The ball is ready for hitting.
var wg sync.WaitGroup
- bat := func(from, to interface{}, noise string) {
+ bat := func(from, to bool, noise string) {
defer wg.Done()
- for !Atomically(func(tx *Tx) interface{} {
- if tx.Get(doneVar).(bool) {
+ for !Atomically(func(tx *Tx) any {
+ if doneVar.Get(tx) {
return true
}
- tx.Assert(tx.Get(ready).(bool))
- if tx.Get(ball) == from {
- tx.Set(ball, to)
- tx.Set(hits, tx.Get(hits).(int)+1)
- tx.Set(ready, false)
+ tx.Assert(ready.Get(tx))
+ if ball.Get(tx) == from {
+ ball.Set(tx, to)
+ hits.Set(tx, hits.Get(tx)+1)
+ ready.Set(tx, false)
return false
}
return tx.Retry()
@@ -233,8 +234,8 @@ func testPingPong(t testing.TB, n int, afterHit func(string)) {
go bat(false, true, "ping!")
go bat(true, false, "pong!")
Atomically(VoidOperation(func(tx *Tx) {
- tx.Assert(tx.Get(hits).(int) >= n)
- tx.Set(doneVar, true)
+ tx.Assert(hits.Get(tx) >= n)
+ doneVar.Set(tx, true)
}))
wg.Wait()
}
@@ -245,7 +246,7 @@ func TestPingPong(t *testing.T) {
func TestSleepingBeauty(t *testing.T) {
require.Panics(t, func() {
- Atomically(func(tx *Tx) interface{} {
+ Atomically(func(tx *Tx) any {
tx.Assert(false)
return nil
})
@@ -253,7 +254,7 @@ func TestSleepingBeauty(t *testing.T) {
}
//func TestRetryStack(t *testing.T) {
-// v := NewVar(nil)
+// v := NewVar[int](nil)
// go func() {
// i := 0
// for {
@@ -261,12 +262,12 @@ func TestSleepingBeauty(t *testing.T) {
// i++
// }
// }()
-// Atomically(func(tx *Tx) interface{} {
+// Atomically(func(tx *Tx) any {
// debug.PrintStack()
// ret := func() {
// defer Atomically(nil)
// }
-// tx.Get(v)
+// v.Get(tx)
// tx.Assert(false)
// return ret
// })
diff --git a/stmutil/containers.go b/stmutil/containers.go
index e0b532d..c7a4a49 100644
--- a/stmutil/containers.go
+++ b/stmutil/containers.go
@@ -9,10 +9,10 @@ import (
)
type Settish interface {
- Add(interface{}) Settish
- Delete(interface{}) Settish
- Contains(interface{}) bool
- Range(func(interface{}) bool)
+ Add(any) Settish
+ Delete(any) Settish
+ Contains(any) bool
+ Range(func(any) bool)
iter.Iterable
Len() int
}
@@ -23,11 +23,11 @@ type mapToSet struct {
type interhash struct{}
-func (interhash) Hash(x interface{}) uint32 {
+func (interhash) Hash(x any) uint32 {
return uint32(nilinterhash(unsafe.Pointer(&x), 0))
}
-func (interhash) Equal(i, j interface{}) bool {
+func (interhash) Equal(i, j any) bool {
return i == j
}
@@ -39,12 +39,12 @@ func NewSortedSet(lesser lessFunc) Settish {
return mapToSet{NewSortedMap(lesser)}
}
-func (s mapToSet) Add(x interface{}) Settish {
+func (s mapToSet) Add(x any) Settish {
s.m = s.m.Set(x, nil)
return s
}
-func (s mapToSet) Delete(x interface{}) Settish {
+func (s mapToSet) Delete(x any) Settish {
s.m = s.m.Delete(x)
return s
}
@@ -53,13 +53,13 @@ func (s mapToSet) Len() int {
return s.m.Len()
}
-func (s mapToSet) Contains(x interface{}) bool {
+func (s mapToSet) Contains(x any) bool {
_, ok := s.m.Get(x)
return ok
}
-func (s mapToSet) Range(f func(interface{}) bool) {
- s.m.Range(func(k, _ interface{}) bool {
+func (s mapToSet) Range(f func(any) bool) {
+ s.m.Range(func(k, _ any) bool {
return f(k)
})
}
@@ -78,17 +78,17 @@ func NewMap() Mappish {
var _ Mappish = Map{}
-func (m Map) Delete(x interface{}) Mappish {
+func (m Map) Delete(x any) Mappish {
m.Map = m.Map.Delete(x)
return m
}
-func (m Map) Set(key, value interface{}) Mappish {
+func (m Map) Set(key, value any) Mappish {
m.Map = m.Map.Set(key, value)
return m
}
-func (sm Map) Range(f func(key, value interface{}) bool) {
+func (sm Map) Range(f func(key, value any) bool) {
iter := sm.Map.Iterator()
for !iter.Done() {
if !f(iter.Next()) {
@@ -98,7 +98,7 @@ func (sm Map) Range(f func(key, value interface{}) bool) {
}
func (sm Map) Iter(cb iter.Callback) {
- sm.Range(func(key, _ interface{}) bool {
+ sm.Range(func(key, _ any) bool {
return cb(key)
})
}
@@ -107,17 +107,17 @@ type SortedMap struct {
*immutable.SortedMap
}
-func (sm SortedMap) Set(key, value interface{}) Mappish {
+func (sm SortedMap) Set(key, value any) Mappish {
sm.SortedMap = sm.SortedMap.Set(key, value)
return sm
}
-func (sm SortedMap) Delete(key interface{}) Mappish {
+func (sm SortedMap) Delete(key any) Mappish {
sm.SortedMap = sm.SortedMap.Delete(key)
return sm
}
-func (sm SortedMap) Range(f func(key, value interface{}) bool) {
+func (sm SortedMap) Range(f func(key, value any) bool) {
iter := sm.SortedMap.Iterator()
for !iter.Done() {
if !f(iter.Next()) {
@@ -127,18 +127,18 @@ func (sm SortedMap) Range(f func(key, value interface{}) bool) {
}
func (sm SortedMap) Iter(cb iter.Callback) {
- sm.Range(func(key, _ interface{}) bool {
+ sm.Range(func(key, _ any) bool {
return cb(key)
})
}
-type lessFunc func(l, r interface{}) bool
+type lessFunc func(l, r any) bool
type comparer struct {
less lessFunc
}
-func (me comparer) Compare(i, j interface{}) int {
+func (me comparer) Compare(i, j any) int {
if me.less(i, j) {
return -1
} else if me.less(j, i) {
@@ -155,15 +155,15 @@ func NewSortedMap(less lessFunc) Mappish {
}
type Mappish interface {
- Set(key, value interface{}) Mappish
- Delete(key interface{}) Mappish
- Get(key interface{}) (interface{}, bool)
- Range(func(_, _ interface{}) bool)
+ Set(key, value any) Mappish
+ Delete(key any) Mappish
+ Get(key any) (any, bool)
+ Range(func(_, _ any) bool)
Len() int
iter.Iterable
}
-func GetLeft(l, _ interface{}) interface{} {
+func GetLeft(l, _ any) any {
return l
}
@@ -171,7 +171,7 @@ func GetLeft(l, _ interface{}) interface{} {
//go:linkname nilinterhash runtime.nilinterhash
func nilinterhash(p unsafe.Pointer, h uintptr) uintptr
-func interfaceHash(x interface{}) uint32 {
+func interfaceHash(x any) uint32 {
return uint32(nilinterhash(unsafe.Pointer(&x), 0))
}
diff --git a/stmutil/context.go b/stmutil/context.go
index 8a4d58d..6f8ba9b 100644
--- a/stmutil/context.go
+++ b/stmutil/context.go
@@ -9,12 +9,12 @@ import (
var (
mu sync.Mutex
- ctxVars = map[context.Context]*stm.Var{}
+ ctxVars = map[context.Context]*stm.Var[bool]{}
)
// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be
// called when the user is no longer interested in the var.
-func ContextDoneVar(ctx context.Context) (*stm.Var, func()) {
+func ContextDoneVar(ctx context.Context) (*stm.Var[bool], func()) {
mu.Lock()
defer mu.Unlock()
if v, ok := ctxVars[ctx]; ok {
diff --git a/tx.go b/tx.go
index 9be08b5..825ff01 100644
--- a/tx.go
+++ b/tx.go
@@ -5,13 +5,22 @@ import (
"sort"
"sync"
"unsafe"
+
+ "github.com/alecthomas/atomic"
)
+type txVar interface {
+ getValue() *atomic.Value[VarValue]
+ changeValue(any)
+ getWatchers() *sync.Map
+ getLock() *sync.Mutex
+}
+
// A Tx represents an atomic transaction.
type Tx struct {
- reads map[*Var]VarValue
- writes map[*Var]interface{}
- watching map[*Var]struct{}
+ reads map[txVar]VarValue
+ writes map[txVar]any
+ watching map[txVar]struct{}
locks txLocks
mu sync.Mutex
cond sync.Cond
@@ -24,7 +33,7 @@ type Tx struct {
// Check that none of the logged values have changed since the transaction began.
func (tx *Tx) inputsChanged() bool {
for v, read := range tx.reads {
- if read.Changed(v.value.Load().(VarValue)) {
+ if read.Changed(v.getValue().Load()) {
return true
}
}
@@ -42,12 +51,12 @@ func (tx *Tx) updateWatchers() {
for v := range tx.watching {
if _, ok := tx.reads[v]; !ok {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
}
for v := range tx.reads {
if _, ok := tx.watching[v]; !ok {
- v.watchers.Store(tx, nil)
+ v.getWatchers().Store(tx, nil)
tx.watching[v] = struct{}{}
}
}
@@ -76,22 +85,22 @@ func (tx *Tx) wait() {
}
// Get returns the value of v as of the start of the transaction.
-func (tx *Tx) Get(v *Var) interface{} {
+func (v *Var[T]) Get(tx *Tx) T {
// If we previously wrote to v, it will be in the write log.
if val, ok := tx.writes[v]; ok {
- return val
+ return val.(T)
}
// If we haven't previously read v, record its version
vv, ok := tx.reads[v]
if !ok {
- vv = v.value.Load().(VarValue)
+ vv = v.getValue().Load()
tx.reads[v] = vv
}
- return vv.Get()
+ return vv.Get().(T)
}
// Set sets the value of a Var for the lifetime of the transaction.
-func (tx *Tx) Set(v *Var, val interface{}) {
+func (v *Var[T]) Set(tx *Tx, val T) {
if v == nil {
panic("nil Var")
}
@@ -105,11 +114,10 @@ type txProfileValue struct {
// Retry aborts the transaction and retries it when a Var changes. You can return from this method
// to satisfy return values, but it should never actually return anything as it panics internally.
-func (tx *Tx) Retry() interface{} {
+func (tx *Tx) Retry() struct{} {
retries.Add(txProfileValue{tx, tx.numRetryValues}, 1)
tx.numRetryValues++
panic(retry)
- panic("unreachable")
}
// Assert is a helper function that retries a transaction if the condition is
@@ -143,7 +151,7 @@ func (tx *Tx) removeRetryProfiles() {
func (tx *Tx) recycle() {
for v := range tx.watching {
delete(tx.watching, v)
- v.watchers.Delete(tx)
+ v.getWatchers().Delete(tx)
}
tx.removeRetryProfiles()
// I don't think we can reuse Txs, because the "completed" field should/needs to be set
@@ -164,7 +172,7 @@ func (tx *Tx) resetLocks() {
func (tx *Tx) collectReadLocks() {
for v := range tx.reads {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
@@ -172,7 +180,7 @@ func (tx *Tx) collectAllLocks() {
tx.collectReadLocks()
for v := range tx.writes {
if _, ok := tx.reads[v]; !ok {
- tx.locks.append(&v.mu)
+ tx.locks.append(v.getLock())
}
}
}
diff --git a/var-value.go b/var-value.go
index ff97104..3518bf6 100644
--- a/var-value.go
+++ b/var-value.go
@@ -1,51 +1,51 @@
package stm
type VarValue interface {
- Set(interface{}) VarValue
- Get() interface{}
+ Set(any) VarValue
+ Get() any
Changed(VarValue) bool
}
type version uint64
-type versionedValue struct {
- value interface{}
+type versionedValue[T any] struct {
+ value T
version version
}
-func (me versionedValue) Set(newValue interface{}) VarValue {
- return versionedValue{
- value: newValue,
+func (me versionedValue[T]) Set(newValue any) VarValue {
+ return versionedValue[T]{
+ value: newValue.(T),
version: me.version + 1,
}
}
-func (me versionedValue) Get() interface{} {
+func (me versionedValue[T]) Get() any {
return me.value
}
-func (me versionedValue) Changed(other VarValue) bool {
- return me.version != other.(versionedValue).version
+func (me versionedValue[T]) Changed(other VarValue) bool {
+ return me.version != other.(versionedValue[T]).version
}
-type customVarValue struct {
- value interface{}
- changed func(interface{}, interface{}) bool
+type customVarValue[T any] struct {
+ value T
+ changed func(T, T) bool
}
-var _ VarValue = customVarValue{}
+var _ VarValue = customVarValue[struct{}]{}
-func (me customVarValue) Changed(other VarValue) bool {
- return me.changed(me.value, other.(customVarValue).value)
+func (me customVarValue[T]) Changed(other VarValue) bool {
+ return me.changed(me.value, other.(customVarValue[T]).value)
}
-func (me customVarValue) Set(newValue interface{}) VarValue {
- return customVarValue{
- value: newValue,
+func (me customVarValue[T]) Set(newValue any) VarValue {
+ return customVarValue[T]{
+ value: newValue.(T),
changed: me.changed,
}
}
-func (me customVarValue) Get() interface{} {
+func (me customVarValue[T]) Get() any {
return me.value
}
diff --git a/var.go b/var.go
index 3f43096..ec6a81e 100644
--- a/var.go
+++ b/var.go
@@ -2,18 +2,31 @@ package stm
import (
"sync"
- "sync/atomic"
+
+ "github.com/alecthomas/atomic"
)
// Holds an STM variable.
-type Var struct {
- value atomic.Value
+type Var[T any] struct {
+ value atomic.Value[VarValue]
watchers sync.Map
mu sync.Mutex
}
-func (v *Var) changeValue(new interface{}) {
- old := v.value.Load().(VarValue)
+func (v *Var[T]) getValue() *atomic.Value[VarValue] {
+ return &v.value
+}
+
+func (v *Var[T]) getWatchers() *sync.Map {
+ return &v.watchers
+}
+
+func (v *Var[T]) getLock() *sync.Mutex {
+ return &v.mu
+}
+
+func (v *Var[T]) changeValue(new any) {
+ old := v.value.Load()
newVarValue := old.Set(new)
v.value.Store(newVarValue)
if old.Changed(newVarValue) {
@@ -21,8 +34,8 @@ func (v *Var) changeValue(new interface{}) {
}
}
-func (v *Var) wakeWatchers(new VarValue) {
- v.watchers.Range(func(k, _ interface{}) bool {
+func (v *Var[T]) wakeWatchers(new VarValue) {
+ v.watchers.Range(func(k, _ any) bool {
tx := k.(*Tx)
// We have to lock here to ensure that the Tx is waiting before we signal it. Otherwise we
// could signal it before it goes to sleep and it will miss the notification.
@@ -34,35 +47,30 @@ func (v *Var) wakeWatchers(new VarValue) {
}
}
tx.mu.Unlock()
- return !v.value.Load().(VarValue).Changed(new)
+ return !v.value.Load().Changed(new)
})
}
-type varSnapshot struct {
- val interface{}
- version uint64
-}
-
// Returns a new STM variable.
-func NewVar(val interface{}) *Var {
- v := &Var{}
- v.value.Store(versionedValue{
+func NewVar[T any](val T) *Var[T] {
+ v := &Var[T]{}
+ v.value.Store(versionedValue[T]{
value: val,
})
return v
}
-func NewCustomVar(val interface{}, changed func(interface{}, interface{}) bool) *Var {
- v := &Var{}
- v.value.Store(customVarValue{
+func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] {
+ v := &Var[T]{}
+ v.value.Store(customVarValue[T]{
value: val,
changed: changed,
})
return v
}
-func NewBuiltinEqVar(val interface{}) *Var {
- return NewCustomVar(val, func(a, b interface{}) bool {
+func NewBuiltinEqVar[T comparable](val T) *Var[T] {
+ return NewCustomVar(val, func(a, b T) bool {
return a != b
})
}