import asyncio
import logging
import time as _time
from collections import defaultdict
from collections import OrderedDict
from .lru import LRU
# sentinel for set ttl param to distinguish default vs explicit None
_SENTINEL = object()
_MISSING = object() # sentinel for get_if_present miss
logger = logging.getLogger("async_cache")
[docs]
class AsyncCache:
class _Cache(LRU):
def __init__(self, maxsize, items=None):
super().__init__(maxsize=maxsize)
if items:
# Bulk-load from backend without triggering eviction per-item
for k, v in items.items():
OrderedDict.__setitem__(self, k, v)
def __contains__(self, key):
if not OrderedDict.__contains__(self, key):
return False
key_expiration = OrderedDict.__getitem__(self, key)[1]
if key_expiration and key_expiration < _time.monotonic():
with self._lock:
# re-check under lock to avoid TOCTOU
if OrderedDict.__contains__(self, key):
OrderedDict.__delitem__(self, key)
return False
return True
def __getitem__(self, key):
value = super().__getitem__(key)[0]
return value
def get_if_present(self, key):
"""Single-lock contains + TTL check + value read + move_to_end.
Returns value on hit, _MISSING on miss/expired."""
with self._lock:
# O(1) dict contains check (bypasses LRU's __contains__)
if not OrderedDict.__contains__(self, key):
return _MISSING
# raw read (bypasses LRU's __getitem__ which also locks)
pair = OrderedDict.__getitem__(self, key)
expiration = pair[1]
if expiration and expiration < _time.monotonic():
# expired — remove under same lock
OrderedDict.__delitem__(self, key)
return _MISSING
# promote in LRU order
self.move_to_end(key)
return pair[0]
def _set(self, key, value, expiration):
# Use LRU's __setitem__ to ensure eviction logic runs
LRU.__setitem__(self, key, (value, expiration))
def __init__(self, maxsize=128, default_ttl=None, batch_window_ms=5, max_batch_size=100, backup=None, remote_cache=None):
self.maxsize = maxsize
self.default_ttl = default_ttl
self.batch_window_ms = batch_window_ms
self.max_batch_size = max_batch_size
self.backup = backup
self.remote_cache = remote_cache
# If a backup has persisted data, load it as warmup
init_items = None
if backup is not None:
raw = backup.load()
if raw:
# Filter expired entries; raw values are (value, expiration)
now = _time.monotonic()
init_items = {}
for k, (v, exp) in raw.items():
if exp is not None and exp < now:
continue
init_items[k] = (v, exp)
self.cache = self._Cache(maxsize=maxsize, items=init_items)
self._pending = {}
self._batch_pending = []
self._batch_lock = asyncio.Lock()
self._pending_lock = asyncio.Lock() # protects thundering herd (single loader pending) from races under concurrency
self._batch_timer = None
self.hits = 0
self.misses = 0
[docs]
async def get(self, key, loader=None, batch_loader=None, ttl=None):
"""Two-tier cache get: L1 (local) -> L2 (remote) -> loader.
Read pattern:
1. Check L1 (in-memory) — immediate
2. If miss and remote_cache configured, check L2 (remote)
3. If miss, call loader (with thundering herd protection)
4. Write result to L1 synchronously, L2 asynchronously
Thundering herd protection and batch loading work with remote_cache.
Redis errors are caught and logged — cache degrades to L1 only.
"""
# L1 hit path - single lock acquisition for contains+TTL+read+promote
result = self.cache.get_if_present(key)
if result is not _MISSING:
self.hits += 1
return result
# L1 miss — try L2 (remote cache) if configured
if self.remote_cache is not None:
try:
l2_value = await self.remote_cache.get(self._remote_key(key))
if l2_value is not None:
# L2 hit: populate L1 and return
self.hits += 1
ttl_arg = _SENTINEL if ttl is None else ttl
self.set(key, l2_value, ttl=ttl_arg, write_remote=False)
return l2_value
except Exception:
logger.debug("remote_cache.get failed, degrading to L1", exc_info=True)
# L1+L2 miss - count it
self.misses += 1
if loader is None and batch_loader is None:
return None
if loader is not None:
# single loader with herd protection
async with self._pending_lock:
if key in self._pending:
# waiter: future already set by leader
fut = self._pending[key]
is_leader = False
else:
# leader: create fut
fut = asyncio.Future()
self._pending[key] = fut
is_leader = True
if not is_leader:
# await result from leader
return await fut
# leader only: perform load (lock released to avoid serializing loads)
try:
value = await loader()
ttl_arg = _SENTINEL if ttl is None else ttl
# L1 write (synchronous, immediate)
self.set(key, value, ttl=ttl_arg)
fut.set_result(value)
return value
except Exception as exc:
fut.set_exception(exc)
fut.exception()
raise
finally:
# cleanup under lock
async with self._pending_lock:
self._pending.pop(key, None)
# batch_loader mode
return await self._batch_get(key, batch_loader, ttl)
async def _batch_get(self, key, batch_loader, ttl):
fut = asyncio.Future()
async with self._batch_lock:
self._batch_pending.append((key, fut, batch_loader, ttl))
if len(self._batch_pending) >= self.max_batch_size:
await self._flush_batch()
elif self._batch_timer is None:
self._batch_timer = asyncio.create_task(self._schedule_flush())
return await fut
async def _schedule_flush(self):
await asyncio.sleep(self.batch_window_ms / 1000.0)
async with self._batch_lock:
await self._flush_batch()
self._batch_timer = None
async def _flush_batch(self):
if not self._batch_pending:
return
# group by batch_loader (support mixed)
groups = defaultdict(list)
for item in self._batch_pending:
groups[item[2]].append(item)
self._batch_pending.clear()
for b_loader, items in groups.items():
keys = [it[0] for it in items]
try:
# assume batch_loader returns list in key order or dict
results = await b_loader(keys)
if isinstance(results, dict):
res_map = results
else:
res_map = dict(zip(keys, results))
for it in items:
val = res_map.get(it[0])
ttl_arg = _SENTINEL if it[3] is None else it[3]
# L1 write (LRU _lock handles atomic + eviction)
self.set(it[0], val, ttl=ttl_arg)
# L2 write (async, non-blocking)
self._write_to_remote(it[0], val, it[3])
it[1].set_result(val)
except Exception as exc:
for it in items:
it[1].set_exception(exc)
@staticmethod
def _remote_key(key):
"""Convert an internal cache key to a string suitable for remote storage."""
return str(key)
def _write_to_remote(self, key, value, ttl):
"""Fire-and-forget async write to L2 remote cache."""
if self.remote_cache is None:
return
use_ttl = self.default_ttl if ttl is None else ttl
async def _do_write():
try:
await self.remote_cache.set(self._remote_key(key), value, ttl=use_ttl)
except Exception:
logger.debug("remote_cache.set failed (background), degrading to L1", exc_info=True)
try:
asyncio.get_running_loop().create_task(_do_write())
except RuntimeError:
pass # no running loop (e.g., called from sync context)
def _delete_from_remote(self, key):
"""Fire-and-forget async delete from L2 remote cache."""
if self.remote_cache is None:
return
async def _do_delete():
try:
await self.remote_cache.delete(self._remote_key(key))
except Exception:
logger.debug("remote_cache.delete failed (background)", exc_info=True)
try:
asyncio.get_running_loop().create_task(_do_delete())
except RuntimeError:
pass
def _clear_remote(self):
"""Fire-and-forget async clear of L2 remote cache."""
if self.remote_cache is None:
return
async def _do_clear():
try:
await self.remote_cache.clear()
except Exception:
logger.debug("remote_cache.clear failed (background)", exc_info=True)
try:
asyncio.get_running_loop().create_task(_do_clear())
except RuntimeError:
pass
[docs]
def set(self, key, value, ttl=_SENTINEL, write_remote=True):
"""Set a value in the cache. Writes to L1 synchronously, L2 async in background."""
if ttl is _SENTINEL:
use_ttl = self.default_ttl
else:
use_ttl = ttl
ttl_value = (
(_time.monotonic() + use_ttl)
if use_ttl is not None
else None
)
self.cache._set(key, value, ttl_value)
if write_remote:
self._write_to_remote(key, value, use_ttl if ttl is not _SENTINEL else None)
[docs]
def delete(self, key):
"""Delete a key from L1 and L2 cache."""
self.cache.pop(key, None)
self._delete_from_remote(key)
[docs]
def clear(self):
"""Clear L1 cache and reset metrics. Also clears L2 async."""
self.cache.clear()
self.hits = 0
self.misses = 0
self._clear_remote()
[docs]
def get_metrics(self):
"""Get cache metrics (hits, misses, size, hit_rate)."""
hits = self.hits
misses = self.misses
total = hits + misses
return {
'hits': hits,
'misses': misses,
'size': len(self.cache),
'hit_rate': (hits / total) if total > 0 else 0.0,
}
[docs]
async def warmup(self, keys_with_loaders):
"""Warmup: serial gets (each locks internally for hit/miss)."""
for key, loader in keys_with_loaders.items():
await self.get(key, loader=loader)
[docs]
def save_to_backup(self):
"""Persist current cache contents to the configured backup.
Call this on application shutdown to survive restarts.
Entries are stored with their remaining TTL (relative seconds).
No-op if no backup is configured.
"""
if self.backup is None:
return
now = _time.monotonic()
data = {}
with self.cache._lock:
for key in list(OrderedDict.keys(self.cache)):
pair = OrderedDict.__getitem__(self.cache, key)
value, expiration = pair
if expiration is not None:
remaining = expiration - now
if remaining <= 0:
continue # expired, skip
else:
remaining = None
data[key] = (value, remaining)
self.backup.save(data)
[docs]
def load_from_backup(self):
"""Reload cache contents from the configured backup.
Useful for manual reload without re-creating the cache object.
No-op if no backup is configured.
"""
if self.backup is None:
return
raw = self.backup.load()
now = _time.monotonic()
with self.cache._lock:
for key, (value, expiration) in raw.items():
if expiration is not None and expiration < now:
continue
OrderedDict.__setitem__(self.cache, key, (value, expiration))