Source code for cache.async_cache

import asyncio
import logging
import time as _time
from collections import defaultdict

from collections import OrderedDict

from .lru import LRU

# sentinel for set ttl param to distinguish default vs explicit None
_SENTINEL = object()
_MISSING = object()  # sentinel for get_if_present miss

logger = logging.getLogger("async_cache")


[docs] class AsyncCache: class _Cache(LRU): def __init__(self, maxsize, items=None): super().__init__(maxsize=maxsize) if items: # Bulk-load from backend without triggering eviction per-item for k, v in items.items(): OrderedDict.__setitem__(self, k, v) def __contains__(self, key): if not OrderedDict.__contains__(self, key): return False key_expiration = OrderedDict.__getitem__(self, key)[1] if key_expiration and key_expiration < _time.monotonic(): with self._lock: # re-check under lock to avoid TOCTOU if OrderedDict.__contains__(self, key): OrderedDict.__delitem__(self, key) return False return True def __getitem__(self, key): value = super().__getitem__(key)[0] return value def get_if_present(self, key): """Single-lock contains + TTL check + value read + move_to_end. Returns value on hit, _MISSING on miss/expired.""" with self._lock: # O(1) dict contains check (bypasses LRU's __contains__) if not OrderedDict.__contains__(self, key): return _MISSING # raw read (bypasses LRU's __getitem__ which also locks) pair = OrderedDict.__getitem__(self, key) expiration = pair[1] if expiration and expiration < _time.monotonic(): # expired — remove under same lock OrderedDict.__delitem__(self, key) return _MISSING # promote in LRU order self.move_to_end(key) return pair[0] def _set(self, key, value, expiration): # Use LRU's __setitem__ to ensure eviction logic runs LRU.__setitem__(self, key, (value, expiration)) def __init__(self, maxsize=128, default_ttl=None, batch_window_ms=5, max_batch_size=100, backup=None, remote_cache=None): self.maxsize = maxsize self.default_ttl = default_ttl self.batch_window_ms = batch_window_ms self.max_batch_size = max_batch_size self.backup = backup self.remote_cache = remote_cache # If a backup has persisted data, load it as warmup init_items = None if backup is not None: raw = backup.load() if raw: # Filter expired entries; raw values are (value, expiration) now = _time.monotonic() init_items = {} for k, (v, exp) in raw.items(): if exp is not None and exp < now: continue init_items[k] = (v, exp) self.cache = self._Cache(maxsize=maxsize, items=init_items) self._pending = {} self._batch_pending = [] self._batch_lock = asyncio.Lock() self._pending_lock = asyncio.Lock() # protects thundering herd (single loader pending) from races under concurrency self._batch_timer = None self.hits = 0 self.misses = 0
[docs] async def get(self, key, loader=None, batch_loader=None, ttl=None): """Two-tier cache get: L1 (local) -> L2 (remote) -> loader. Read pattern: 1. Check L1 (in-memory) — immediate 2. If miss and remote_cache configured, check L2 (remote) 3. If miss, call loader (with thundering herd protection) 4. Write result to L1 synchronously, L2 asynchronously Thundering herd protection and batch loading work with remote_cache. Redis errors are caught and logged — cache degrades to L1 only. """ # L1 hit path - single lock acquisition for contains+TTL+read+promote result = self.cache.get_if_present(key) if result is not _MISSING: self.hits += 1 return result # L1 miss — try L2 (remote cache) if configured if self.remote_cache is not None: try: l2_value = await self.remote_cache.get(self._remote_key(key)) if l2_value is not None: # L2 hit: populate L1 and return self.hits += 1 ttl_arg = _SENTINEL if ttl is None else ttl self.set(key, l2_value, ttl=ttl_arg, write_remote=False) return l2_value except Exception: logger.debug("remote_cache.get failed, degrading to L1", exc_info=True) # L1+L2 miss - count it self.misses += 1 if loader is None and batch_loader is None: return None if loader is not None: # single loader with herd protection async with self._pending_lock: if key in self._pending: # waiter: future already set by leader fut = self._pending[key] is_leader = False else: # leader: create fut fut = asyncio.Future() self._pending[key] = fut is_leader = True if not is_leader: # await result from leader return await fut # leader only: perform load (lock released to avoid serializing loads) try: value = await loader() ttl_arg = _SENTINEL if ttl is None else ttl # L1 write (synchronous, immediate) self.set(key, value, ttl=ttl_arg) fut.set_result(value) return value except Exception as exc: fut.set_exception(exc) fut.exception() raise finally: # cleanup under lock async with self._pending_lock: self._pending.pop(key, None) # batch_loader mode return await self._batch_get(key, batch_loader, ttl)
async def _batch_get(self, key, batch_loader, ttl): fut = asyncio.Future() async with self._batch_lock: self._batch_pending.append((key, fut, batch_loader, ttl)) if len(self._batch_pending) >= self.max_batch_size: await self._flush_batch() elif self._batch_timer is None: self._batch_timer = asyncio.create_task(self._schedule_flush()) return await fut async def _schedule_flush(self): await asyncio.sleep(self.batch_window_ms / 1000.0) async with self._batch_lock: await self._flush_batch() self._batch_timer = None async def _flush_batch(self): if not self._batch_pending: return # group by batch_loader (support mixed) groups = defaultdict(list) for item in self._batch_pending: groups[item[2]].append(item) self._batch_pending.clear() for b_loader, items in groups.items(): keys = [it[0] for it in items] try: # assume batch_loader returns list in key order or dict results = await b_loader(keys) if isinstance(results, dict): res_map = results else: res_map = dict(zip(keys, results)) for it in items: val = res_map.get(it[0]) ttl_arg = _SENTINEL if it[3] is None else it[3] # L1 write (LRU _lock handles atomic + eviction) self.set(it[0], val, ttl=ttl_arg) # L2 write (async, non-blocking) self._write_to_remote(it[0], val, it[3]) it[1].set_result(val) except Exception as exc: for it in items: it[1].set_exception(exc) @staticmethod def _remote_key(key): """Convert an internal cache key to a string suitable for remote storage.""" return str(key) def _write_to_remote(self, key, value, ttl): """Fire-and-forget async write to L2 remote cache.""" if self.remote_cache is None: return use_ttl = self.default_ttl if ttl is None else ttl async def _do_write(): try: await self.remote_cache.set(self._remote_key(key), value, ttl=use_ttl) except Exception: logger.debug("remote_cache.set failed (background), degrading to L1", exc_info=True) try: asyncio.get_running_loop().create_task(_do_write()) except RuntimeError: pass # no running loop (e.g., called from sync context) def _delete_from_remote(self, key): """Fire-and-forget async delete from L2 remote cache.""" if self.remote_cache is None: return async def _do_delete(): try: await self.remote_cache.delete(self._remote_key(key)) except Exception: logger.debug("remote_cache.delete failed (background)", exc_info=True) try: asyncio.get_running_loop().create_task(_do_delete()) except RuntimeError: pass def _clear_remote(self): """Fire-and-forget async clear of L2 remote cache.""" if self.remote_cache is None: return async def _do_clear(): try: await self.remote_cache.clear() except Exception: logger.debug("remote_cache.clear failed (background)", exc_info=True) try: asyncio.get_running_loop().create_task(_do_clear()) except RuntimeError: pass
[docs] def set(self, key, value, ttl=_SENTINEL, write_remote=True): """Set a value in the cache. Writes to L1 synchronously, L2 async in background.""" if ttl is _SENTINEL: use_ttl = self.default_ttl else: use_ttl = ttl ttl_value = ( (_time.monotonic() + use_ttl) if use_ttl is not None else None ) self.cache._set(key, value, ttl_value) if write_remote: self._write_to_remote(key, value, use_ttl if ttl is not _SENTINEL else None)
[docs] def delete(self, key): """Delete a key from L1 and L2 cache.""" self.cache.pop(key, None) self._delete_from_remote(key)
[docs] def clear(self): """Clear L1 cache and reset metrics. Also clears L2 async.""" self.cache.clear() self.hits = 0 self.misses = 0 self._clear_remote()
[docs] def get_metrics(self): """Get cache metrics (hits, misses, size, hit_rate).""" hits = self.hits misses = self.misses total = hits + misses return { 'hits': hits, 'misses': misses, 'size': len(self.cache), 'hit_rate': (hits / total) if total > 0 else 0.0, }
[docs] async def warmup(self, keys_with_loaders): """Warmup: serial gets (each locks internally for hit/miss).""" for key, loader in keys_with_loaders.items(): await self.get(key, loader=loader)
[docs] def save_to_backup(self): """Persist current cache contents to the configured backup. Call this on application shutdown to survive restarts. Entries are stored with their remaining TTL (relative seconds). No-op if no backup is configured. """ if self.backup is None: return now = _time.monotonic() data = {} with self.cache._lock: for key in list(OrderedDict.keys(self.cache)): pair = OrderedDict.__getitem__(self.cache, key) value, expiration = pair if expiration is not None: remaining = expiration - now if remaining <= 0: continue # expired, skip else: remaining = None data[key] = (value, remaining) self.backup.save(data)
[docs] def load_from_backup(self): """Reload cache contents from the configured backup. Useful for manual reload without re-creating the cache object. No-op if no backup is configured. """ if self.backup is None: return raw = self.backup.load() now = _time.monotonic() with self.cache._lock: for key, (value, expiration) in raw.items(): if expiration is not None and expiration < now: continue OrderedDict.__setitem__(self.cache, key, (value, expiration))