1use std::{fmt::Display, future::Future};
7
8use crate::time::{Duration, sleep};
9use rand::Rng;
10use tracing::{debug, warn};
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub initial_delay: Duration,
19 pub max_delay: Duration,
21 pub backoff_base: f64,
23 pub use_jitter: bool,
25}
26
27impl Default for RetryConfig {
28 fn default() -> Self {
29 Self {
30 max_retries: 3,
31 initial_delay: Duration::from_secs(1),
32 max_delay: Duration::from_secs(60),
33 backoff_base: 2.0,
34 use_jitter: true,
35 }
36 }
37}
38
39impl RetryConfig {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
47 self.max_retries = max_retries;
48 self
49 }
50
51 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
53 self.initial_delay = delay;
54 self
55 }
56
57 pub fn with_max_delay(mut self, delay: Duration) -> Self {
59 self.max_delay = delay;
60 self
61 }
62
63 pub fn without_jitter(mut self) -> Self {
65 self.use_jitter = false;
66 self
67 }
68
69 fn calculate_delay(&self, attempt: usize) -> Duration {
71 let base_delay = self.initial_delay.as_millis() as f64;
72 let exponential_delay = base_delay * self.backoff_base.powi(attempt as i32);
73 let capped_delay = exponential_delay.min(self.max_delay.as_millis() as f64);
74
75 let final_delay = if self.use_jitter {
76 let mut rng = rand::thread_rng();
78 let jitter_factor = rng.gen_range(0.5..1.5);
79 capped_delay * jitter_factor
80 } else {
81 capped_delay
82 };
83
84 Duration::from_millis(final_delay as u64)
85 }
86}
87
88pub trait RetryableError {
90 fn is_retryable(&self) -> bool;
92
93 fn retry_reason(&self) -> &str {
95 if self.is_retryable() {
96 "Transient error, will retry"
97 } else {
98 "Non-transient error, will not retry"
99 }
100 }
101}
102
103pub async fn with_retry<F, Fut, T, E>(
115 mut operation: F,
116 config: &RetryConfig,
117 operation_name: &str,
118) -> Result<T, E>
119where
120 F: FnMut() -> Fut,
121 Fut: Future<Output = Result<T, E>>,
122 E: RetryableError + Display,
123{
124 let mut attempt = 0;
125 let mut last_error = None;
126
127 while attempt <= config.max_retries {
128 debug!(
129 operation = operation_name,
130 attempt = attempt,
131 max_retries = config.max_retries,
132 "Attempting operation"
133 );
134
135 match operation().await {
136 Ok(result) => {
137 if attempt > 0 {
138 debug!(
139 operation = operation_name,
140 attempt = attempt,
141 "Operation succeeded after retry"
142 );
143 }
144 return Ok(result);
145 }
146 Err(error) => {
147 debug!(
148 operation = operation_name,
149 error = %error,
150 is_retryable = error.is_retryable(),
151 reason = error.retry_reason(),
152 "Error encountered"
153 );
154
155 if !error.is_retryable() {
156 debug!(
157 operation = operation_name,
158 error = %error,
159 reason = error.retry_reason(),
160 "Non-retryable error encountered"
161 );
162 return Err(error);
163 }
164
165 last_error = Some(error);
166
167 if attempt < config.max_retries {
168 let delay = config.calculate_delay(attempt);
169 debug!(
170 operation = operation_name,
171 attempt = attempt + 1,
172 max_retries = config.max_retries,
173 delay_ms = delay.as_millis(),
174 error = %last_error.as_ref().unwrap(),
175 "Retryable error encountered, will retry after delay"
176 );
177 sleep(delay).await;
178 } else {
179 warn!(
180 operation = operation_name,
181 attempts = attempt + 1,
182 error = %last_error.as_ref().unwrap(),
183 "Max retries exceeded, operation failed"
184 );
185 }
186 }
187 }
188
189 attempt += 1;
190 }
191
192 Err(last_error.unwrap())
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::sync::Arc;
200 use std::sync::atomic::{AtomicUsize, Ordering};
201
202 #[derive(Debug, thiserror::Error)]
203 enum TestError {
204 #[error("Retryable error")]
205 Retryable,
206 #[error("Non-retryable error")]
207 NonRetryable,
208 }
209
210 impl RetryableError for TestError {
211 fn is_retryable(&self) -> bool {
212 matches!(self, TestError::Retryable)
213 }
214 }
215
216 #[tokio::test]
217 async fn test_successful_operation() {
218 let config = RetryConfig::new().without_jitter();
219 let counter = Arc::new(AtomicUsize::new(0));
220 let counter_clone = counter.clone();
221
222 let result = with_retry(
223 || {
224 counter_clone.fetch_add(1, Ordering::SeqCst);
225 async { Ok::<_, TestError>(42) }
226 },
227 &config,
228 "test_operation",
229 )
230 .await;
231
232 assert_eq!(result.unwrap(), 42);
233 assert_eq!(counter.load(Ordering::SeqCst), 1); }
235
236 #[tokio::test]
237 async fn test_retry_then_success() {
238 let config = RetryConfig::new()
239 .with_max_retries(3)
240 .with_initial_delay(Duration::from_millis(10))
241 .without_jitter();
242
243 let counter = Arc::new(AtomicUsize::new(0));
244 let counter_clone = counter.clone();
245
246 let result = with_retry(
247 || {
248 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
249 async move {
250 if count < 2 {
251 Err(TestError::Retryable)
252 } else {
253 Ok(42)
254 }
255 }
256 },
257 &config,
258 "test_operation",
259 )
260 .await;
261
262 assert_eq!(result.unwrap(), 42);
263 assert_eq!(counter.load(Ordering::SeqCst), 3); }
265
266 #[tokio::test]
267 async fn test_non_retryable_error() {
268 let config = RetryConfig::new().with_max_retries(3);
269 let counter = Arc::new(AtomicUsize::new(0));
270 let counter_clone = counter.clone();
271
272 let result = with_retry(
273 || {
274 counter_clone.fetch_add(1, Ordering::SeqCst);
275 async { Err::<i32, _>(TestError::NonRetryable) }
276 },
277 &config,
278 "test_operation",
279 )
280 .await;
281
282 assert!(matches!(result, Err(TestError::NonRetryable)));
283 assert_eq!(counter.load(Ordering::SeqCst), 1); }
285
286 #[tokio::test]
287 async fn test_max_retries_exceeded() {
288 let config = RetryConfig::new()
289 .with_max_retries(2)
290 .with_initial_delay(Duration::from_millis(10))
291 .without_jitter();
292
293 let counter = Arc::new(AtomicUsize::new(0));
294 let counter_clone = counter.clone();
295
296 let result = with_retry(
297 || {
298 counter_clone.fetch_add(1, Ordering::SeqCst);
299 async { Err::<i32, _>(TestError::Retryable) }
300 },
301 &config,
302 "test_operation",
303 )
304 .await;
305
306 assert!(matches!(result, Err(TestError::Retryable)));
307 assert_eq!(counter.load(Ordering::SeqCst), 3); }
309
310 #[test]
311 fn test_exponential_backoff_calculation() {
312 let config = RetryConfig::new()
313 .with_initial_delay(Duration::from_secs(1))
314 .with_max_delay(Duration::from_secs(30))
315 .without_jitter();
316
317 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
319 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
320 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
321 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
322 assert_eq!(config.calculate_delay(4), Duration::from_secs(16));
323
324 assert_eq!(config.calculate_delay(5), Duration::from_secs(30)); assert_eq!(config.calculate_delay(10), Duration::from_secs(30)); }
328
329 #[test]
330 fn test_jitter() {
331 let config = RetryConfig::new().with_initial_delay(Duration::from_secs(1));
332
333 let delay1 = config.calculate_delay(1);
335 let delay2 = config.calculate_delay(1);
336
337 assert!(delay1.as_millis() >= 1000);
339 assert!(delay1.as_millis() <= 3000);
340 assert!(delay2.as_millis() >= 1000);
341 assert!(delay2.as_millis() <= 3000);
342
343 }
346}