diff options
Diffstat (limited to 'rate/ratelimit.go')
-rw-r--r-- | rate/ratelimit.go | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/rate/ratelimit.go b/rate/ratelimit.go index 0278283..69d6932 100644 --- a/rate/ratelimit.go +++ b/rate/ratelimit.go @@ -56,7 +56,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { if available < 1 { continue } - stm.Atomically(func(tx *stm.Tx) { + stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) { cur := tx.Get(rl.cur).(numTokens) max := tx.Get(rl.max).(numTokens) tx.Assert(cur < max) @@ -68,7 +68,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) { tx.Set(rl.cur, newCur) } tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) - }) + })) } } @@ -77,8 +77,8 @@ func (rl *Limiter) Allow() bool { } func (rl *Limiter) AllowN(n numTokens) bool { - return stm.Atomically(func(tx *stm.Tx) { - tx.Return(rl.takeTokens(tx, n)) + return stm.Atomically(func(tx *stm.Tx) interface{} { + return rl.takeTokens(tx, n) }).(bool) } @@ -105,22 +105,23 @@ func (rl *Limiter) Wait(ctx context.Context) error { func (rl *Limiter) WaitN(ctx context.Context, n int) error { ctxDone, cancel := stmutil.ContextDoneVar(ctx) defer cancel() - if err := stm.Atomically(func(tx *stm.Tx) { + if err := stm.Atomically(func(tx *stm.Tx) interface{} { if tx.Get(ctxDone).(bool) { - tx.Return(ctx.Err()) + return ctx.Err() } if rl.takeTokens(tx, n) { - tx.Return(nil) + return nil } if n > tx.Get(rl.max).(numTokens) { - tx.Return(errors.New("burst exceeded")) + return errors.New("burst exceeded") } if dl, ok := ctx.Deadline(); ok { if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { - tx.Return(context.DeadlineExceeded) + return context.DeadlineExceeded } } tx.Retry() + panic("unreachable") }); err != nil { return err.(error) } |