diff options
author | EuAndreh <eu@euandre.org> | 2025-01-22 12:31:30 -0300 |
---|---|---|
committer | EuAndreh <eu@euandre.org> | 2025-01-22 12:31:30 -0300 |
commit | 59d879ef4e654ce53c2450e000ffa435f06c2f0e (patch) | |
tree | 05ae996bf799b1e51f891a5586b3b72fa9bdfe3f /src/stm.go | |
parent | Setup Makefile build skeleton (diff) | |
download | stm-59d879ef4e654ce53c2450e000ffa435f06c2f0e.tar.gz stm-59d879ef4e654ce53c2450e000ffa435f06c2f0e.tar.xz |
Unify code into default repo format
Diffstat (limited to 'src/stm.go')
-rw-r--r-- | src/stm.go | 1105 |
1 files changed, 1105 insertions, 0 deletions
@@ -1,9 +1,1114 @@ +/// Package stm provides Software Transactional Memory operations for Go. This +/// is an alternative to the standard way of writing concurrent code (channels +/// and mutexes). STM makes it easy to perform arbitrarily complex operations +/// in an atomic fashion. One of its primary advantages over traditional +/// locking is that STM transactions are composable, whereas locking functions +/// are not -- the composition will either deadlock or release the lock between +/// functions (making it non-atomic). +/// +/// To begin, create an STM object that wraps the data you want to access +/// concurrently. +/// +/// x := stm.NewVar[int](3) +/// +/// You can then use the Atomically method to atomically read and/or write the +/// the data. This code atomically decrements x: +/// +/// stm.Atomically(func(tx *stm.Tx) { +/// cur := x.Get(tx) +/// x.Set(tx, cur-1) +/// }) +/// +/// An important part of STM transactions is retrying. At any point during the +/// transaction, you can call tx.Retry(), which will abort the transaction, but +/// not cancel it entirely. The call to Atomically will block until another +/// call to Atomically finishes, at which point the transaction will be rerun. +/// Specifically, one of the values read by the transaction (via tx.Get) must be +/// updated before the transaction will be rerun. As an example, this code will +/// try to decrement x, but will block as long as x is zero: +/// +/// stm.Atomically(func(tx *stm.Tx) { +/// cur := x.Get(tx) +/// if cur == 0 { +/// tx.Retry() +/// } +/// x.Set(tx, cur-1) +/// }) +/// +/// Internally, tx.Retry simply calls panic(stm.Retry). Panicking with any +/// other value will cancel the transaction; no values will be changed. +/// However, it is the responsibility of the caller to catch such panics. +/// +/// Multiple transactions can be composed using Select. If the first +/// transaction calls Retry, the next transaction will be run, and so on. If +/// all of the transactions call Retry, the call will block and the entire +/// selection will be retried. For example, this code implements the +/// "decrement-if-nonzero" transaction above, but for two values. It will +/// first try to decrement x, then y, and block if both values are zero. +/// +/// func dec(v *stm.Var[int]) { +/// return func(tx *stm.Tx) { +/// cur := v.Get(tx) +/// if cur == 0 { +/// tx.Retry() +/// } +/// v.Set(tx, cur-1) +/// } +/// } +/// +/// // Note that Select does not perform any work itself, but merely +/// // returns a transaction function. +/// stm.Atomically(stm.Select(dec(x), dec(y))) +/// +/// An important caveat: transactions must be idempotent (they should have the +/// same effect every time they are invoked). This is because a transaction may +/// be retried several times before successfully completing, meaning its side +/// effects may execute more than once. This will almost certainly cause +/// incorrect behavior. One common way to get around this is to build up a list +/// of impure operations inside the transaction, and then perform them after the +/// transaction completes. +/// +/// The stm API tries to mimic that of Haskell's Control.Concurrent.STM, but +/// Haskell can enforce at compile time that STM variables are not modified +/// outside the STM monad. This is not possible in Go, so be especially careful +/// when using pointers in your STM code. Remember: modifying a pointer is a +/// side effect! package stm import ( + "context" + "errors" + "expvar" + "fmt" + "math" + "math/rand" + "runtime/pprof" + "sort" + "sync" + "sync/atomic" + "time" + "unsafe" + + "pds" ) +// Package atomic contains type-safe atomic types. +// +// The zero value for the numeric types cannot be used. Use New*. The +// rationale for this behaviour is that copying an atomic integer is not +// reliable. Copying can be prevented by embedding sync.Mutex, but that bloats +// the type. + +// Interface represents atomic operations on a value. +type Interface[T any] interface { + // Load value atomically. + Load() T + // Store value atomically. + Store(value T) + // Swap the previous value with the new value atomically. + Swap(new T) (old T) + // CompareAndSwap the previous value with new if its value is "old". + CompareAndSwap(old, new T) (swapped bool) +} + +var _ Interface[bool] = &Value[bool]{} + +// Value wraps any generic value in atomic load and store operations. +// +// The zero value should be initialised using [Value.Store] before use. +type Value[T any] struct { + value atomic.Value +} + +// New atomic Value. +func New[T any](seed T) *Value[T] { + v := &Value[T]{} + v.value.Store(seed) + return v +} + +// Load value atomically. +// +// Will panic if the value is nil. +func (v *Value[T]) Load() (out T) { + value := v.value.Load() + if value == nil { + panic("nil value in atomic.Value") + } + return value.(T) +} +func (v *Value[T]) Store(value T) { v.value.Store(value) } +func (v *Value[T]) Swap(new T) (old T) { return v.value.Swap(new).(T) } +func (v *Value[T]) CompareAndSwap(old, new T) (swapped bool) { return v.value.CompareAndSwap(old, new) } + +// atomicint defines the types that atomic integer operations are supported on. +type atomicint interface { + int32 | uint32 | int64 | uint64 +} + +// Int expresses atomic operations on signed or unsigned integer values. +type Int[T atomicint] interface { + Interface[T] + // Add a value and return the new result. + Add(delta T) (new T) +} + +// Currently not supported by Go's generic type system: +// +// ./atomic.go:48:9: cannot use type switch on type parameter value v (variable of type T constrained by atomicint) +// +// // ForInt infers and creates an atomic Int[T] type for a value. +// func ForInt[T atomicint](v T) Int[T] { +// switch v.(type) { +// case int32: +// return NewInt32(v) +// case uint32: +// return NewUint32(v) +// case int64: +// return NewInt64(v) +// case uint64: +// return NewUint64(v) +// } +// panic("can't happen") +// } + +// Int32 atomic value. +// +// Copying creates an alias. The zero value is not usable, use NewInt32. +type Int32 struct{ value *int32 } + +// NewInt32 creates a new atomic integer with an initial value. +func NewInt32(value int32) Int32 { return Int32{value: &value} } + +var _ Int[int32] = &Int32{} + +func (i Int32) Add(delta int32) (new int32) { return atomic.AddInt32(i.value, delta) } +func (i Int32) Load() (val int32) { return atomic.LoadInt32(i.value) } +func (i Int32) Store(val int32) { atomic.StoreInt32(i.value, val) } +func (i Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(i.value, new) } +func (i Int32) CompareAndSwap(old, new int32) (swapped bool) { + return atomic.CompareAndSwapInt32(i.value, old, new) +} + +// Uint32 atomic value. +// +// Copying creates an alias. +type Uint32 struct{ value *uint32 } + +var _ Int[uint32] = Uint32{} + +// NewUint32 creates a new atomic integer with an initial value. +func NewUint32(value uint32) Uint32 { return Uint32{value: &value} } + +func (i Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(i.value, delta) } +func (i Uint32) Load() (val uint32) { return atomic.LoadUint32(i.value) } +func (i Uint32) Store(val uint32) { atomic.StoreUint32(i.value, val) } +func (i Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(i.value, new) } +func (i Uint32) CompareAndSwap(old, new uint32) (swapped bool) { + return atomic.CompareAndSwapUint32(i.value, old, new) +} + +// Int64 atomic value. +// +// Copying creates an alias. +type Int64 struct{ value *int64 } + +var _ Int[int64] = Int64{} + +// NewInt64 creates a new atomic integer with an initial value. +func NewInt64(value int64) Int64 { return Int64{value: &value} } + +func (i Int64) Add(delta int64) (new int64) { return atomic.AddInt64(i.value, delta) } +func (i Int64) Load() (val int64) { return atomic.LoadInt64(i.value) } +func (i Int64) Store(val int64) { atomic.StoreInt64(i.value, val) } +func (i Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(i.value, new) } +func (i Int64) CompareAndSwap(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(i.value, old, new) +} + +// Uint64 atomic value. +// +// Copying creates an alias. +type Uint64 struct{ value *uint64 } + +var _ Int[uint64] = Uint64{} + +// NewUint64 creates a new atomic integer with an initial value. +func NewUint64(value uint64) Uint64 { return Uint64{value: &value} } + +func (i Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(i.value, delta) } +func (i Uint64) Load() (val uint64) { return atomic.LoadUint64(i.value) } +func (i Uint64) Store(val uint64) { atomic.StoreUint64(i.value, val) } +func (i Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(i.value, new) } +func (i Uint64) CompareAndSwap(old, new uint64) (swapped bool) { + return atomic.CompareAndSwapUint64(i.value, old, new) +} + + func F() { } + +var ( + txPool = sync.Pool{New: func() any { + expvars.Add("new txs", 1) + tx := &Tx{ + reads: make(map[txVar]VarValue), + writes: make(map[txVar]any), + watching: make(map[txVar]struct{}), + } + tx.cond.L = &tx.mu + return tx + }} + failedCommitsProfile *pprof.Profile +) + +const ( + profileFailedCommits = false + sleepBetweenRetries = false +) + +func init() { + if profileFailedCommits { + failedCommitsProfile = pprof.NewProfile("stmFailedCommits") + } +} + +func newTx() *Tx { + tx := txPool.Get().(*Tx) + tx.tries = 0 + tx.completed = false + return tx +} + +func WouldBlock[R any](fn Operation[R]) (block bool) { + tx := newTx() + tx.reset() + _, block = catchRetry(fn, tx) + if len(tx.watching) != 0 { + panic("shouldn't have installed any watchers") + } + tx.recycle() + return +} + +// Atomically executes the atomic function fn. +func Atomically[R any](op Operation[R]) R { + expvars.Add("atomically", 1) + // run the transaction + tx := newTx() +retry: + tx.tries++ + tx.reset() + if sleepBetweenRetries { + shift := int64(tx.tries - 1) + const maxShift = 30 + if shift > maxShift { + shift = maxShift + } + ns := int64(1) << shift + d := time.Duration(rand.Int63n(ns)) + if d > 100*time.Microsecond { + tx.updateWatchers() + time.Sleep(time.Duration(ns)) + } + } + tx.mu.Lock() + ret, retry := catchRetry(op, tx) + tx.mu.Unlock() + if retry { + expvars.Add("retries", 1) + // wait for one of the variables we read to change before retrying + tx.wait() + goto retry + } + // verify the read log + tx.lockAllVars() + if tx.inputsChanged() { + tx.unlock() + expvars.Add("failed commits", 1) + if profileFailedCommits { + failedCommitsProfile.Add(new(int), 0) + } + goto retry + } + // commit the write log and broadcast that variables have changed + tx.commit() + tx.mu.Lock() + tx.completed = true + tx.cond.Broadcast() + tx.mu.Unlock() + tx.unlock() + expvars.Add("commits", 1) + tx.recycle() + return ret +} + +// AtomicGet is a helper function that atomically reads a value. +func AtomicGet[T any](v *Var[T]) T { + return v.value.Load().Get().(T) +} + +// AtomicSet is a helper function that atomically writes a value. +func AtomicSet[T any](v *Var[T], val T) { + v.mu.Lock() + v.changeValue(val) + v.mu.Unlock() +} + +// Compose is a helper function that composes multiple transactions into a +// single transaction. +func Compose[R any](fns ...Operation[R]) Operation[struct{}] { + return VoidOperation(func(tx *Tx) { + for _, f := range fns { + f(tx) + } + }) +} + +// Select runs the supplied functions in order. Execution stops when a +// function succeeds without calling Retry. If no functions succeed, the +// entire selection will be retried. +func Select[R any](fns ...Operation[R]) Operation[R] { + return func(tx *Tx) R { + switch len(fns) { + case 0: + // empty Select blocks forever + tx.Retry() + panic("unreachable") + case 1: + return fns[0](tx) + default: + oldWrites := tx.writes + tx.writes = make(map[txVar]any, len(oldWrites)) + for k, v := range oldWrites { + tx.writes[k] = v + } + ret, retry := catchRetry(fns[0], tx) + if retry { + tx.writes = oldWrites + return Select(fns[1:]...)(tx) + } else { + return ret + } + } + } +} + +type Operation[R any] func(*Tx) R + +func VoidOperation(f func(*Tx)) Operation[struct{}] { + return func(tx *Tx) struct{} { + f(tx) + return struct{}{} + } +} + +func AtomicModify[T any](v *Var[T], f func(T) T) { + Atomically(VoidOperation(func(tx *Tx) { + v.Set(tx, f(v.Get(tx))) + })) +} + +var expvars = expvar.NewMap("stm") + +type VarValue interface { + Set(any) VarValue + Get() any + Changed(VarValue) bool +} + +type version uint64 + +type versionedValue[T any] struct { + value T + version version +} + +func (me versionedValue[T]) Set(newValue any) VarValue { + return versionedValue[T]{ + value: newValue.(T), + version: me.version + 1, + } +} + +func (me versionedValue[T]) Get() any { + return me.value +} + +func (me versionedValue[T]) Changed(other VarValue) bool { + return me.version != other.(versionedValue[T]).version +} + +type customVarValue[T any] struct { + value T + changed func(T, T) bool +} + +var _ VarValue = customVarValue[struct{}]{} + +func (me customVarValue[T]) Changed(other VarValue) bool { + return me.changed(me.value, other.(customVarValue[T]).value) +} + +func (me customVarValue[T]) Set(newValue any) VarValue { + return customVarValue[T]{ + value: newValue.(T), + changed: me.changed, + } +} + +func (me customVarValue[T]) Get() any { + return me.value +} + +type txVar interface { + getValue() *Value[VarValue] + changeValue(any) + getWatchers() *sync.Map + getLock() *sync.Mutex +} + +// A Tx represents an atomic transaction. +type Tx struct { + reads map[txVar]VarValue + writes map[txVar]any + watching map[txVar]struct{} + locks txLocks + mu sync.Mutex + cond sync.Cond + waiting bool + completed bool + tries int + numRetryValues int +} + +// Check that none of the logged values have changed since the transaction began. +func (tx *Tx) inputsChanged() bool { + for v, read := range tx.reads { + if read.Changed(v.getValue().Load()) { + return true + } + } + return false +} + +// Writes the values in the transaction log to their respective Vars. +func (tx *Tx) commit() { + for v, val := range tx.writes { + v.changeValue(val) + } +} + +func (tx *Tx) updateWatchers() { + for v := range tx.watching { + if _, ok := tx.reads[v]; !ok { + delete(tx.watching, v) + v.getWatchers().Delete(tx) + } + } + for v := range tx.reads { + if _, ok := tx.watching[v]; !ok { + v.getWatchers().Store(tx, nil) + tx.watching[v] = struct{}{} + } + } +} + +// wait blocks until another transaction modifies any of the Vars read by tx. +func (tx *Tx) wait() { + if len(tx.reads) == 0 { + panic("not waiting on anything") + } + tx.updateWatchers() + tx.mu.Lock() + firstWait := true + for !tx.inputsChanged() { + if !firstWait { + expvars.Add("wakes for unchanged versions", 1) + } + expvars.Add("waits", 1) + tx.waiting = true + tx.cond.Broadcast() + tx.cond.Wait() + tx.waiting = false + firstWait = false + } + tx.mu.Unlock() +} + +// Get returns the value of v as of the start of the transaction. +func (v *Var[T]) Get(tx *Tx) T { + // If we previously wrote to v, it will be in the write log. + if val, ok := tx.writes[v]; ok { + return val.(T) + } + // If we haven't previously read v, record its version + vv, ok := tx.reads[v] + if !ok { + vv = v.getValue().Load() + tx.reads[v] = vv + } + return vv.Get().(T) +} + +// Set sets the value of a Var for the lifetime of the transaction. +func (v *Var[T]) Set(tx *Tx, val T) { + if v == nil { + panic("nil Var") + } + tx.writes[v] = val +} + +type txProfileValue struct { + *Tx + int +} + +// Retry aborts the transaction and retries it when a Var changes. You can return from this method +// to satisfy return values, but it should never actually return anything as it panics internally. +func (tx *Tx) Retry() struct{} { + retries.Add(txProfileValue{tx, tx.numRetryValues}, 1) + tx.numRetryValues++ + panic(retry) +} + +// Assert is a helper function that retries a transaction if the condition is +// not satisfied. +func (tx *Tx) Assert(p bool) { + if !p { + tx.Retry() + } +} + +func (tx *Tx) reset() { + tx.mu.Lock() + for k := range tx.reads { + delete(tx.reads, k) + } + for k := range tx.writes { + delete(tx.writes, k) + } + tx.mu.Unlock() + tx.removeRetryProfiles() + tx.resetLocks() +} + +func (tx *Tx) removeRetryProfiles() { + for tx.numRetryValues > 0 { + tx.numRetryValues-- + retries.Remove(txProfileValue{tx, tx.numRetryValues}) + } +} + +func (tx *Tx) recycle() { + for v := range tx.watching { + delete(tx.watching, v) + v.getWatchers().Delete(tx) + } + tx.removeRetryProfiles() + // I don't think we can reuse Txs, because the "completed" field should/needs to be set + // indefinitely after use. + //txPool.Put(tx) +} + +func (tx *Tx) lockAllVars() { + tx.resetLocks() + tx.collectAllLocks() + tx.sortLocks() + tx.lock() +} + +func (tx *Tx) resetLocks() { + tx.locks.clear() +} + +func (tx *Tx) collectReadLocks() { + for v := range tx.reads { + tx.locks.append(v.getLock()) + } +} + +func (tx *Tx) collectAllLocks() { + tx.collectReadLocks() + for v := range tx.writes { + if _, ok := tx.reads[v]; !ok { + tx.locks.append(v.getLock()) + } + } +} + +func (tx *Tx) sortLocks() { + sort.Sort(&tx.locks) +} + +func (tx *Tx) lock() { + for _, l := range tx.locks.mus { + l.Lock() + } +} + +func (tx *Tx) unlock() { + for _, l := range tx.locks.mus { + l.Unlock() + } +} + +func (tx *Tx) String() string { + return fmt.Sprintf("%[1]T %[1]p", tx) +} + +// Dedicated type avoids reflection in sort.Slice. +type txLocks struct { + mus []*sync.Mutex +} + +func (me txLocks) Len() int { + return len(me.mus) +} + +func (me txLocks) Less(i, j int) bool { + return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j])) +} + +func (me txLocks) Swap(i, j int) { + me.mus[i], me.mus[j] = me.mus[j], me.mus[i] +} + +func (me *txLocks) clear() { + me.mus = me.mus[:0] +} + +func (me *txLocks) append(mu *sync.Mutex) { + me.mus = append(me.mus, mu) +} + +// Holds an STM variable. +type Var[T any] struct { + value Value[VarValue] + watchers sync.Map + mu sync.Mutex +} + +func (v *Var[T]) getValue() *Value[VarValue] { + return &v.value +} + +func (v *Var[T]) getWatchers() *sync.Map { + return &v.watchers +} + +func (v *Var[T]) getLock() *sync.Mutex { + return &v.mu +} + +func (v *Var[T]) changeValue(new any) { + old := v.value.Load() + newVarValue := old.Set(new) + v.value.Store(newVarValue) + if old.Changed(newVarValue) { + go v.wakeWatchers(newVarValue) + } +} + +func (v *Var[T]) wakeWatchers(new VarValue) { + v.watchers.Range(func(k, _ any) bool { + tx := k.(*Tx) + // We have to lock here to ensure that the Tx is waiting before + // we signal it. Otherwise we could signal it before it goes to + // sleep and it will miss the notification. + tx.mu.Lock() + if read := tx.reads[v]; read != nil && read.Changed(new) { + tx.cond.Broadcast() + for !tx.waiting && !tx.completed { + tx.cond.Wait() + } + } + tx.mu.Unlock() + return !v.value.Load().Changed(new) + }) +} + +// Returns a new STM variable. +func NewVar[T any](val T) *Var[T] { + v := &Var[T]{} + v.value.Store(versionedValue[T]{ + value: val, + }) + return v +} + +func NewCustomVar[T any](val T, changed func(T, T) bool) *Var[T] { + v := &Var[T]{} + v.value.Store(customVarValue[T]{ + value: val, + changed: changed, + }) + return v +} + +func NewBuiltinEqVar[T comparable](val T) *Var[T] { + return NewCustomVar(val, func(a, b T) bool { + return a != b + }) +} + +var retries = pprof.NewProfile("stmRetries") + +// retry is a sentinel value. When thrown via panic, it indicates that a +// transaction should be retried. +var retry = &struct{}{} + +// catchRetry returns true if fn calls tx.Retry. +func catchRetry[R any](fn Operation[R], tx *Tx) (result R, gotRetry bool) { + defer func() { + if r := recover(); r == retry { + gotRetry = true + } else if r != nil { + panic(r) + } + }() + result = fn(tx) + return +} + +// This is the type constraint for keys passed through from pds +type KeyConstraint interface { + comparable +} + +type Settish[K KeyConstraint] interface { + Add(K) Settish[K] + Delete(K) Settish[K] + Contains(K) bool + Range(func(K) bool) + Len() int + // iter.Iterable +} + +type mapToSet[K KeyConstraint] struct { + m Mappish[K, struct{}] +} + +type interhash[K KeyConstraint] struct{} + +func (interhash[K]) Hash(x K) uint32 { + return uint32(nilinterhash(unsafe.Pointer(&x), 0)) +} + +func (interhash[K]) Equal(i, j K) bool { + return i == j +} + +func NewSet[K KeyConstraint]() Settish[K] { + return mapToSet[K]{NewMap[K, struct{}]()} +} + +func NewSortedSet[K KeyConstraint](lesser lessFunc[K]) Settish[K] { + return mapToSet[K]{NewSortedMap[K, struct{}](lesser)} +} + +func (s mapToSet[K]) Add(x K) Settish[K] { + s.m = s.m.Set(x, struct{}{}) + return s +} + +func (s mapToSet[K]) Delete(x K) Settish[K] { + s.m = s.m.Delete(x) + return s +} + +func (s mapToSet[K]) Len() int { + return s.m.Len() +} + +func (s mapToSet[K]) Contains(x K) bool { + _, ok := s.m.Get(x) + return ok +} + +func (s mapToSet[K]) Range(f func(K) bool) { + s.m.Range(func(k K, _ struct{}) bool { + return f(k) + }) +} + +/* +func (s mapToSet[K]) Iter(cb iter.Callback) { + s.Range(func(k K) bool { + return cb(k) + }) +} +*/ + +type Map[K KeyConstraint, V any] struct { + *pds.Map[K, V] +} + +func NewMap[K KeyConstraint, V any]() Mappish[K, V] { + return Map[K, V]{pds.NewMap[K, V](interhash[K]{})} +} + +func (m Map[K, V]) Delete(x K) Mappish[K, V] { + m.Map = m.Map.Delete(x) + return m +} + +func (m Map[K, V]) Set(key K, value V) Mappish[K, V] { + m.Map = m.Map.Set(key, value) + return m +} + +func (sm Map[K, V]) Range(f func(K, V) bool) { + iter := sm.Map.Iterator() + for { + k, v, ok := iter.Next() + if !ok { + break + } + if !f(k, v) { + return + } + } +} + +/* +func (sm Map[K, V]) Iter(cb iter.Callback) { + sm.Range(func(key K, _ V) bool { + return cb(key) + }) +} +*/ + +type SortedMap[K KeyConstraint, V any] struct { + *pds.SortedMap[K, V] +} + +func (sm SortedMap[K, V]) Set(key K, value V) Mappish[K, V] { + sm.SortedMap = sm.SortedMap.Set(key, value) + return sm +} + +func (sm SortedMap[K, V]) Delete(key K) Mappish[K, V] { + sm.SortedMap = sm.SortedMap.Delete(key) + return sm +} + +func (sm SortedMap[K, V]) Range(f func(key K, value V) bool) { + iter := sm.SortedMap.Iterator() + for { + k, v, ok := iter.Next() + if !ok { + break + } + if !f(k, v) { + return + } + } +} + +/* +func (sm SortedMap[K, V]) Iter(cb iter.Callback) { + sm.Range(func(key K, _ V) bool { + return cb(key) + }) +} +*/ + +type lessFunc[T KeyConstraint] func(l, r T) bool + +type comparer[K KeyConstraint] struct { + less lessFunc[K] +} + +func (me comparer[K]) Compare(i, j K) int { + if me.less(i, j) { + return -1 + } else if me.less(j, i) { + return 1 + } else { + return 0 + } +} + +func NewSortedMap[K KeyConstraint, V any](less lessFunc[K]) Mappish[K, V] { + return SortedMap[K, V]{ + SortedMap: pds.NewSortedMap[K, V](comparer[K]{less}), + } +} + +type Mappish[K, V any] interface { + Set(K, V) Mappish[K, V] + Delete(key K) Mappish[K, V] + Get(key K) (V, bool) + Range(func(K, V) bool) + Len() int + // iter.Iterable +} + +func GetLeft(l, _ any) any { + return l +} + +//go:noescape +//go:linkname nilinterhash runtime.nilinterhash +func nilinterhash(p unsafe.Pointer, h uintptr) uintptr + +func interfaceHash(x any) uint32 { + return uint32(nilinterhash(unsafe.Pointer(&x), 0)) +} + +type Lenner interface { + Len() int +} + +var ( + mu sync.Mutex + ctxVars = map[context.Context]*Var[bool]{} +) + +// Returns an STM var that contains a bool equal to `ctx.Err != nil`, and a cancel function to be +// called when the user is no longer interested in the var. +func ContextDoneVar(ctx context.Context) (*Var[bool], func()) { + mu.Lock() + defer mu.Unlock() + if v, ok := ctxVars[ctx]; ok { + return v, func() {} + } + if ctx.Err() != nil { + // TODO: What if we had read-only Vars? Then we could have a global one for this that we + // just reuse. + v := NewBuiltinEqVar(true) + return v, func() {} + } + v := NewVar(false) + go func() { + <-ctx.Done() + AtomicSet(v, true) + mu.Lock() + delete(ctxVars, ctx) + mu.Unlock() + }() + ctxVars[ctx] = v + return v, func() {} +} + +type numTokens = int + +type Limiter struct { + max *Var[numTokens] + cur *Var[numTokens] + lastAdd *Var[time.Time] + rate Limit +} + +const Inf = Limit(math.MaxFloat64) + +type Limit float64 + +func (l Limit) interval() time.Duration { + return time.Duration(Limit(1*time.Second) / l) +} + +func Every(interval time.Duration) Limit { + if interval == 0 { + return Inf + } + return Limit(time.Second / interval) +} + +func NewLimiter(rate Limit, burst numTokens) *Limiter { + rl := &Limiter{ + max: NewVar(burst), + cur: NewBuiltinEqVar(burst), + lastAdd: NewVar(time.Now()), + rate: rate, + } + if rate != Inf { + go rl.tokenGenerator(rate.interval()) + } + return rl +} + +func (rl *Limiter) tokenGenerator(interval time.Duration) { + for { + lastAdd := AtomicGet(rl.lastAdd) + time.Sleep(time.Until(lastAdd.Add(interval))) + now := time.Now() + available := numTokens(now.Sub(lastAdd) / interval) + if available < 1 { + continue + } + Atomically(VoidOperation(func(tx *Tx) { + cur := rl.cur.Get(tx) + max := rl.max.Get(tx) + tx.Assert(cur < max) + newCur := cur + available + if newCur > max { + newCur = max + } + if newCur != cur { + rl.cur.Set(tx, newCur) + } + rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) + })) + } +} + +func (rl *Limiter) Allow() bool { + return rl.AllowN(1) +} + +func (rl *Limiter) AllowN(n numTokens) bool { + return Atomically(func(tx *Tx) bool { + return rl.takeTokens(tx, n) + }) +} + +func (rl *Limiter) AllowStm(tx *Tx) bool { + return rl.takeTokens(tx, 1) +} + +func (rl *Limiter) takeTokens(tx *Tx, n numTokens) bool { + if rl.rate == Inf { + return true + } + cur := rl.cur.Get(tx) + if cur >= n { + rl.cur.Set(tx, cur-n) + return true + } + return false +} + +func (rl *Limiter) Wait(ctx context.Context) error { + return rl.WaitN(ctx, 1) +} + +func (rl *Limiter) WaitN(ctx context.Context, n int) error { + ctxDone, cancel := ContextDoneVar(ctx) + defer cancel() + if err := Atomically(func(tx *Tx) error { + if ctxDone.Get(tx) { + return ctx.Err() + } + if rl.takeTokens(tx, n) { + return nil + } + if n > rl.max.Get(tx) { + return errors.New("burst exceeded") + } + if dl, ok := ctx.Deadline(); ok { + if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { + return context.DeadlineExceeded + } + } + tx.Retry() + panic("unreachable") + }); err != nil { + return err + } + return nil + +} |