tower/retry/budget/
tps_budget.rs1use std::{
4 fmt,
5 sync::{
6 atomic::{AtomicIsize, Ordering},
7 Mutex,
8 },
9 time::Duration,
10};
11use tokio::time::Instant;
12
13use super::Budget;
14
15pub struct TpsBudget {
27 generation: Mutex<Generation>,
28 reserve: isize,
30 slots: Box<[AtomicIsize]>,
32 window: Duration,
34 writer: AtomicIsize,
37 deposit_amount: isize,
39 withdraw_amount: isize,
41}
42
43#[derive(Debug)]
44struct Generation {
45 index: usize,
47 time: Instant,
49}
50
51impl TpsBudget {
54 pub fn new(ttl: Duration, min_per_sec: u32, retry_percent: f32) -> Self {
70 assert!(ttl >= Duration::from_secs(1));
72 assert!(ttl <= Duration::from_secs(60));
73 assert!(retry_percent >= 0.0);
74 assert!(retry_percent <= 1000.0);
75 assert!(min_per_sec < i32::MAX as u32);
76
77 let (deposit_amount, withdraw_amount) = if retry_percent == 0.0 {
78 (0, 1)
81 } else {
82 (1000, (1000.0 / retry_percent) as isize)
87 };
88 let reserve = (min_per_sec as isize)
89 .saturating_mul(ttl.as_secs() as isize) .saturating_mul(withdraw_amount);
91
92 let windows = 10u32;
94 let mut slots = Vec::with_capacity(windows as usize);
95 for _ in 0..windows {
96 slots.push(AtomicIsize::new(0));
97 }
98
99 TpsBudget {
100 generation: Mutex::new(Generation {
101 index: 0,
102 time: Instant::now(),
103 }),
104 reserve,
105 slots: slots.into_boxed_slice(),
106 window: ttl / windows,
107 writer: AtomicIsize::new(0),
108 deposit_amount,
109 withdraw_amount,
110 }
111 }
112
113 fn expire(&self) {
114 let mut gen = self.generation.lock().expect("generation lock");
115
116 let now = Instant::now();
117 let diff = now.saturating_duration_since(gen.time);
118 if diff < self.window {
119 return;
121 }
122
123 let to_commit = self.writer.swap(0, Ordering::SeqCst);
124 self.slots[gen.index].store(to_commit, Ordering::SeqCst);
125
126 let mut diff = diff;
127 let mut idx = (gen.index + 1) % self.slots.len();
128 while diff > self.window {
129 self.slots[idx].store(0, Ordering::SeqCst);
130 diff -= self.window;
131 idx = (idx + 1) % self.slots.len();
132 }
133
134 gen.index = idx;
135 gen.time = now;
136 }
137
138 fn sum(&self) -> isize {
139 let current = self.writer.load(Ordering::SeqCst);
140 let windowed_sum: isize = self
141 .slots
142 .iter()
143 .map(|slot| slot.load(Ordering::SeqCst))
144 .fold(0, isize::saturating_add);
146
147 current
148 .saturating_add(windowed_sum)
149 .saturating_add(self.reserve)
150 }
151
152 fn put(&self, amt: isize) {
153 self.expire();
154 self.writer.fetch_add(amt, Ordering::SeqCst);
155 }
156
157 fn try_get(&self, amt: isize) -> bool {
158 debug_assert!(amt >= 0);
159
160 self.expire();
161
162 let sum = self.sum();
163 if sum >= amt {
164 self.writer.fetch_add(-amt, Ordering::SeqCst);
165 true
166 } else {
167 false
168 }
169 }
170}
171
172impl Budget for TpsBudget {
173 fn deposit(&self) {
174 self.put(self.deposit_amount)
175 }
176
177 fn withdraw(&self) -> bool {
178 self.try_get(self.withdraw_amount)
179 }
180}
181
182impl Default for TpsBudget {
183 fn default() -> Self {
184 TpsBudget::new(Duration::from_secs(10), 10, 0.2)
185 }
186}
187
188impl fmt::Debug for TpsBudget {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 f.debug_struct("Budget")
191 .field("deposit", &self.deposit_amount)
192 .field("withdraw", &self.withdraw_amount)
193 .field("balance", &self.sum())
194 .finish()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use crate::retry::budget::Budget;
201
202 use super::*;
203 use tokio::time;
204
205 #[test]
206 fn tps_empty() {
207 let bgt = TpsBudget::new(Duration::from_secs(1), 0, 1.0);
208 assert!(!bgt.withdraw());
209 }
210
211 #[tokio::test]
212 async fn tps_leaky() {
213 time::pause();
214
215 let bgt = TpsBudget::new(Duration::from_secs(1), 0, 1.0);
216 bgt.deposit();
217
218 time::advance(Duration::from_secs(3)).await;
219
220 assert!(!bgt.withdraw());
221 }
222
223 #[tokio::test]
224 async fn tps_slots() {
225 time::pause();
226
227 let bgt = TpsBudget::new(Duration::from_secs(1), 0, 0.5);
228 bgt.deposit();
229 bgt.deposit();
230 time::advance(Duration::from_millis(901)).await;
231 assert!(bgt.withdraw());
233
234 time::advance(Duration::from_millis(2001)).await;
236
237 bgt.deposit();
238 time::advance(Duration::from_millis(301)).await;
239 bgt.deposit();
240 time::advance(Duration::from_millis(801)).await;
241 bgt.deposit();
242
243 assert!(bgt.withdraw());
246 }
247
248 #[tokio::test]
249 async fn tps_reserve() {
250 let bgt = TpsBudget::new(Duration::from_secs(1), 5, 1.0);
251 assert!(bgt.withdraw());
252 assert!(bgt.withdraw());
253 assert!(bgt.withdraw());
254 assert!(bgt.withdraw());
255 assert!(bgt.withdraw());
256
257 assert!(!bgt.withdraw());
258 }
259
260 #[test]
261 fn tps_fractional_retry_percent_below_one() {
262 let bgt = TpsBudget::new(Duration::from_secs(1), 0, 0.6);
263 for _ in 0..10 {
264 bgt.deposit();
265 }
266 let allowed = (0..10).filter(|_| bgt.withdraw()).count();
267 assert!(
268 allowed <= 6,
269 "10 deposits at retry_percent=0.6 should allow at most 6 retries, got {allowed}"
270 );
271 }
272}