1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
package rate
import (
"context"
"errors"
"math"
"time"
"github.com/anacrolix/stm"
"github.com/anacrolix/stm/stmutil"
)
type numTokens = int
type Limiter struct {
max *stm.Var
cur *stm.Var
lastAdd *stm.Var
rate Limit
}
const Inf = Limit(math.MaxFloat64)
type Limit float64
func (l Limit) interval() time.Duration {
return time.Duration(Limit(1*time.Second) / l)
}
func Every(interval time.Duration) Limit {
if interval == 0 {
return Inf
}
return Limit(time.Second / interval)
}
func NewLimiter(rate Limit, burst numTokens) *Limiter {
rl := &Limiter{
max: stm.NewVar(burst),
cur: stm.NewVar(burst),
lastAdd: stm.NewVar(time.Now()),
rate: rate,
}
if rate != Inf {
go rl.tokenGenerator(rate.interval())
}
return rl
}
func (rl *Limiter) tokenGenerator(interval time.Duration) {
for {
lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
time.Sleep(time.Until(lastAdd.Add(interval)))
now := time.Now()
available := numTokens(now.Sub(lastAdd) / interval)
if available < 1 {
continue
}
stm.Atomically(func(tx *stm.Tx) {
cur := tx.Get(rl.cur).(numTokens)
max := tx.Get(rl.max).(numTokens)
tx.Assert(cur < max)
newCur := cur + available
if newCur > max {
newCur = max
}
if newCur != cur {
tx.Set(rl.cur, newCur)
}
tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
})
}
}
func (rl *Limiter) Allow() bool {
return rl.AllowN(1)
}
func (rl *Limiter) AllowN(n numTokens) bool {
return stm.Atomically(func(tx *stm.Tx) {
tx.Return(rl.takeTokens(tx, n))
}).(bool)
}
func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
return rl.takeTokens(tx, 1)
}
func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
if rl.rate == Inf {
return true
}
cur := tx.Get(rl.cur).(numTokens)
if cur >= n {
tx.Set(rl.cur, cur-n)
return true
}
return false
}
func (rl *Limiter) Wait(ctx context.Context) error {
return rl.WaitN(ctx, 1)
}
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
if err := stm.Atomically(func(tx *stm.Tx) {
if tx.Get(ctxDone).(bool) {
tx.Return(ctx.Err())
}
if rl.takeTokens(tx, n) {
tx.Return(nil)
}
if n > tx.Get(rl.max).(numTokens) {
tx.Return(errors.New("burst exceeded"))
}
if dl, ok := ctx.Deadline(); ok {
if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
tx.Return(context.DeadlineExceeded)
}
}
tx.Retry()
}); err != nil {
return err.(error)
}
return nil
}
|