aboutsummaryrefslogtreecommitdiff
path: root/rate/ratelimit.go
diff options
context:
space:
mode:
Diffstat (limited to 'rate/ratelimit.go')
-rw-r--r--rate/ratelimit.go19
1 files changed, 10 insertions, 9 deletions
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
index 0278283..69d6932 100644
--- a/rate/ratelimit.go
+++ b/rate/ratelimit.go
@@ -56,7 +56,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
if available < 1 {
continue
}
- stm.Atomically(func(tx *stm.Tx) {
+ stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
cur := tx.Get(rl.cur).(numTokens)
max := tx.Get(rl.max).(numTokens)
tx.Assert(cur < max)
@@ -68,7 +68,7 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
tx.Set(rl.cur, newCur)
}
tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
- })
+ }))
}
}
@@ -77,8 +77,8 @@ func (rl *Limiter) Allow() bool {
}
func (rl *Limiter) AllowN(n numTokens) bool {
- return stm.Atomically(func(tx *stm.Tx) {
- tx.Return(rl.takeTokens(tx, n))
+ return stm.Atomically(func(tx *stm.Tx) interface{} {
+ return rl.takeTokens(tx, n)
}).(bool)
}
@@ -105,22 +105,23 @@ func (rl *Limiter) Wait(ctx context.Context) error {
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
- if err := stm.Atomically(func(tx *stm.Tx) {
+ if err := stm.Atomically(func(tx *stm.Tx) interface{} {
if tx.Get(ctxDone).(bool) {
- tx.Return(ctx.Err())
+ return ctx.Err()
}
if rl.takeTokens(tx, n) {
- tx.Return(nil)
+ return nil
}
if n > tx.Get(rl.max).(numTokens) {
- tx.Return(errors.New("burst exceeded"))
+ return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
- tx.Return(context.DeadlineExceeded)
+ return context.DeadlineExceeded
}
}
tx.Retry()
+ panic("unreachable")
}); err != nil {
return err.(error)
}