diff options
Diffstat (limited to 'rate/ratelimit.go')
-rw-r--r-- | rate/ratelimit.go | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/rate/ratelimit.go b/rate/ratelimit.go new file mode 100644 index 0000000..8525371 --- /dev/null +++ b/rate/ratelimit.go @@ -0,0 +1,129 @@ +package rate + +import ( + "context" + "errors" + "math" + "time" + + "github.com/lukechampine/stm" + "github.com/lukechampine/stm/stmutil" +) + +type numTokens = int + +type Limiter struct { + max *stm.Var + cur *stm.Var + lastAdd *stm.Var + rate Limit +} + +const Inf = Limit(math.MaxFloat64) + +type Limit float64 + +func (l Limit) interval() time.Duration { + return time.Duration(Limit(1*time.Second) / l) +} + +func Every(interval time.Duration) Limit { + if interval == 0 { + return Inf + } + return Limit(time.Second / interval) +} + +func NewLimiter(rate Limit, burst numTokens) *Limiter { + rl := &Limiter{ + max: stm.NewVar(burst), + cur: stm.NewVar(burst), + lastAdd: stm.NewVar(time.Now()), + rate: rate, + } + if rate != Inf { + go rl.tokenGenerator(rate.interval()) + } + return rl +} + +func (rl *Limiter) tokenGenerator(interval time.Duration) { + for { + lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time) + time.Sleep(time.Until(lastAdd.Add(interval))) + now := time.Now() + available := numTokens(now.Sub(lastAdd) / interval) + if available < 1 { + continue + } + stm.Atomically(func(tx *stm.Tx) { + cur := tx.Get(rl.cur).(numTokens) + max := tx.Get(rl.max).(numTokens) + tx.Assert(cur < max) + newCur := cur + available + if newCur > max { + newCur = max + } + if newCur != cur { + tx.Set(rl.cur, newCur) + } + tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available))) + }) + } +} + +func (rl *Limiter) Allow() bool { + return rl.AllowN(1) +} + +func (rl *Limiter) AllowN(n numTokens) bool { + return stm.Atomically(func(tx *stm.Tx) { + tx.Return(rl.takeTokens(tx, n)) + }).(bool) +} + +func (rl *Limiter) AllowStm(tx *stm.Tx) bool { + return rl.takeTokens(tx, 1) +} + +func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool { + if rl.rate == Inf { + return true + } + cur := tx.Get(rl.cur).(numTokens) + if cur >= n { + tx.Set(rl.cur, cur-n) + return true + } + return false +} + +func (rl *Limiter) Wait(ctx context.Context) error { + return rl.WaitN(ctx, 1) +} + +func (rl *Limiter) WaitN(ctx context.Context, n int) error { + ctxDone, cancel := stmutil.ContextDoneVar(ctx) + defer cancel() + if err := stm.Atomically(func(tx *stm.Tx) { + if tx.Get(ctxDone).(bool) { + tx.Return(ctx.Err()) + } + if rl.takeTokens(tx, n) { + tx.Return(nil) + } + if n > tx.Get(rl.max).(numTokens) { + tx.Return(errors.New("burst exceeded")) + } + if dl, ok := ctx.Deadline(); ok { + if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n { + tx.Return(context.DeadlineExceeded) + } + } + tx.Retry() + }); err != nil { + return err.(error) + } + return nil + +} |