pubmed_client/
rate_limit.rs

1//! Rate limiting implementation for NCBI API compliance
2//!
3//! This module provides rate limiting functionality that respects NCBI E-utilities guidelines.
4//! Uses a unified implementation that works across both native and WASM targets.
5
6use crate::time::{Duration, Instant, sleep};
7use std::sync::Arc;
8use std::sync::Mutex;
9use tracing::{debug, instrument};
10
11/// NCBI E-utilities rate limits:
12/// - 3 requests per second without API key
13/// - 10 requests per second with API key
14/// - Violations can result in IP blocking
15///
16/// Token bucket rate limiter for NCBI API compliance
17#[derive(Clone)]
18pub struct RateLimiter {
19    bucket: Arc<Mutex<TokenBucket>>,
20}
21
22struct TokenBucket {
23    tokens: f64,
24    capacity: f64,
25    refill_rate: f64, // tokens per second
26    last_refill: Instant,
27}
28
29impl RateLimiter {
30    /// Create a new rate limiter with the specified rate
31    ///
32    /// # Arguments
33    ///
34    /// * `rate` - Maximum requests per second (e.g., 3.0 for NCBI default)
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use pubmed_client::RateLimiter;
40    ///
41    /// // Create rate limiter for NCBI API without key (3 req/sec)
42    /// let limiter_default = RateLimiter::new(3.0);
43    ///
44    /// // Create rate limiter for NCBI API with key (10 req/sec)
45    /// let limiter_with_key = RateLimiter::new(10.0);
46    /// ```
47    pub fn new(rate: f64) -> Self {
48        let capacity = rate.max(1.0); // Ensure minimum capacity
49        let now = Instant::now();
50        Self {
51            bucket: Arc::new(Mutex::new(TokenBucket {
52                tokens: capacity,
53                capacity,
54                refill_rate: rate,
55                last_refill: now,
56            })),
57        }
58    }
59
60    /// Create rate limiter for NCBI API without API key (3 requests/second)
61    pub fn ncbi_default() -> Self {
62        Self::new(3.0)
63    }
64
65    /// Create rate limiter for NCBI API with API key (10 requests/second)
66    pub fn ncbi_with_key() -> Self {
67        Self::new(10.0)
68    }
69
70    /// Acquire a token, waiting if necessary to respect rate limits
71    ///
72    /// This method implements a token bucket algorithm with the following behavior:
73    /// 1. Check if tokens are available in the bucket
74    /// 2. If available, consume one token and return immediately
75    /// 3. If not available, wait for the appropriate interval
76    /// 4. Refill the bucket and consume one token
77    ///
78    /// # Examples
79    ///
80    /// ```no_run
81    /// use pubmed_client::RateLimiter;
82    ///
83    /// #[tokio::main]
84    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
85    ///     let limiter = RateLimiter::ncbi_default();
86    ///
87    ///     // This will respect the 3 req/sec limit
88    ///     for i in 0..5 {
89    ///         limiter.acquire().await?;
90    ///         println!("Making API call {}", i + 1);
91    ///         // Make your API call here
92    ///     }
93    ///
94    ///     Ok(())
95    /// }
96    /// ```
97    #[instrument(skip(self))]
98    pub async fn acquire(&self) -> crate::Result<()> {
99        let should_wait = {
100            let mut bucket = self.bucket.lock().unwrap();
101            self.refill_bucket(&mut bucket);
102
103            if bucket.tokens >= 1.0 {
104                bucket.tokens -= 1.0;
105                debug!(remaining_tokens = %bucket.tokens, "Token acquired immediately");
106                false
107            } else {
108                debug!("No tokens available, need to wait");
109                true
110            }
111        };
112
113        if should_wait {
114            // Calculate wait time based on rate
115            let wait_duration = Duration::from_secs(1).as_secs_f64() / self.rate();
116            let wait_duration = Duration::from_millis((wait_duration * 1000.0) as u64);
117
118            debug!(
119                wait_duration_ms = wait_duration.as_millis(),
120                "Waiting for rate limit"
121            );
122            sleep(wait_duration).await;
123
124            // After waiting, refill bucket and consume token
125            let mut bucket = self.bucket.lock().unwrap();
126            self.refill_bucket(&mut bucket);
127            bucket.tokens = bucket.tokens.min(bucket.capacity);
128            if bucket.tokens >= 1.0 {
129                bucket.tokens -= 1.0;
130                debug!(remaining_tokens = %bucket.tokens, "Token acquired after waiting");
131            }
132        }
133
134        Ok(())
135    }
136
137    /// Check if a token is available without blocking
138    ///
139    /// Returns `true` if a token is available and can be acquired immediately.
140    /// This method does not consume a token.
141    pub fn check_available(&self) -> bool {
142        let mut bucket = self.bucket.lock().unwrap();
143        self.refill_bucket(&mut bucket);
144        bucket.tokens >= 1.0
145    }
146
147    /// Get current token count (for testing and monitoring)
148    pub fn token_count(&self) -> f64 {
149        let mut bucket = self.bucket.lock().unwrap();
150        self.refill_bucket(&mut bucket);
151        bucket.tokens
152    }
153
154    /// Get the configured rate limit (requests per second)
155    pub fn rate(&self) -> f64 {
156        let bucket = self.bucket.lock().unwrap();
157        bucket.refill_rate
158    }
159
160    /// Refill the token bucket based on elapsed time
161    fn refill_bucket(&self, bucket: &mut TokenBucket) {
162        let now = Instant::now();
163        let elapsed = now.duration_since(bucket.last_refill);
164
165        // Calculate tokens to add based on elapsed time
166        let tokens_to_add = elapsed.as_secs_f64() * bucket.refill_rate;
167        bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.capacity);
168
169        bucket.last_refill = now;
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[tokio::test]
178    async fn test_basic_functionality() {
179        let limiter = RateLimiter::new(5.0);
180
181        // Should be able to acquire tokens
182        limiter.acquire().await.unwrap();
183
184        // Check rate
185        let rate = limiter.rate();
186        assert!((rate - 5.0).abs() < 0.1);
187    }
188
189    #[tokio::test]
190    async fn test_check_available() {
191        let limiter = RateLimiter::new(2.0);
192
193        // Should have tokens available initially
194        assert!(limiter.check_available());
195    }
196
197    #[tokio::test]
198    async fn test_ncbi_presets() {
199        let default_limiter = RateLimiter::ncbi_default();
200        let with_key_limiter = RateLimiter::ncbi_with_key();
201
202        assert!((default_limiter.rate() - 3.0).abs() < 0.1);
203        assert!((with_key_limiter.rate() - 10.0).abs() < 0.1);
204    }
205
206    #[tokio::test]
207    async fn test_rate_limiting_basic() {
208        let limiter = RateLimiter::new(1.0); // 1 request per second
209
210        // Should be able to acquire tokens
211        limiter.acquire().await.unwrap();
212        limiter.acquire().await.unwrap(); // This should involve a wait
213
214        // Rate limiter should still work
215        let tokens = limiter.token_count();
216        assert!(tokens >= 0.0);
217    }
218}