import asyncio
import datetime
import threading
from collections import defaultdict
from .lru import LRU
# sentinel for set ttl param to distinguish default vs explicit None
_SENTINEL = object()
[docs]
class AsyncCache:
class _Cache(LRU):
def __init__(self, maxsize):
super().__init__(maxsize=maxsize)
def __contains__(self, key):
if key not in self.keys():
return False
else:
key_expiration = super().__getitem__(key)[1]
if key_expiration and key_expiration < datetime.datetime.now():
del self[key]
return False
else:
return True
def __getitem__(self, key):
value = super().__getitem__(key)[0]
return value
def _set(self, key, value, expiration):
# Use LRU's __setitem__ to ensure eviction logic runs
LRU.__setitem__(self, key, (value, expiration))
def __init__(self, maxsize=128, default_ttl=None, batch_window_ms=5, max_batch_size=100):
self.maxsize = maxsize
self.default_ttl = default_ttl
self.batch_window_ms = batch_window_ms
self.max_batch_size = max_batch_size
self.cache = self._Cache(maxsize=maxsize)
self._pending = {}
self._batch_pending = []
self._batch_lock = asyncio.Lock()
self._pending_lock = asyncio.Lock() # protects thundering herd (single loader pending) from races under concurrency
self._batch_timer = None
self._metrics_lock = threading.Lock() # protects hits/misses counters
self.hits = 0
self.misses = 0
[docs]
async def get(self, key, loader=None, batch_loader=None, ttl=None):
"""Get from cache, loader on miss, with thundering herd protection for concurrent misses on same key.
Uses _pending_lock to avoid race conditions when multiple async tasks (e.g., HTTP requests) miss simultaneously.
Only one loader executes; others await the future result. Batch mode for multi-key.
LRU ops protected by RLock in lru.LRU (prevents eviction races in parallel re-runs/hits near maxsize).
"""
# hit path - LRU __getitem__ is protected by RLock internally
if key in self.cache:
with self._metrics_lock:
self.hits += 1
return self.cache[key]
# cache miss - count it
with self._metrics_lock:
self.misses += 1
if loader is None and batch_loader is None:
return None
if loader is not None:
# single loader with herd protection
async with self._pending_lock:
if key in self._pending:
# waiter: future already set by leader
fut = self._pending[key]
is_leader = False
else:
# leader: create fut
fut = asyncio.Future()
self._pending[key] = fut
is_leader = True
if not is_leader:
# await result from leader
return await fut
# leader only: perform load (lock released to avoid serializing loads)
try:
value = await loader()
ttl_arg = _SENTINEL if ttl is None else ttl
# set (LRU handles lock/evict internally)
self.set(key, value, ttl=ttl_arg)
fut.set_result(value)
return value
except Exception as exc:
fut.set_exception(exc)
fut.exception()
raise
finally:
# cleanup under lock
async with self._pending_lock:
self._pending.pop(key, None)
# batch_loader mode
return await self._batch_get(key, batch_loader, ttl)
async def _batch_get(self, key, batch_loader, ttl):
fut = asyncio.Future()
async with self._batch_lock:
self._batch_pending.append((key, fut, batch_loader, ttl))
if len(self._batch_pending) >= self.max_batch_size:
await self._flush_batch()
elif self._batch_timer is None:
self._batch_timer = asyncio.create_task(self._schedule_flush())
return await fut
async def _schedule_flush(self):
await asyncio.sleep(self.batch_window_ms / 1000.0)
async with self._batch_lock:
await self._flush_batch()
self._batch_timer = None
async def _flush_batch(self):
if not self._batch_pending:
return
# group by batch_loader (support mixed)
groups = defaultdict(list)
for item in self._batch_pending:
groups[item[2]].append(item)
self._batch_pending.clear()
for b_loader, items in groups.items():
keys = [it[0] for it in items]
try:
# assume batch_loader returns list in key order or dict
results = await b_loader(keys)
if isinstance(results, dict):
res_map = results
else:
res_map = dict(zip(keys, results))
for it in items:
val = res_map.get(it[0])
ttl_arg = _SENTINEL if it[3] is None else it[3]
# set (LRU _lock handles atomic + eviction)
self.set(it[0], val, ttl=ttl_arg)
it[1].set_result(val)
except Exception as exc:
for it in items:
it[1].set_exception(exc)
[docs]
def set(self, key, value, ttl=_SENTINEL):
"""Set a value in the cache. LRU __setitem__ is protected by RLock internally."""
if ttl is _SENTINEL:
use_ttl = self.default_ttl
else:
use_ttl = ttl
ttl_value = (
(datetime.datetime.now() + datetime.timedelta(seconds=use_ttl))
if use_ttl is not None
else None
)
self.cache._set(key, value, ttl_value)
[docs]
def delete(self, key):
"""Delete a key from the cache. LRU pop is protected by RLock internally."""
self.cache.pop(key, None)
[docs]
def clear(self):
"""Clear the cache and reset metrics. LRU clear is protected by RLock internally."""
self.cache.clear()
with self._metrics_lock:
self.hits = 0
self.misses = 0
[docs]
def get_metrics(self):
"""Get cache metrics (hits, misses, size, hit_rate)."""
with self._metrics_lock:
hits = self.hits
misses = self.misses
total = hits + misses
return {
'hits': hits,
'misses': misses,
'size': len(self.cache),
'hit_rate': (hits / total) if total > 0 else 0.0,
}
[docs]
async def warmup(self, keys_with_loaders):
"""Warmup: serial gets (each locks internally for hit/miss)."""
for key, loader in keys_with_loaders.items():
await self.get(key, loader=loader)