pubmed_client/
cache.rs

1use moka::future::Cache as MokaCache;
2use std::sync::Arc;
3use std::time::Duration;
4use tracing::{debug, info};
5
6use pubmed_parser::pmc::PmcArticle;
7
8// ---------------------------------------------------------------------------
9// CacheBackend trait
10// ---------------------------------------------------------------------------
11
12/// Async storage backend for a cache.
13///
14/// Implement this trait to provide a custom caching layer.  The type
15/// parameter `V` is the cached value type.
16///
17/// # Object safety
18///
19/// `CacheBackend<V>` is object-safe when used with [`async_trait`], so
20/// [`TypedCache<V>`] can hold any backend behind `Arc<dyn CacheBackend<V>>`.
21#[async_trait::async_trait]
22pub trait CacheBackend<V>: Send + Sync
23where
24    V: Send + Sync,
25{
26    /// Return the cached value for `key`, or `None` on a miss.
27    async fn get(&self, key: &str) -> Option<V>;
28
29    /// Store `value` under `key`.
30    async fn insert(&self, key: String, value: V);
31
32    /// Remove all entries.
33    async fn clear(&self);
34
35    /// Return the number of live entries (best-effort; may return 0 for some
36    /// backends).
37    fn entry_count(&self) -> u64;
38
39    /// Flush any pending internal tasks (useful for testing).
40    async fn sync(&self);
41}
42
43// ---------------------------------------------------------------------------
44// CacheBackendConfig / CacheConfig
45// ---------------------------------------------------------------------------
46
47/// Selects which storage backend to use for caching.
48#[derive(Debug, Clone, Default)]
49pub enum CacheBackendConfig {
50    /// In-memory cache using Moka (default).
51    #[default]
52    Memory,
53    /// Redis-backed persistent cache.
54    ///
55    /// Requires the `cache-redis` feature.
56    #[cfg(feature = "cache-redis")]
57    Redis {
58        /// Redis connection URL, e.g. `"redis://127.0.0.1/"`.
59        url: String,
60    },
61    /// SQLite-backed persistent cache.
62    ///
63    /// Requires the `cache-sqlite` feature.
64    /// Not supported on WASM targets.
65    #[cfg(feature = "cache-sqlite")]
66    Sqlite {
67        /// Path to the SQLite database file.
68        path: std::path::PathBuf,
69    },
70}
71
72/// Configuration for response caching.
73#[derive(Debug, Clone)]
74pub struct CacheConfig {
75    /// Maximum number of items to store (used by the memory backend).
76    pub max_capacity: u64,
77    /// Time-to-live for cached items.
78    pub time_to_live: Duration,
79    /// Which storage backend to use.
80    pub backend: CacheBackendConfig,
81}
82
83impl Default for CacheConfig {
84    fn default() -> Self {
85        Self {
86            max_capacity: 1000,
87            time_to_live: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
88            backend: CacheBackendConfig::default(),
89        }
90    }
91}
92
93// ---------------------------------------------------------------------------
94// Memory backend
95// ---------------------------------------------------------------------------
96
97/// In-memory cache backed by [Moka](https://docs.rs/moka).
98#[derive(Clone)]
99pub struct MemoryCache<V> {
100    cache: MokaCache<String, V>,
101}
102
103impl<V: Clone + Send + Sync + 'static> MemoryCache<V> {
104    /// Create a new in-memory cache from `config`.
105    pub fn new(config: &CacheConfig) -> Self {
106        let cache = MokaCache::builder()
107            .max_capacity(config.max_capacity)
108            .time_to_live(config.time_to_live)
109            .build();
110        Self { cache }
111    }
112}
113
114#[async_trait::async_trait]
115impl<V: Clone + Send + Sync + 'static> CacheBackend<V> for MemoryCache<V> {
116    async fn get(&self, key: &str) -> Option<V> {
117        let result = self.cache.get(key).await;
118        if result.is_some() {
119            debug!("Cache hit");
120        } else {
121            debug!("Cache miss");
122        }
123        result
124    }
125
126    async fn insert(&self, key: String, value: V) {
127        self.cache.insert(key, value).await;
128        info!("Item cached");
129    }
130
131    async fn clear(&self) {
132        self.cache.invalidate_all();
133        info!("Cache cleared");
134    }
135
136    fn entry_count(&self) -> u64 {
137        self.cache.entry_count()
138    }
139
140    async fn sync(&self) {
141        self.cache.run_pending_tasks().await;
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Redis backend  (feature = "cache-redis")
147// ---------------------------------------------------------------------------
148
149/// Redis-backed cache using JSON serialisation.
150///
151/// Each cache operation opens a new multiplexed connection.  TTL is applied
152/// per entry via `SET … EX`.
153///
154/// `entry_count()` always returns 0; Redis `DBSIZE` counts all keys and
155/// cannot be scoped to this cache without a full scan.
156///
157/// Requires the `cache-redis` feature.
158#[cfg(feature = "cache-redis")]
159#[derive(Clone)]
160pub struct RedisCache<V> {
161    client: redis::Client,
162    ttl: Duration,
163    _phantom: std::marker::PhantomData<fn() -> V>,
164}
165
166#[cfg(feature = "cache-redis")]
167impl<V> RedisCache<V> {
168    /// Open a Redis connection pool at `url`.
169    pub fn new(url: &str, ttl: Duration) -> Result<Self, redis::RedisError> {
170        let client = redis::Client::open(url)?;
171        Ok(Self {
172            client,
173            ttl,
174            _phantom: std::marker::PhantomData,
175        })
176    }
177}
178
179#[cfg(feature = "cache-redis")]
180#[async_trait::async_trait]
181impl<V> CacheBackend<V> for RedisCache<V>
182where
183    V: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
184{
185    async fn get(&self, key: &str) -> Option<V> {
186        use redis::AsyncCommands;
187        let mut conn = self.client.get_multiplexed_async_connection().await.ok()?;
188        let json: String = conn.get(key).await.ok()?;
189        if json.is_empty() {
190            return None;
191        }
192        let value = serde_json::from_str(&json).ok();
193        if value.is_some() {
194            debug!("Cache hit (Redis)");
195        } else {
196            debug!("Cache miss (Redis): deserialization failed");
197        }
198        value
199    }
200
201    async fn insert(&self, key: String, value: V) {
202        use redis::AsyncCommands;
203        let Ok(mut conn) = self.client.get_multiplexed_async_connection().await else {
204            return;
205        };
206        let Ok(json) = serde_json::to_string(&value) else {
207            return;
208        };
209        let _: Result<(), _> = conn.set_ex(key, json, self.ttl.as_secs()).await;
210        info!("Item cached (Redis)");
211    }
212
213    async fn clear(&self) {
214        let Ok(mut conn) = self.client.get_multiplexed_async_connection().await else {
215            return;
216        };
217        let _: Result<(), _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
218        info!("Cache cleared (Redis)");
219    }
220
221    fn entry_count(&self) -> u64 {
222        0
223    }
224
225    async fn sync(&self) {
226        // No-op for Redis
227    }
228}
229
230// ---------------------------------------------------------------------------
231// SQLite backend  (feature = "cache-sqlite")
232// ---------------------------------------------------------------------------
233
234/// SQLite-backed cache using JSON serialisation.
235///
236/// The database schema is created automatically on first use.  Expired entries
237/// are not purged automatically; call [`CacheBackend::clear`] or run
238/// `DELETE FROM cache WHERE expires_at <= unixepoch()` periodically.
239///
240/// `entry_count()` uses `try_lock`; returns 0 if the mutex is currently held.
241///
242/// Requires the `cache-sqlite` feature.  Not available on WASM targets.
243#[cfg(feature = "cache-sqlite")]
244#[derive(Clone)]
245pub struct SqliteCache<V> {
246    conn: Arc<std::sync::Mutex<rusqlite::Connection>>,
247    ttl: Duration,
248    _phantom: std::marker::PhantomData<fn() -> V>,
249}
250
251#[cfg(feature = "cache-sqlite")]
252impl<V> SqliteCache<V> {
253    /// Open (or create) a SQLite database at `path`.
254    pub fn new(path: &std::path::Path, ttl: Duration) -> rusqlite::Result<Self> {
255        let conn = rusqlite::Connection::open(path)?;
256        conn.execute_batch(
257            "CREATE TABLE IF NOT EXISTS cache (
258                key        TEXT    PRIMARY KEY,
259                value      TEXT    NOT NULL,
260                expires_at INTEGER NOT NULL
261            );
262            CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache (expires_at);",
263        )?;
264        Ok(Self {
265            conn: Arc::new(std::sync::Mutex::new(conn)),
266            ttl,
267            _phantom: std::marker::PhantomData,
268        })
269    }
270}
271
272#[cfg(feature = "cache-sqlite")]
273fn sqlite_now_secs() -> i64 {
274    std::time::SystemTime::now()
275        .duration_since(std::time::UNIX_EPOCH)
276        .unwrap_or_default()
277        .as_secs() as i64
278}
279
280#[cfg(feature = "cache-sqlite")]
281#[async_trait::async_trait]
282impl<V> CacheBackend<V> for SqliteCache<V>
283where
284    V: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
285{
286    async fn get(&self, key: &str) -> Option<V> {
287        let key = key.to_owned();
288        let conn = Arc::clone(&self.conn);
289        tokio::task::spawn_blocking(move || {
290            let now = sqlite_now_secs();
291            let guard = conn.lock().unwrap();
292            let result: rusqlite::Result<String> = guard.query_row(
293                "SELECT value FROM cache WHERE key = ?1 AND expires_at > ?2",
294                rusqlite::params![key, now],
295                |row| row.get(0),
296            );
297            match result {
298                Ok(json) => {
299                    let value = serde_json::from_str(&json).ok();
300                    if value.is_some() {
301                        debug!("Cache hit (SQLite)");
302                    } else {
303                        debug!("Cache miss (SQLite): deserialization failed");
304                    }
305                    value
306                }
307                Err(_) => {
308                    debug!("Cache miss (SQLite)");
309                    None
310                }
311            }
312        })
313        .await
314        .unwrap_or(None)
315    }
316
317    async fn insert(&self, key: String, value: V) {
318        let conn = Arc::clone(&self.conn);
319        let ttl = self.ttl;
320        tokio::task::spawn_blocking(move || {
321            let Ok(json) = serde_json::to_string(&value) else {
322                return;
323            };
324            let expires_at = sqlite_now_secs() + ttl.as_secs() as i64;
325            let guard = conn.lock().unwrap();
326            let _ = guard.execute(
327                "INSERT OR REPLACE INTO cache (key, value, expires_at) VALUES (?1, ?2, ?3)",
328                rusqlite::params![key, json, expires_at],
329            );
330            info!("Item cached (SQLite)");
331        })
332        .await
333        .ok();
334    }
335
336    async fn clear(&self) {
337        let conn = Arc::clone(&self.conn);
338        tokio::task::spawn_blocking(move || {
339            let guard = conn.lock().unwrap();
340            let _ = guard.execute("DELETE FROM cache", []);
341            info!("Cache cleared (SQLite)");
342        })
343        .await
344        .ok();
345    }
346
347    fn entry_count(&self) -> u64 {
348        let now = sqlite_now_secs();
349        if let Ok(guard) = self.conn.try_lock() {
350            guard
351                .query_row(
352                    "SELECT COUNT(*) FROM cache WHERE expires_at > ?1",
353                    rusqlite::params![now],
354                    |row| row.get::<_, i64>(0),
355                )
356                .map(|c| c as u64)
357                .unwrap_or(0)
358        } else {
359            0
360        }
361    }
362
363    async fn sync(&self) {
364        // No-op for SQLite
365    }
366}
367
368// ---------------------------------------------------------------------------
369// TypedCache — type-erased wrapper
370// ---------------------------------------------------------------------------
371
372/// A type-erased, cloneable cache for values of type `V`.
373///
374/// Wraps any [`CacheBackend<V>`] behind an `Arc<dyn …>`, so it can be
375/// cloned cheaply and stored in structs without generics leaking into the
376/// public API.
377///
378/// The concrete backend is selected via [`CacheConfig::backend`].
379///
380/// # Example
381///
382/// ```no_run
383/// use pubmed_client::cache::{TypedCache, MemoryCache, CacheConfig};
384///
385/// # tokio_test::block_on(async {
386/// let config = CacheConfig::default();
387/// let cache: TypedCache<String> = TypedCache::new(MemoryCache::new(&config));
388/// cache.insert("key".to_string(), "value".to_string()).await;
389/// assert_eq!(cache.get("key").await, Some("value".to_string()));
390/// # });
391/// ```
392#[derive(Clone)]
393pub struct TypedCache<V: Send + Sync + 'static> {
394    inner: Arc<dyn CacheBackend<V>>,
395}
396
397impl<V: Send + Sync + 'static> TypedCache<V> {
398    /// Wrap `backend` in a [`TypedCache`].
399    pub fn new(backend: impl CacheBackend<V> + 'static) -> Self {
400        Self {
401            inner: Arc::new(backend),
402        }
403    }
404
405    pub async fn get(&self, key: &str) -> Option<V> {
406        self.inner.get(key).await
407    }
408
409    pub async fn insert(&self, key: String, value: V) {
410        self.inner.insert(key, value).await;
411    }
412
413    pub async fn clear(&self) {
414        self.inner.clear().await;
415    }
416
417    pub fn entry_count(&self) -> u64 {
418        self.inner.entry_count()
419    }
420
421    pub async fn sync(&self) {
422        self.inner.sync().await;
423    }
424}
425
426/// Type alias for the PMC full-text response cache.
427pub type PmcCache = TypedCache<PmcArticle>;
428
429// ---------------------------------------------------------------------------
430// Factory
431// ---------------------------------------------------------------------------
432
433/// Create a [`PmcCache`] from `config`.
434///
435/// Falls back to the in-memory backend (with a logged error) when the
436/// configured backend cannot be initialised.
437pub fn create_cache(config: &CacheConfig) -> PmcCache {
438    match &config.backend {
439        CacheBackendConfig::Memory => TypedCache::new(MemoryCache::new(config)),
440        #[cfg(feature = "cache-redis")]
441        CacheBackendConfig::Redis { url } => match RedisCache::new(url, config.time_to_live) {
442            Ok(c) => TypedCache::new(c),
443            Err(e) => {
444                tracing::error!("Failed to create Redis cache, falling back to memory: {e}");
445                TypedCache::new(MemoryCache::new(config))
446            }
447        },
448        #[cfg(feature = "cache-sqlite")]
449        CacheBackendConfig::Sqlite { path } => match SqliteCache::new(path, config.time_to_live) {
450            Ok(c) => TypedCache::new(c),
451            Err(e) => {
452                tracing::error!("Failed to create SQLite cache, falling back to memory: {e}");
453                TypedCache::new(MemoryCache::new(config))
454            }
455        },
456    }
457}
458
459// ---------------------------------------------------------------------------
460// Tests
461// ---------------------------------------------------------------------------
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[tokio::test]
468    async fn test_memory_cache_basic() {
469        let config = CacheConfig {
470            max_capacity: 10,
471            time_to_live: Duration::from_secs(60),
472            ..Default::default()
473        };
474        let cache = TypedCache::new(MemoryCache::<String>::new(&config));
475
476        cache.insert("key1".to_string(), "value1".to_string()).await;
477        assert_eq!(cache.get("key1").await, Some("value1".to_string()));
478        assert_eq!(cache.get("nonexistent").await, None);
479
480        cache.clear().await;
481        assert_eq!(cache.get("key1").await, None);
482    }
483
484    #[tokio::test]
485    async fn test_cache_entry_count() {
486        let config = CacheConfig::default();
487        let cache = TypedCache::new(MemoryCache::<String>::new(&config));
488
489        assert_eq!(cache.entry_count(), 0);
490
491        cache.insert("key1".to_string(), "value1".to_string()).await;
492        cache.sync().await;
493        assert_eq!(cache.entry_count(), 1);
494
495        cache.insert("key2".to_string(), "value2".to_string()).await;
496        cache.sync().await;
497        assert_eq!(cache.entry_count(), 2);
498
499        cache.clear().await;
500        cache.sync().await;
501        assert_eq!(cache.entry_count(), 0);
502    }
503}