1use moka::future::Cache as MokaCache;
2use std::sync::Arc;
3use std::time::Duration;
4use tracing::{debug, info};
5
6use pubmed_parser::pmc::PmcArticle;
7
8#[async_trait::async_trait]
22pub trait CacheBackend<V>: Send + Sync
23where
24 V: Send + Sync,
25{
26 async fn get(&self, key: &str) -> Option<V>;
28
29 async fn insert(&self, key: String, value: V);
31
32 async fn clear(&self);
34
35 fn entry_count(&self) -> u64;
38
39 async fn sync(&self);
41}
42
43#[derive(Debug, Clone, Default)]
49pub enum CacheBackendConfig {
50 #[default]
52 Memory,
53 #[cfg(feature = "cache-redis")]
57 Redis {
58 url: String,
60 },
61 #[cfg(feature = "cache-sqlite")]
66 Sqlite {
67 path: std::path::PathBuf,
69 },
70}
71
72#[derive(Debug, Clone)]
74pub struct CacheConfig {
75 pub max_capacity: u64,
77 pub time_to_live: Duration,
79 pub backend: CacheBackendConfig,
81}
82
83impl Default for CacheConfig {
84 fn default() -> Self {
85 Self {
86 max_capacity: 1000,
87 time_to_live: Duration::from_secs(7 * 24 * 60 * 60), backend: CacheBackendConfig::default(),
89 }
90 }
91}
92
93#[derive(Clone)]
99pub struct MemoryCache<V> {
100 cache: MokaCache<String, V>,
101}
102
103impl<V: Clone + Send + Sync + 'static> MemoryCache<V> {
104 pub fn new(config: &CacheConfig) -> Self {
106 let cache = MokaCache::builder()
107 .max_capacity(config.max_capacity)
108 .time_to_live(config.time_to_live)
109 .build();
110 Self { cache }
111 }
112}
113
114#[async_trait::async_trait]
115impl<V: Clone + Send + Sync + 'static> CacheBackend<V> for MemoryCache<V> {
116 async fn get(&self, key: &str) -> Option<V> {
117 let result = self.cache.get(key).await;
118 if result.is_some() {
119 debug!("Cache hit");
120 } else {
121 debug!("Cache miss");
122 }
123 result
124 }
125
126 async fn insert(&self, key: String, value: V) {
127 self.cache.insert(key, value).await;
128 info!("Item cached");
129 }
130
131 async fn clear(&self) {
132 self.cache.invalidate_all();
133 info!("Cache cleared");
134 }
135
136 fn entry_count(&self) -> u64 {
137 self.cache.entry_count()
138 }
139
140 async fn sync(&self) {
141 self.cache.run_pending_tasks().await;
142 }
143}
144
145#[cfg(feature = "cache-redis")]
159#[derive(Clone)]
160pub struct RedisCache<V> {
161 client: redis::Client,
162 ttl: Duration,
163 _phantom: std::marker::PhantomData<fn() -> V>,
164}
165
166#[cfg(feature = "cache-redis")]
167impl<V> RedisCache<V> {
168 pub fn new(url: &str, ttl: Duration) -> Result<Self, redis::RedisError> {
170 let client = redis::Client::open(url)?;
171 Ok(Self {
172 client,
173 ttl,
174 _phantom: std::marker::PhantomData,
175 })
176 }
177}
178
179#[cfg(feature = "cache-redis")]
180#[async_trait::async_trait]
181impl<V> CacheBackend<V> for RedisCache<V>
182where
183 V: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
184{
185 async fn get(&self, key: &str) -> Option<V> {
186 use redis::AsyncCommands;
187 let mut conn = self.client.get_multiplexed_async_connection().await.ok()?;
188 let json: String = conn.get(key).await.ok()?;
189 if json.is_empty() {
190 return None;
191 }
192 let value = serde_json::from_str(&json).ok();
193 if value.is_some() {
194 debug!("Cache hit (Redis)");
195 } else {
196 debug!("Cache miss (Redis): deserialization failed");
197 }
198 value
199 }
200
201 async fn insert(&self, key: String, value: V) {
202 use redis::AsyncCommands;
203 let Ok(mut conn) = self.client.get_multiplexed_async_connection().await else {
204 return;
205 };
206 let Ok(json) = serde_json::to_string(&value) else {
207 return;
208 };
209 let _: Result<(), _> = conn.set_ex(key, json, self.ttl.as_secs()).await;
210 info!("Item cached (Redis)");
211 }
212
213 async fn clear(&self) {
214 let Ok(mut conn) = self.client.get_multiplexed_async_connection().await else {
215 return;
216 };
217 let _: Result<(), _> = redis::cmd("FLUSHDB").query_async(&mut conn).await;
218 info!("Cache cleared (Redis)");
219 }
220
221 fn entry_count(&self) -> u64 {
222 0
223 }
224
225 async fn sync(&self) {
226 }
228}
229
230#[cfg(feature = "cache-sqlite")]
244#[derive(Clone)]
245pub struct SqliteCache<V> {
246 conn: Arc<std::sync::Mutex<rusqlite::Connection>>,
247 ttl: Duration,
248 _phantom: std::marker::PhantomData<fn() -> V>,
249}
250
251#[cfg(feature = "cache-sqlite")]
252impl<V> SqliteCache<V> {
253 pub fn new(path: &std::path::Path, ttl: Duration) -> rusqlite::Result<Self> {
255 let conn = rusqlite::Connection::open(path)?;
256 conn.execute_batch(
257 "CREATE TABLE IF NOT EXISTS cache (
258 key TEXT PRIMARY KEY,
259 value TEXT NOT NULL,
260 expires_at INTEGER NOT NULL
261 );
262 CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache (expires_at);",
263 )?;
264 Ok(Self {
265 conn: Arc::new(std::sync::Mutex::new(conn)),
266 ttl,
267 _phantom: std::marker::PhantomData,
268 })
269 }
270}
271
272#[cfg(feature = "cache-sqlite")]
273fn sqlite_now_secs() -> i64 {
274 std::time::SystemTime::now()
275 .duration_since(std::time::UNIX_EPOCH)
276 .unwrap_or_default()
277 .as_secs() as i64
278}
279
280#[cfg(feature = "cache-sqlite")]
281#[async_trait::async_trait]
282impl<V> CacheBackend<V> for SqliteCache<V>
283where
284 V: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
285{
286 async fn get(&self, key: &str) -> Option<V> {
287 let key = key.to_owned();
288 let conn = Arc::clone(&self.conn);
289 tokio::task::spawn_blocking(move || {
290 let now = sqlite_now_secs();
291 let guard = conn.lock().unwrap();
292 let result: rusqlite::Result<String> = guard.query_row(
293 "SELECT value FROM cache WHERE key = ?1 AND expires_at > ?2",
294 rusqlite::params![key, now],
295 |row| row.get(0),
296 );
297 match result {
298 Ok(json) => {
299 let value = serde_json::from_str(&json).ok();
300 if value.is_some() {
301 debug!("Cache hit (SQLite)");
302 } else {
303 debug!("Cache miss (SQLite): deserialization failed");
304 }
305 value
306 }
307 Err(_) => {
308 debug!("Cache miss (SQLite)");
309 None
310 }
311 }
312 })
313 .await
314 .unwrap_or(None)
315 }
316
317 async fn insert(&self, key: String, value: V) {
318 let conn = Arc::clone(&self.conn);
319 let ttl = self.ttl;
320 tokio::task::spawn_blocking(move || {
321 let Ok(json) = serde_json::to_string(&value) else {
322 return;
323 };
324 let expires_at = sqlite_now_secs() + ttl.as_secs() as i64;
325 let guard = conn.lock().unwrap();
326 let _ = guard.execute(
327 "INSERT OR REPLACE INTO cache (key, value, expires_at) VALUES (?1, ?2, ?3)",
328 rusqlite::params![key, json, expires_at],
329 );
330 info!("Item cached (SQLite)");
331 })
332 .await
333 .ok();
334 }
335
336 async fn clear(&self) {
337 let conn = Arc::clone(&self.conn);
338 tokio::task::spawn_blocking(move || {
339 let guard = conn.lock().unwrap();
340 let _ = guard.execute("DELETE FROM cache", []);
341 info!("Cache cleared (SQLite)");
342 })
343 .await
344 .ok();
345 }
346
347 fn entry_count(&self) -> u64 {
348 let now = sqlite_now_secs();
349 if let Ok(guard) = self.conn.try_lock() {
350 guard
351 .query_row(
352 "SELECT COUNT(*) FROM cache WHERE expires_at > ?1",
353 rusqlite::params![now],
354 |row| row.get::<_, i64>(0),
355 )
356 .map(|c| c as u64)
357 .unwrap_or(0)
358 } else {
359 0
360 }
361 }
362
363 async fn sync(&self) {
364 }
366}
367
368#[derive(Clone)]
393pub struct TypedCache<V: Send + Sync + 'static> {
394 inner: Arc<dyn CacheBackend<V>>,
395}
396
397impl<V: Send + Sync + 'static> TypedCache<V> {
398 pub fn new(backend: impl CacheBackend<V> + 'static) -> Self {
400 Self {
401 inner: Arc::new(backend),
402 }
403 }
404
405 pub async fn get(&self, key: &str) -> Option<V> {
406 self.inner.get(key).await
407 }
408
409 pub async fn insert(&self, key: String, value: V) {
410 self.inner.insert(key, value).await;
411 }
412
413 pub async fn clear(&self) {
414 self.inner.clear().await;
415 }
416
417 pub fn entry_count(&self) -> u64 {
418 self.inner.entry_count()
419 }
420
421 pub async fn sync(&self) {
422 self.inner.sync().await;
423 }
424}
425
426pub type PmcCache = TypedCache<PmcArticle>;
428
429pub fn create_cache(config: &CacheConfig) -> PmcCache {
438 match &config.backend {
439 CacheBackendConfig::Memory => TypedCache::new(MemoryCache::new(config)),
440 #[cfg(feature = "cache-redis")]
441 CacheBackendConfig::Redis { url } => match RedisCache::new(url, config.time_to_live) {
442 Ok(c) => TypedCache::new(c),
443 Err(e) => {
444 tracing::error!("Failed to create Redis cache, falling back to memory: {e}");
445 TypedCache::new(MemoryCache::new(config))
446 }
447 },
448 #[cfg(feature = "cache-sqlite")]
449 CacheBackendConfig::Sqlite { path } => match SqliteCache::new(path, config.time_to_live) {
450 Ok(c) => TypedCache::new(c),
451 Err(e) => {
452 tracing::error!("Failed to create SQLite cache, falling back to memory: {e}");
453 TypedCache::new(MemoryCache::new(config))
454 }
455 },
456 }
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[tokio::test]
468 async fn test_memory_cache_basic() {
469 let config = CacheConfig {
470 max_capacity: 10,
471 time_to_live: Duration::from_secs(60),
472 ..Default::default()
473 };
474 let cache = TypedCache::new(MemoryCache::<String>::new(&config));
475
476 cache.insert("key1".to_string(), "value1".to_string()).await;
477 assert_eq!(cache.get("key1").await, Some("value1".to_string()));
478 assert_eq!(cache.get("nonexistent").await, None);
479
480 cache.clear().await;
481 assert_eq!(cache.get("key1").await, None);
482 }
483
484 #[tokio::test]
485 async fn test_cache_entry_count() {
486 let config = CacheConfig::default();
487 let cache = TypedCache::new(MemoryCache::<String>::new(&config));
488
489 assert_eq!(cache.entry_count(), 0);
490
491 cache.insert("key1".to_string(), "value1".to_string()).await;
492 cache.sync().await;
493 assert_eq!(cache.entry_count(), 1);
494
495 cache.insert("key2".to_string(), "value2".to_string()).await;
496 cache.sync().await;
497 assert_eq!(cache.entry_count(), 2);
498
499 cache.clear().await;
500 cache.sync().await;
501 assert_eq!(cache.entry_count(), 0);
502 }
503}