pubmed_client/
retry.rs

1//! Retry logic with exponential backoff for handling transient network failures
2//!
3//! This module provides a configurable retry mechanism with exponential backoff
4//! and jitter to handle temporary network issues when communicating with NCBI APIs.
5
6use std::{fmt::Display, future::Future};
7
8use crate::time::{Duration, sleep};
9use rand::Rng;
10use tracing::{debug, warn};
11
12/// Configuration for retry behavior
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Maximum number of retry attempts
16    pub max_retries: usize,
17    /// Initial delay before first retry
18    pub initial_delay: Duration,
19    /// Maximum delay between retries
20    pub max_delay: Duration,
21    /// Base for exponential backoff (typically 2.0)
22    pub backoff_base: f64,
23    /// Whether to add jitter to retry delays
24    pub use_jitter: bool,
25}
26
27impl Default for RetryConfig {
28    fn default() -> Self {
29        Self {
30            max_retries: 3,
31            initial_delay: Duration::from_secs(1),
32            max_delay: Duration::from_secs(60),
33            backoff_base: 2.0,
34            use_jitter: true,
35        }
36    }
37}
38
39impl RetryConfig {
40    /// Create a new retry configuration with custom settings
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Set the maximum number of retries
46    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
47        self.max_retries = max_retries;
48        self
49    }
50
51    /// Set the initial delay
52    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
53        self.initial_delay = delay;
54        self
55    }
56
57    /// Set the maximum delay
58    pub fn with_max_delay(mut self, delay: Duration) -> Self {
59        self.max_delay = delay;
60        self
61    }
62
63    /// Disable jitter
64    pub fn without_jitter(mut self) -> Self {
65        self.use_jitter = false;
66        self
67    }
68
69    /// Calculate delay for a given retry attempt
70    fn calculate_delay(&self, attempt: usize) -> Duration {
71        let base_delay = self.initial_delay.as_millis() as f64;
72        let exponential_delay = base_delay * self.backoff_base.powi(attempt as i32);
73        let capped_delay = exponential_delay.min(self.max_delay.as_millis() as f64);
74
75        let final_delay = if self.use_jitter {
76            // Add jitter: random value between 0.5 and 1.5 of the calculated delay
77            let mut rng = rand::thread_rng();
78            let jitter_factor = rng.gen_range(0.5..1.5);
79            capped_delay * jitter_factor
80        } else {
81            capped_delay
82        };
83
84        Duration::from_millis(final_delay as u64)
85    }
86}
87
88/// Trait for errors that can be retried
89pub trait RetryableError {
90    /// Returns true if the error is transient and the operation should be retried
91    fn is_retryable(&self) -> bool;
92
93    /// Returns a human-readable description of why the error is/isn't retryable
94    fn retry_reason(&self) -> &str {
95        if self.is_retryable() {
96            "Transient error, will retry"
97        } else {
98            "Non-transient error, will not retry"
99        }
100    }
101}
102
103/// Execute an operation with retry logic
104///
105/// # Arguments
106///
107/// * `operation` - A closure that returns a future with the operation to retry
108/// * `config` - Retry configuration
109/// * `operation_name` - A descriptive name for logging
110///
111/// # Returns
112///
113/// Returns the result of the operation, or the last error if all retries failed
114pub async fn with_retry<F, Fut, T, E>(
115    mut operation: F,
116    config: &RetryConfig,
117    operation_name: &str,
118) -> Result<T, E>
119where
120    F: FnMut() -> Fut,
121    Fut: Future<Output = Result<T, E>>,
122    E: RetryableError + Display,
123{
124    let mut attempt = 0;
125    let mut last_error = None;
126
127    while attempt <= config.max_retries {
128        debug!(
129            operation = operation_name,
130            attempt = attempt,
131            max_retries = config.max_retries,
132            "Attempting operation"
133        );
134
135        match operation().await {
136            Ok(result) => {
137                if attempt > 0 {
138                    debug!(
139                        operation = operation_name,
140                        attempt = attempt,
141                        "Operation succeeded after retry"
142                    );
143                }
144                return Ok(result);
145            }
146            Err(error) => {
147                debug!(
148                    operation = operation_name,
149                    error = %error,
150                    is_retryable = error.is_retryable(),
151                    reason = error.retry_reason(),
152                    "Error encountered"
153                );
154
155                if !error.is_retryable() {
156                    debug!(
157                        operation = operation_name,
158                        error = %error,
159                        reason = error.retry_reason(),
160                        "Non-retryable error encountered"
161                    );
162                    return Err(error);
163                }
164
165                last_error = Some(error);
166
167                if attempt < config.max_retries {
168                    let delay = config.calculate_delay(attempt);
169                    debug!(
170                        operation = operation_name,
171                        attempt = attempt + 1,
172                        max_retries = config.max_retries,
173                        delay_ms = delay.as_millis(),
174                        error = %last_error.as_ref().unwrap(),
175                        "Retryable error encountered, will retry after delay"
176                    );
177                    sleep(delay).await;
178                } else {
179                    warn!(
180                        operation = operation_name,
181                        attempts = attempt + 1,
182                        error = %last_error.as_ref().unwrap(),
183                        "Max retries exceeded, operation failed"
184                    );
185                }
186            }
187        }
188
189        attempt += 1;
190    }
191
192    // This should never be reached due to the loop logic, but just in case
193    Err(last_error.unwrap())
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use std::sync::Arc;
200    use std::sync::atomic::{AtomicUsize, Ordering};
201
202    #[derive(Debug, thiserror::Error)]
203    enum TestError {
204        #[error("Retryable error")]
205        Retryable,
206        #[error("Non-retryable error")]
207        NonRetryable,
208    }
209
210    impl RetryableError for TestError {
211        fn is_retryable(&self) -> bool {
212            matches!(self, TestError::Retryable)
213        }
214    }
215
216    #[tokio::test]
217    async fn test_successful_operation() {
218        let config = RetryConfig::new().without_jitter();
219        let counter = Arc::new(AtomicUsize::new(0));
220        let counter_clone = counter.clone();
221
222        let result = with_retry(
223            || {
224                counter_clone.fetch_add(1, Ordering::SeqCst);
225                async { Ok::<_, TestError>(42) }
226            },
227            &config,
228            "test_operation",
229        )
230        .await;
231
232        assert_eq!(result.unwrap(), 42);
233        assert_eq!(counter.load(Ordering::SeqCst), 1); // Called only once
234    }
235
236    #[tokio::test]
237    async fn test_retry_then_success() {
238        let config = RetryConfig::new()
239            .with_max_retries(3)
240            .with_initial_delay(Duration::from_millis(10))
241            .without_jitter();
242
243        let counter = Arc::new(AtomicUsize::new(0));
244        let counter_clone = counter.clone();
245
246        let result = with_retry(
247            || {
248                let count = counter_clone.fetch_add(1, Ordering::SeqCst);
249                async move {
250                    if count < 2 {
251                        Err(TestError::Retryable)
252                    } else {
253                        Ok(42)
254                    }
255                }
256            },
257            &config,
258            "test_operation",
259        )
260        .await;
261
262        assert_eq!(result.unwrap(), 42);
263        assert_eq!(counter.load(Ordering::SeqCst), 3); // Failed twice, succeeded on third
264    }
265
266    #[tokio::test]
267    async fn test_non_retryable_error() {
268        let config = RetryConfig::new().with_max_retries(3);
269        let counter = Arc::new(AtomicUsize::new(0));
270        let counter_clone = counter.clone();
271
272        let result = with_retry(
273            || {
274                counter_clone.fetch_add(1, Ordering::SeqCst);
275                async { Err::<i32, _>(TestError::NonRetryable) }
276            },
277            &config,
278            "test_operation",
279        )
280        .await;
281
282        assert!(matches!(result, Err(TestError::NonRetryable)));
283        assert_eq!(counter.load(Ordering::SeqCst), 1); // Called only once
284    }
285
286    #[tokio::test]
287    async fn test_max_retries_exceeded() {
288        let config = RetryConfig::new()
289            .with_max_retries(2)
290            .with_initial_delay(Duration::from_millis(10))
291            .without_jitter();
292
293        let counter = Arc::new(AtomicUsize::new(0));
294        let counter_clone = counter.clone();
295
296        let result = with_retry(
297            || {
298                counter_clone.fetch_add(1, Ordering::SeqCst);
299                async { Err::<i32, _>(TestError::Retryable) }
300            },
301            &config,
302            "test_operation",
303        )
304        .await;
305
306        assert!(matches!(result, Err(TestError::Retryable)));
307        assert_eq!(counter.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
308    }
309
310    #[test]
311    fn test_exponential_backoff_calculation() {
312        let config = RetryConfig::new()
313            .with_initial_delay(Duration::from_secs(1))
314            .with_max_delay(Duration::from_secs(30))
315            .without_jitter();
316
317        // Test exponential growth
318        assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
319        assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
320        assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
321        assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
322        assert_eq!(config.calculate_delay(4), Duration::from_secs(16));
323
324        // Test max delay cap
325        assert_eq!(config.calculate_delay(5), Duration::from_secs(30)); // Would be 32, capped at 30
326        assert_eq!(config.calculate_delay(10), Duration::from_secs(30)); // Still capped
327    }
328
329    #[test]
330    fn test_jitter() {
331        let config = RetryConfig::new().with_initial_delay(Duration::from_secs(1));
332
333        // With jitter, delays should vary
334        let delay1 = config.calculate_delay(1);
335        let delay2 = config.calculate_delay(1);
336
337        // Both should be between 1-3 seconds (2 seconds * 0.5-1.5 jitter)
338        assert!(delay1.as_millis() >= 1000);
339        assert!(delay1.as_millis() <= 3000);
340        assert!(delay2.as_millis() >= 1000);
341        assert!(delay2.as_millis() <= 3000);
342
343        // They're unlikely to be exactly the same with jitter
344        // (though theoretically possible, so we don't assert inequality)
345    }
346}