aboutsummaryrefslogtreecommitdiff
path: root/rate
diff options
context:
space:
mode:
Diffstat (limited to 'rate')
-rw-r--r--rate/rate_test.go3
-rw-r--r--rate/ratelimit.go34
2 files changed, 19 insertions, 18 deletions
diff --git a/rate/rate_test.go b/rate/rate_test.go
index 7f44d74..3078b4c 100644
--- a/rate/rate_test.go
+++ b/rate/rate_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build go1.7
// +build go1.7
package rate
@@ -403,7 +404,7 @@ func runWait(t *testing.T, lim *Limiter, w wait) {
t.Helper()
start := time.Now()
err := lim.WaitN(w.ctx, w.n)
- delay := time.Now().Sub(start)
+ delay := time.Since(start)
if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) {
errString := "<nil>"
if !w.nilErr {
diff --git a/rate/ratelimit.go b/rate/ratelimit.go
index a44f383..f521f66 100644
--- a/rate/ratelimit.go
+++ b/rate/ratelimit.go
@@ -13,9 +13,9 @@ import (
type numTokens = int
type Limiter struct {
- max *stm.Var
- cur *stm.Var
- lastAdd *stm.Var
+ max *stm.Var[numTokens]
+ cur *stm.Var[numTokens]
+ lastAdd *stm.Var[time.Time]
rate Limit
}
@@ -49,7 +49,7 @@ func NewLimiter(rate Limit, burst numTokens) *Limiter {
func (rl *Limiter) tokenGenerator(interval time.Duration) {
for {
- lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
+ lastAdd := stm.AtomicGet(rl.lastAdd)
time.Sleep(time.Until(lastAdd.Add(interval)))
now := time.Now()
available := numTokens(now.Sub(lastAdd) / interval)
@@ -57,17 +57,17 @@ func (rl *Limiter) tokenGenerator(interval time.Duration) {
continue
}
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
- cur := tx.Get(rl.cur).(numTokens)
- max := tx.Get(rl.max).(numTokens)
+ cur := rl.cur.Get(tx)
+ max := rl.max.Get(tx)
tx.Assert(cur < max)
newCur := cur + available
if newCur > max {
newCur = max
}
if newCur != cur {
- tx.Set(rl.cur, newCur)
+ rl.cur.Set(tx, newCur)
}
- tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
+ rl.lastAdd.Set(tx, lastAdd.Add(interval*time.Duration(available)))
}))
}
}
@@ -77,9 +77,9 @@ func (rl *Limiter) Allow() bool {
}
func (rl *Limiter) AllowN(n numTokens) bool {
- return stm.Atomically(func(tx *stm.Tx) interface{} {
+ return stm.Atomically(func(tx *stm.Tx) bool {
return rl.takeTokens(tx, n)
- }).(bool)
+ })
}
func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
@@ -90,9 +90,9 @@ func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
if rl.rate == Inf {
return true
}
- cur := tx.Get(rl.cur).(numTokens)
+ cur := rl.cur.Get(tx)
if cur >= n {
- tx.Set(rl.cur, cur-n)
+ rl.cur.Set(tx, cur-n)
return true
}
return false
@@ -105,25 +105,25 @@ func (rl *Limiter) Wait(ctx context.Context) error {
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
- if err := stm.Atomically(func(tx *stm.Tx) interface{} {
- if tx.Get(ctxDone).(bool) {
+ if err := stm.Atomically(func(tx *stm.Tx) error {
+ if ctxDone.Get(tx) {
return ctx.Err()
}
if rl.takeTokens(tx, n) {
return nil
}
- if n > tx.Get(rl.max).(numTokens) {
+ if n > rl.max.Get(tx) {
return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
- if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
+ if rl.cur.Get(tx)+numTokens(dl.Sub(rl.lastAdd.Get(tx))/rl.rate.interval()) < n {
return context.DeadlineExceeded
}
}
tx.Retry()
panic("unreachable")
}); err != nil {
- return err.(error)
+ return err
}
return nil