1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
package rate
import (
"context"
"errors"
"math"
"time"
"github.com/anacrolix/stm"
"github.com/anacrolix/stm/stmutil"
)
type numTokens = int
type Limiter struct {
max *stm.Var
cur *stm.Var
lastAdd *stm.Var
rate Limit
}
const Inf = Limit(math.MaxFloat64)
type Limit float64
func (l Limit) interval() time.Duration {
return time.Duration(Limit(1*time.Second) / l)
}
func Every(interval time.Duration) Limit {
if interval == 0 {
return Inf
}
return Limit(time.Second / interval)
}
func NewLimiter(rate Limit, burst numTokens) *Limiter {
rl := &Limiter{
max: stm.NewVar(burst),
cur: stm.NewVar(burst),
lastAdd: stm.NewVar(time.Now()),
rate: rate,
}
if rate != Inf {
go rl.tokenGenerator(rate.interval())
}
return rl
}
func (rl *Limiter) tokenGenerator(interval time.Duration) {
for {
lastAdd := stm.AtomicGet(rl.lastAdd).(time.Time)
time.Sleep(time.Until(lastAdd.Add(interval)))
now := time.Now()
available := numTokens(now.Sub(lastAdd) / interval)
if available < 1 {
continue
}
stm.Atomically(stm.VoidOperation(func(tx *stm.Tx) {
cur := tx.Get(rl.cur).(numTokens)
max := tx.Get(rl.max).(numTokens)
tx.Assert(cur < max)
newCur := cur + available
if newCur > max {
newCur = max
}
if newCur != cur {
tx.Set(rl.cur, newCur)
}
tx.Set(rl.lastAdd, lastAdd.Add(interval*time.Duration(available)))
}))
}
}
func (rl *Limiter) Allow() bool {
return rl.AllowN(1)
}
func (rl *Limiter) AllowN(n numTokens) bool {
return stm.Atomically(func(tx *stm.Tx) interface{} {
return rl.takeTokens(tx, n)
}).(bool)
}
func (rl *Limiter) AllowStm(tx *stm.Tx) bool {
return rl.takeTokens(tx, 1)
}
func (rl *Limiter) takeTokens(tx *stm.Tx, n numTokens) bool {
if rl.rate == Inf {
return true
}
cur := tx.Get(rl.cur).(numTokens)
if cur >= n {
tx.Set(rl.cur, cur-n)
return true
}
return false
}
func (rl *Limiter) Wait(ctx context.Context) error {
return rl.WaitN(ctx, 1)
}
func (rl *Limiter) WaitN(ctx context.Context, n int) error {
ctxDone, cancel := stmutil.ContextDoneVar(ctx)
defer cancel()
if err := stm.Atomically(func(tx *stm.Tx) interface{} {
if tx.Get(ctxDone).(bool) {
return ctx.Err()
}
if rl.takeTokens(tx, n) {
return nil
}
if n > tx.Get(rl.max).(numTokens) {
return errors.New("burst exceeded")
}
if dl, ok := ctx.Deadline(); ok {
if tx.Get(rl.cur).(numTokens)+numTokens(dl.Sub(tx.Get(rl.lastAdd).(time.Time))/rl.rate.interval()) < n {
return context.DeadlineExceeded
}
}
tx.Retry()
panic("unreachable")
}); err != nil {
return err.(error)
}
return nil
}
|