/// Package stm provides Software Transactional Memory operations for Go. This /// is an alternative to the standard way of writing concurrent code (channels /// and mutexes). STM makes it easy to perform arbitrarily complex operations /// in an atomic fashion. One of its primary advantages over traditional /// locking is that STM transactions are composable, whereas locking functions /// are not -- the composition will either deadlock or release the lock between /// functions (making it non-atomic). /// /// To begin, create an STM object that wraps the data you want to access /// concurrently. /// /// x := stm.NewVar[int](3) /// /// You can then use the Atomically method to atomically read and/or write the /// the data. This code atomically decrements x: /// /// stm.Atomically(func(tx *stm.Tx) { /// cur := x.Get(tx) /// x.Set(tx, cur-1) /// }) /// /// An important part of STM transactions is retrying. At any point during the /// transaction, you can call tx.Retry(), which will abort the transaction, but /// not cancel it entirely. The call to Atomically will block until another /// call to Atomically finishes, at which point the transaction will be rerun. /// Specifically, one of the values read by the transaction (via tx.Get) must be /// updated before the transaction will be rerun. As an example, this code will /// try to decrement x, but will block as long as x is zero: /// /// stm.Atomically(func(tx *stm.Tx) { /// cur := x.Get(tx) /// if cur == 0 { /// tx.Retry() /// } /// x.Set(tx, cur-1) /// }) /// /// Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any /// other value will cancel the transaction; no values will be changed. /// However, it is the responsibility of the caller to catch such panics. /// /// Multiple transactions can be composed using Select. If the first /// transaction calls Retry, the next transaction will be run, and so on. If /// all of the transactions call Retry, the call will block and the entire /// selection will be retried. For example, this code implements the /// "decrement-if-nonzero" transaction above, but for two values. It will /// first try to decrement x, then y, and block if both values are zero. /// /// func dec(v *stm.Var[int]) { /// return func(tx *stm.Tx) { /// cur := v.Get(tx) /// if cur == 0 { /// tx.Retry() /// } /// v.Set(tx, cur-1) /// } /// } /// /// // Note that Select does not perform any work itself, but merely /// // returns a transaction function. /// stm.Atomically(stm.Select(dec(x), dec(y))) /// /// An important caveat: transactions must be idempotent (they should have the /// same effect every time they are invoked). This is because a transaction may /// be retried several times before successfully completing, meaning its side /// effects may execute more than once. This will almost certainly cause /// incorrect behavior. One common way to get around this is to build up a list /// of impure operations inside the transaction, and then perform them after the /// transaction completes. /// /// The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but /// Haskell can enforce at compile time that STM variables are not modified /// outside the STM monad. This is not possible in Go, so be especially careful /// when using pointers in your STM code. Remember: modifying a pointer is a /// side effect! package stm import ( "context" "errors" "expvar" "fmt" "math" "math/rand" "runtime/pprof" "sort" "sync" "sync/atomic" "time" "unsafe" "pds" ) // Package atomic contains type-safe atomic types. // // The zero value for the numeric types cannot be used. Use New*. The // rationale for this behaviour is that copying an atomic integer is not // reliable. Copying can be prevented by embedding sync.Mutex, but that bloats // the type. // Interface represents atomic operations on a value. type Interface[T any] interface { // Load value atomically. Load() T // Store value atomically. Store(value T) // Swap the previous value with the new value atomically. Swap(new T) (old T) // CompareAndSwap the previous value with new if its value is "old". CompareAndSwap(old, new T) (swapped bool) } var _ Interface[bool] = &Value[bool]{} // Value wraps any generic value in atomic load and store operations. // // The zero value should be initialised using [Value.Store] before use. type Value[T any] struct { value atomic.Value } // New atomic Value. func New[T any](seed T) *Value[T] { v := &Value[T]{} v.value.Store(seed) return v } // Load value atomically. // // Will panic if the value is nil. func (v *Value[T]) Load() (out T) { value := v.value.Load() if value == nil { panic("nil value in atomic.Value") } return value.(T) } func (v *Value[T]) Store(value T) { v.value.Store(value) } func (v *Value[T]) Swap(new T) (old T) { return v.value.Swap(new).(T) } func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { return v.value.CompareAndSwap(old, new) } // atomicint defines the types that atomic integer operations are supported on. type atomicint interface { int32 | uint32 | int64 | uint64 } // Int expresses atomic operations on signed or unsigned integer values. type Int[T atomicint] interface { Interface[T] // Add a value and return the new result. Add(delta T) (new T) } // Currently not supported by Go's generic type system: // // ./atomic.go:48:9: cannot use type switch on type parameter value v (variable of type T constrained by atomicint) // // // ForInt infers and creates an atomic Int[T] type for a value. // func ForInt[T atomicint](v T) Int[T] { // switch v.(type) { // case int32: // return NewInt32(v) // case uint32: // return NewUint32(v) // case int64: // return NewInt64(v) // case uint64: // return NewUint64(v) // } // panic("can't happen") // } // Int32 atomic value. // // Copying creates an alias. The zero value is not usable, use NewInt32. type Int32 struct{ value *int32 } // NewInt32 creates a new atomic integer with an initial value. func NewInt32(value int32) Int32 { return Int32{value: &value} } var _ Int[int32] = &Int32{} func (i Int32) Add(delta int32) (new int32) { return atomic.AddInt32(i.value, delta) } func (i Int32) Load() (val int32) { return atomic.LoadInt32(i.value) } func (i Int32) Store(val int32) { atomic.StoreInt32(i.value, val) } func (i Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(i.value, new) } func (i Int32) CompareAndSwap(old, new int32) (swapped bool) { return atomic.CompareAndSwapInt32(i.value, old, new) } // Uint32 atomic value. // // Copying creates an alias. type Uint32 struct{ value *uint32 } var _ Int[uint32] = Uint32{} // NewUint32 creates a new atomic integer with an initial value. func NewUint32(value uint32) Uint32 { return Uint32{value: &value} } func (i Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(i.value, delta) } func (i Uint32) Load() (val uint32) { return atomic.LoadUint32(i.value) } func (i Uint32) Store(val uint32) { atomic.StoreUint32(i.value, val) } func (i Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(i.value, new) } func (i Uint32) CompareAndSwap(old, new uint32) (swapped bool) { return atomic.CompareAndSwapUint32(i.value, old, new) } // Int64 atomic value. // // Copying creates an alias. type Int64 struct{ value *int64 } var _ Int[int64] = Int64{} // NewInt64 creates a new atomic integer with an initial value. func NewInt64(value int64) Int64 { return Int64{value: &value} } func (i Int64) Add(delta int64) (new int64) { return atomic.AddInt64(i.value, delta) } func (i Int64) Load() (val int64) { return atomic.LoadInt64(i.value) } func (i Int64) Store(val int64) { atomic.StoreInt64(i.value, val) } func (i Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(i.value, new) } func (i Int64) CompareAndSwap(old, new int64) (swapped bool) { return atomic.CompareAndSwapInt64(i.value, old, new) } // Uint64 atomic value. // // Copying creates an alias. type Uint64 struct{ value *uint64 } var _ Int[uint64] = Uint64{} // NewUint64 creates a new atomic integer with an initial value. func NewUint64(value uint64) Uint64 { return Uint64{value: &value} } func (i Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(i.value, delta) } func (i Uint64) Load() (val uint64) { return atomic.LoadUint64(i.value) } func (i Uint64) Store(val uint64) { atomic.StoreUint64(i.value, val) } func (i Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(i.value, new) } func (i Uint64) CompareAndSwap(old, new uint64) (swapped bool) { return atomic.CompareAndSwapUint64(i.value, old, new) } func F() { } var ( txPool = sync.Pool{New: func() any { expvars.Add("new txs", 1) tx := &Tx{ reads: make(map[txVar]VarValue), writes: make(map[txVar]any), watching: make(map[txVar]struct{}), } tx.cond.L = &tx.mu return tx }} failedCommitsProfile *pprof.Profile ) const ( profileFailedCommits = false sleepBetweenRetries = false ) func init() { if profileFailedCommits { failedCommitsProfile = pprof.NewProfile("stmFailedCommits") } } func newTx() *Tx { tx := txPool.Get().(*Tx) tx.tries = 0 tx.completed = false return tx } func WouldBlock[R any](fn Operation[R]) (block bool) { tx := newTx() tx.reset() _, block = catchRetry(fn, tx) if len(tx.watching) != 0 { panic("shouldn't have installed any watchers") } tx.recycle() return } // Atomically executes the atomic function fn. func Atomically[R any](op Operation[R]) R { expvars.Add("atomically", 1) // run the transaction tx := newTx() retry: tx.tries++ tx.reset() if sleepBetweenRetries { shift := int64(tx.tries - 1) const maxShift = 30 if shift > maxShift { shift = maxShift } ns := int64(1) << shift d := time.Duration(rand.Int63n(ns)) if d > 100*time.Microsecond { tx.updateWatchers() time.Sleep(time.Duration(ns)) } } tx.mu.Lock() ret, retry := catchRetry(op, tx) tx.mu.Unlock() if retry { expvars.Add("retries", 1) // wait for one of the variables we read to change before retrying tx.wait() goto retry } // verify the read log tx.lockAllVars() if tx.inputsChanged() { tx.unlock() expvars.Add("failed commits", 1) if profileFailedCommits { failedCommitsProfile.Add(new(int), 0) } goto retry } // commit the write log and broadcast that variables have changed tx.commit() tx.mu.Lock() tx.completed = true tx.cond.Broadcast() tx.mu.Unlock() tx.unlock() expvars.Add("commits", 1) tx.recycle() return ret } // AtomicGet is a helper function that atomically reads a value. func Deref[T any](v *Var[T]) T { return v.value.Load().Get().(T) } // AtomicSet is a helper function that atomically writes a value. func AtomicSet[T any](v *Var[T], val T) { v.mu.Lock() v.changeValue(val) v.mu.Unlock() } // Compose is a helper function that composes multiple transactions into a // single transaction. func Compose[R any](fns ...Operation[R]) Operation[struct{}] { return VoidOperation(func(tx *Tx) { for _, f := range fns { f(tx) } }) } // Select runs the supplied functions in order. Execution stops when a // function succeeds without calling Retry. If no functions succeed, the // entire selection will be retried. func Select[R any](fns ...Operation[R]) Operation[R] { return func(tx *Tx) R { switch len(fns) { case 0: // empty Select blocks forever tx.Retry() panic("unreachable") case 1: return fns[0](tx) default: oldWrites := tx.writes tx.writes = make(map[txVar]any, len(oldWrites)) for k, v := range oldWrites { tx.writes[k] = v } ret, retry := catchRetry(fns[0], tx) if retry { tx.writes = oldWrites return Select(fns[1:]...)(tx) } else { return ret } } } } type Operation[R any] func(*Tx) R func VoidOperation(f func(*Tx)) Operation[struct{}] { return func(tx *Tx) struct{} { f(tx) return struct{}{} } } func Swap[T any](v *Var[T], f func(T) T) { Atomically(VoidOperation(func(tx *Tx) { v.Set(tx, f(v.Get(tx))) })) } var expvars = expvar.NewMap("stm") type VarValue interface { Set(any) VarValue Get() any Changed(VarValue) bool } type version uint64 type versionedValue[T any] struct { value T version version } func (me versionedValue[T]) Set(newValue any) VarValue { return versionedValue[T]{ value: newValue.(T), version: me.version + 1, } } func (me versionedValue[T]) Get() any { return me.value } func (me versionedValue[T]) Changed(other VarValue) bool { return me.version != other.(versionedValue[T]).version } type customVarValue[T any] struct { value T changed func(T, T) bool } var _ VarValue = customVarValue[struct{}]{} func (me customVarValue[T]) Changed(other VarValue) bool { return me.changed(me.value, other.(customVarValue[T]).value) } func (me customVarValue[T]) Set(newValue any) VarValue { return customVarValue[T]{ value: newValue.(T), changed: me.changed, } } func (me customVarValue[T]) Get() any { return me.value } type txVar interface { getValue() *Value[VarValue] changeValue(any) getWatchers() *sync.Map getLock() *sync.Mutex } // A Tx represents an atomic transaction. type Tx struct { reads map[txVar]VarValue writes map[txVar]any watching map[txVar]struct{} locks txLocks mu sync.Mutex cond sync.Cond waiting bool completed bool tries int numRetryValues int } // Check that none of the logged values have changed since the transaction began. func (tx *Tx) inputsChanged() bool { for v, read := range tx.reads { if read.Changed(v.getValue().Load()) { return true } } return false } // Writes the values in the transaction log to their respective Vars. func (tx *Tx) commit() { for v, val := range tx.writes { v.changeValue(val) } } func (tx *Tx) updateWatchers() { for v := range tx.watching { if _, ok := tx.reads[v]; !ok { delete(tx.watching, v) v.getWatchers().Delete(tx) } } for v := range tx.reads { if _, ok := tx.watching[v]; !ok { v.getWatchers().Store(tx, nil) tx.watching[v] = struct{}{} } } } // wait blocks until another transaction modifies any of the Vars read by tx. func (tx *Tx) wait() { if len(tx.reads) == 0 { panic("not waiting on anything") } tx.updateWatchers() tx.mu.Lock() firstWait := true for !tx.inputsChanged() { if !firstWait { expvars.Add("wakes for unchanged versions", 1) } expvars.Add("waits", 1) tx.waiting = true tx.cond.Broadcast() tx.cond.Wait() tx.waiting = false firstWait = false } tx.mu.Unlock() } // Get returns the value of v as of the start of the transaction. func (v *Var[T]) Get(tx *Tx) T { // If we previously wrote to v, it will be in the write log. if val, ok := tx.writes[v]; ok { return val.(T) } // If we haven't previously read v, record its version vv, ok := tx.reads[v] if !ok { vv = v.getValue().Load() tx.reads[v] = vv } return vv.Get().(T) } // Set sets the value of a Var for the lifetime of the transaction. func (v *Var[T]) Set(tx *Tx, val T) { if v == nil { panic("nil Var") } tx.writes[v] = val } type txProfileValue struct { *Tx int } // Retry aborts the transaction and retries it when a Var changes. You can return from this method // to satisfy return values, but it should never actually return anything as it panics internally. func (tx *Tx) Retry() struct{} { retries.Add(txProfileValue{tx, tx.numRetryValues}, 1) tx.numRetryValues++ panic(retry) } // Assert is a helper function that retries a transaction if the condition is // not satisfied. func (tx *Tx) Assert(p bool) { if !p { tx.Retry() } } func (tx *Tx) reset() { tx.mu.Lock() for k := range tx.reads { delete(tx.reads, k) } for k := range tx.writes { delete(tx.writes, k) } tx.mu.Unlock() tx.removeRetryProfiles() tx.resetLocks() } func (tx *Tx) removeRetryProfiles() { for tx.numRetryValues > 0 { tx.numRetryValues-- retries.Remove(txProfileValue{tx, tx.numRetryValues}) } } func (tx *Tx) recycle() { for v := range tx.watching { delete(tx.watching, v) v.getWatchers().Delete(tx) } tx.removeRetryProfiles() // I don't think we can reuse Txs, because the "completed" field should/needs to be set // indefinitely after use. //txPool.Put(tx) } func (tx *Tx) lockAllVars() { tx.resetLocks() tx.collectAllLocks() tx.sortLocks() tx.lock() } func (tx *Tx) resetLocks() { tx.locks.clear() } func (tx *Tx) collectReadLocks() { for v := range tx.reads { tx.locks.append(v.getLock()) } } func (tx *Tx) collectAllLocks() { tx.collectReadLocks() for v := range tx.writes { if _, ok := tx.reads[v]; !ok { tx.locks.append(v.getLock()) } } } func (tx *Tx) sortLocks() { sort.Sort(&tx.locks) } func (tx *Tx) lock() { for _, l := range tx.locks.mus { l.Lock() } } func (tx *Tx) unlock() { for _, l := range tx.locks.mus { l.Unlock() } } func (tx *Tx) String() string { return fmt.Sprintf("%[1]T %[1]p", tx) } // Dedicated type avoids reflection in sort.Slice. type txLocks struct { mus []*sync.Mutex } func (me txLocks) Len() int { return len(me.mus) } func (me txLocks) Less(i, j int) bool { return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j])) } func (me txLocks) Swap(i, j int) { me.mus[i], me.mus[j] = me.mus[j], me.mus[i] } func (me *txLocks) clear() { me.mus = me.mus[:0] } func (me *txLocks) append(mu *sync.Mutex) { me.mus = append(me.mus, mu) } // Holds an STM variable. type Var[T any] struct { value Value[VarValue] watchers sync.Map mu sync.Mutex } func (v *Var[T]) getValue() *Value[VarValue] { return &v.value } func (v *Var[T]) getWatchers() *sync.Map { return &v.watchers } func (v *Var[T]) getLock() *sync.Mutex { return &v.mu } func (v *Var[T]) changeValue(new any) { old := v.value.Load() newVarValue := old.Set(new) v.value.Store(newVarValue) if old.Changed(newVarValue) { go v.wakeWatchers(newVarValue) } } func (v *Var[T]) wakeWatchers(new VarValue) { v.watchers.Range(func(k, _ any) bool { tx := k.(*Tx) // We have to lock here to ensure that the Tx is waiting before // we signal it. Otherwise we could signal it before it goes to // sleep and it will miss the notification. tx.mu.Lock() if read := tx.reads[v]; read != nil && read.Changed(new) { tx.cond.Broadcast() for !tx.waiting && !tx.completed { tx.cond.Wait() } } tx.mu.Unlock() return !v.value.Load().Changed(new) }) } // Returns a new STM variable. func NewVar[T any](val T) *Var[T] { v := &Var[T]{} v.value.Store(versionedValue[T]{ value: val, }) return v } func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { v := &Var[T]{} v.value.Store(customVarValue[T]{ value: val, changed: changed, }) return v } func NewBuiltinEqVar[T comparable](val T) *Var[T] { return NewCustomVar(val, func(a, b T) bool { return a != b }) } var retries = pprof.NewProfile("stmRetries") // retry is a sentinel value. When thrown via panic, it indicates that a // transaction should be retried. var retry = &struct{}{} // catchRetry returns true if fn calls tx.Retry. func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) { defer func() { if r := recover(); r == retry { gotRetry = true } else if r != nil { panic(r) } }() result = fn(tx) return } // This is the type constraint for keys passed through from pds type KeyConstraint interface { comparable } type Settish[K KeyConstraint] interface { Add(K) Settish[K] Delete(K) Settish[K] Contains(K) bool Range(func(K) bool) Len() int // iter.Iterable } type mapToSet[K KeyConstraint] struct { m Mappish[K, struct{}] } type interhash[K KeyConstraint] struct{} func (interhash[K]) Hash(x K) uint32 { return uint32(nilinterhash(unsafe.Pointer(&x), 0)) } func (interhash[K]) Equal(i, j K) bool { return i == j } func NewSet[K KeyConstraint]() Settish[K] { return mapToSet[K]{NewMap[K, struct{}]()} } func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] { return mapToSet[K]{NewSortedMap[K, struct{}](lesser)} } func (s mapToSet[K]) Add(x K) Settish[K] { s.m = s.m.Set(x, struct{}{}) return s } func (s mapToSet[K]) Delete(x K) Settish[K] { s.m = s.m.Delete(x) return s } func (s mapToSet[K]) Len() int { return s.m.Len() } func (s mapToSet[K]) Contains(x K) bool { _, ok := s.m.Get(x) return ok } func (s mapToSet[K]) Range(f func(K) bool) { s.m.Range(func(k K, _ struct{}) bool { return f(k) }) } /* func (s mapToSet[K]) Iter(cb iter.Callback) { s.Range(func(k K) bool { return cb(k) }) } */ type Map[K KeyConstraint, V any] struct { *pds.Map[K, V] } func NewMap[K KeyConstraint, V any]() Mappish[K, V] { return Map[K, V]{pds.NewMap[K, V](interhash[K]{})} } func (m Map[K, V]) Delete(x K) Mappish[K, V] { m.Map = m.Map.Delete(x) return m } func (m Map[K, V]) Set(key K, value V) Mappish[K, V] { m.Map = m.Map.Set(key, value) return m } func (sm Map[K, V]) Range(f func(K, V) bool) { iter := sm.Map.Iterator() for { k, v, ok := iter.Next() if !ok { break } if !f(k, v) { return } } } /* func (sm Map[K, V]) Iter(cb iter.Callback) { sm.Range(func(key K, _ V) bool { return cb(key) }) } */ type SortedMap[K KeyConstraint, V any] struct { *pds.SortedMap[K, V] } func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] { sm.SortedMap = sm.SortedMap.Set(key, value) return sm } func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] { sm.SortedMap = sm.SortedMap.Delete(key) return sm } func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) { iter := sm.SortedMap.Iterator() for { k, v, ok := iter.Next() if !ok { break } if !f(k, v) { return } } } /* func (sm SortedMap[K, V]) Iter(cb iter.Callback) { sm.Range(func(key K, _ V) bool { return cb(key) }) } */ type lessFunc[T KeyConstraint] func(l, r T) bool type comparer[K KeyConstraint] struct { less lessFunc[K] } func (me comparer[K]) Compare(i, j K) int { if me.less(i, j) { return -1 } else if me.less(j, i) { return 1 } else { return 0 } } func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] { return SortedMap[K, V]{ SortedMap: pds.NewSortedMap[K, V](comparer[K]{less}), } } type Mappish[K, V any] interface { Set(K, V) Mappish[K, V] Delete(key K) Mappish[K, V] Get(key K) (V, bool) Range(func(K, V) bool) Len() int // iter.Iterable } func GetLeft(l, _ any) any { return l } //go:noescape //go:linkname nilinterhash runtime.nilinterhash func nilinterhash(p unsafe.Pointer, h uintptr) uintptr func interfaceHash(x any) uint32 { return uint32(nilinterhash(unsafe.Pointer(&x), 0)) } type Lenner interface { Len() int } var ( mu sync.Mutex ctxVars = map[context.Context]*Var[bool]{} ) // Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be // called when the user is no longer interested in the var. func ContextDoneVar(ctx context.Context) (*Var[bool], func()) { mu.Lock() defer mu.Unlock() if v, ok := ctxVars[ctx]; ok { return v, func() {} } if ctx.Err() != nil { // TODO: What if we had read-only Vars? Then we could have a global one for this that we // just reuse. v := NewBuiltinEqVar(true) return v, func() {} } v := NewVar(false) go func() { <-ctx.Done() AtomicSet(v, true) mu.Lock() delete(ctxVars, ctx) mu.Unlock() }() ctxVars[ctx] = v return v, func() {} } type numTokens = int type Limiter struct { max *Var[numTokens] cur *Var[numTokens] lastAdd *Var[time.Time] rate Limit } const Inf = Limit(math.MaxFloat64) type Limit float64 func (l Limit) interval() time.Duration { return time.Duration(Limit(1*time.Second) / l) } func Every(interval time.Duration) Limit { if interval == 0 { return Inf } return Limit(time.Second / interval) } func NewLimiter(rate Limit, burst numTokens) *Limiter { rl := &Limiter{ max: NewVar(burst), cur: NewBuiltinEqVar(burst), lastAdd: NewVar(time.Now()), rate: rate, } if rate != Inf { go rl.tokenGenerator(rate.interval()) } return rl } func (rl *Limiter) tokenGenerator(interval time.Duration) { for { lastAdd := Deref(rl.lastAdd) time.Sleep(time.Until(lastAdd.Add(interval))) now := time.Now() available := numTokens(now.Sub(lastAdd) / interval) if available < 1 { continue } Atomically(VoidOperation(func(tx *Tx) { cur := rl.cur.Get(tx) max := rl.max.Get(tx) tx.Assert(cur < max) newCur := cur + available if newCur > max { newCur = max } if newCur != cur { rl.cur.Set(tx, newCur) } rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) })) } } func (rl *Limiter) Allow() bool { return rl.AllowN(1) } func (rl *Limiter) AllowN(n numTokens) bool { return Atomically(func(tx *Tx) bool { return rl.takeTokens(tx, n) }) } func (rl *Limiter) AllowStm(tx *Tx) bool { return rl.takeTokens(tx, 1) } func (rl *Limiter) takeTokens(tx *Tx, n numTokens) bool { if rl.rate == Inf { return true } cur := rl.cur.Get(tx) if cur >= n { rl.cur.Set(tx, cur-n) return true } return false } func (rl *Limiter) Wait(ctx context.Context) error { return rl.WaitN(ctx, 1) } func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := ContextDoneVar(ctx) defer cancel() if err := Atomically(func(tx *Tx) error { if ctxDone.Get(tx) { return ctx.Err() } if rl.takeTokens(tx, n) { return nil } if n > rl.max.Get(tx) { return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { return context.DeadlineExceeded } } tx.Retry() panic("unreachable") }); err != nil { return err } return nil }