diff options
Diffstat (limited to 'rate/ratelimit.go')
-rw-r--r-- | rate/ratelimit.go | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/rate/ratelimit.go b/rate/ratelimit.go index a44f383..7fde1ce 100644 --- a/rate/ratelimit.go +++ b/rate/ratelimit.go @@ -13,9 +13,9 @@ import ( type numTokens = int type Limiter struct { - max *stm.Var - cur *stm.Var - lastAdd *stm.Var + max *stm.Var[numTokens] + cur *stm.Var[numTokens] + lastAdd *stm.Var[time.Time] rate Limit } @@ -36,9 +36,9 @@ func Every(interval time.Duration) Limit { func NewLimiter(rate Limit, burst numTokens) *Limiter { rl := &Limiter{ - max: stm.NewVar(burst), + max: stm.NewVar[int](burst), cur: stm.NewBuiltinEqVar(burst), - lastAdd: stm.NewVar(time.Now()), + lastAdd: stm.NewVar[time.Time](time.Now()), rate: rate, } if rate != Inf { @@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter { func (rl *Limiter) tokenGenerator(interval time.Duration) { for { - lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time) + lastAdd := stm.AtomicGet(rl.lastAdd) time.Sleep(time.Until(lastAdd.Add(interval))) now := time.Now() available := numTokens(now.Sub(lastAdd) / interval) @@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { continue } stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { - cur := tx.Get(rl.cur).(numTokens) - max := tx.Get(rl.max).(numTokens) + cur := rl.cur.Get(tx) + max := rl.max.Get(tx) tx.Assert(cur < max) newCur := cur + available if newCur > max { newCur = max } if newCur != cur { - tx.Set(rl.cur, newCur) + rl.cur.Set(tx, newCur) } - tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) + rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available))) })) } } @@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool { if rl.rate == Inf { return true } - cur := tx.Get(rl.cur).(numTokens) + cur := rl.cur.Get(tx) if cur >= n { - tx.Set(rl.cur, cur-n) + rl.cur.Set(tx, cur-n) return true } return false @@ -106,17 +106,17 @@ func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := stmutil.ContextDoneVar(ctx) defer cancel() if err := stm.Atomically(func(tx *stm.Tx) interface{} { - if tx.Get(ctxDone).(bool) { + if ctxDone.Get(tx) { return ctx.Err() } if rl.takeTokens(tx, n) { return nil } - if n > tx.Get(rl.max).(numTokens) { + if n > rl.max.Get(tx) { return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { - if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { + if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n { return context.DeadlineExceeded } } |