aboutsummaryrefslogtreecommitdiff
path: root/rate/ratelimit.go
blob: a44f38359f5065a403212b6891a58cb29f65937a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package rate

import (
	"context"
	"errors"
	"math"
	"time"

	"github.com/anacrolix/stm"
	"github.com/anacrolix/stm/stmutil"
)

type numTokens = int

type Limiter struct {
	max     *stm.Var
	cur     *stm.Var
	lastAdd *stm.Var
	rate    Limit
}

const Inf = Limit(math.MaxFloat64)

type Limit float64

func (l Limit) interval() time.Duration {
	return time.Duration(Limit(1*time.Second) / l)
}

func Every(interval time.Duration) Limit {
	if interval == 0 {
		return Inf
	}
	return Limit(time.Second / interval)
}

func NewLimiter(rate Limit, burst numTokens) *Limiter {
	rl := &Limiter{
		max:     stm.NewVar(burst),
		cur:     stm.NewBuiltinEqVar(burst),
		lastAdd: stm.NewVar(time.Now()),
		rate:    rate,
	}
	if rate != Inf {
		go rl.tokenGenerator(rate.interval())
	}
	return rl
}

func (rl *Limiter) tokenGenerator(interval time.Duration) {
	for {
		lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
		time.Sleep(time.Until(lastAdd.Add(interval)))
		now := time.Now()
		available := numTokens(now.Sub(lastAdd) / interval)
		if available < 1 {
			continue
		}
		stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
			cur := tx.Get(rl.cur).(numTokens)
			max := tx.Get(rl.max).(numTokens)
			tx.Assert(cur < max)
			newCur := cur + available
			if newCur > max {
				newCur = max
			}
			if newCur != cur {
				tx.Set(rl.cur, newCur)
			}
			tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
		}))
	}
}

func (rl *Limiter) Allow() bool {
	return rl.AllowN(1)
}

func (rl *Limiter) AllowN(n numTokens) bool {
	return stm.Atomically(func(tx *stm.Tx) interface{} {
		return rl.takeTokens(tx, n)
	}).(bool)
}

func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
	return rl.takeTokens(tx, 1)
}

func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
	if rl.rate == Inf {
		return true
	}
	cur := tx.Get(rl.cur).(numTokens)
	if cur >= n {
		tx.Set(rl.cur, cur-n)
		return true
	}
	return false
}

func (rl *Limiter) Wait(ctx context.Context) error {
	return rl.WaitN(ctx, 1)
}

func (rl *Limiter) WaitN(ctx context.Context, n int) error {
	ctxDone, cancel := stmutil.ContextDoneVar(ctx)
	defer cancel()
	if err := stm.Atomically(func(tx *stm.Tx) interface{} {
		if tx.Get(ctxDone).(bool) {
			return ctx.Err()
		}
		if rl.takeTokens(tx, n) {
			return nil
		}
		if n > tx.Get(rl.max).(numTokens) {
			return errors.New("burst exceeded")
		}
		if dl, ok := ctx.Deadline(); ok {
			if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
				return context.DeadlineExceeded
			}
		}
		tx.Retry()
		panic("unreachable")
	}); err != nil {
		return err.(error)
	}
	return nil

}