"""
Agent-aware caching layer for AI agent workflows.
Provides:
- AgentCache: decorator for caching async read-tool results
- AgentCacheInvalidator: decorator for write-tools that invalidate cached reads
- AgentCacheSession: async context manager with per-session caching and loop detection
"""
import contextvars
import logging
import time as _time
import uuid
from cache import AsyncCache
from cache.async_cache import _MISSING as _CACHE_MISSING
from cache.key import KEY, make_key
logger = logging.getLogger("agent_cache")
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
[docs]
class AgentLoopDetectedError(Exception):
"""Raised when an agent execution loop is detected (on_loop='raise')."""
pass
class _LoopShortCircuit(Exception):
"""Internal: signals short-circuit on loop detection. Never propagates."""
pass
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
_active_session: contextvars.ContextVar = contextvars.ContextVar(
"_active_session", default=None
)
_global_caches: dict = {} # resource -> AsyncCache
_global_metrics: dict = {
"hits": 0,
"misses": 0,
"invalidations": 0,
"loop_detections": 0,
}
[docs]
def get_metrics() -> dict:
"""Return global aggregate metrics across all sessions."""
total = _global_metrics["hits"] + _global_metrics["misses"]
return {
**_global_metrics,
"hit_rate": _global_metrics["hits"] / total if total > 0 else 0.0,
}
def _reset_global_state():
"""Reset all global caches and metrics. **Testing only.**"""
_global_caches.clear()
_global_metrics.update(
{"hits": 0, "misses": 0, "invalidations": 0, "loop_detections": 0}
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_global_cache(resource: str, maxsize: int = 128, backup=None, remote_cache=None) -> AsyncCache:
if resource not in _global_caches:
_global_caches[resource] = AsyncCache(maxsize=maxsize, default_ttl=None, backup=backup, remote_cache=remote_cache)
return _global_caches[resource]
# ---------------------------------------------------------------------------
# AgentCacheSession
# ---------------------------------------------------------------------------
[docs]
class AgentCacheSession:
"""Async context manager providing session-scoped caching and loop detection.
Usage::
async with AgentCacheSession(loop_detection=True, max_tool_repeats=5) as session:
await some_cached_tool(...)
"""
def __init__(
self,
session_id=None,
loop_detection=True,
max_tool_repeats=5,
max_execution_depth=50,
on_loop="raise",
):
self.session_id = session_id or str(uuid.uuid4())
self.loop_detection = loop_detection
self.max_tool_repeats = max_tool_repeats
self.max_execution_depth = max_execution_depth
self.on_loop = on_loop
self._session_caches: dict = {} # resource -> {key: (value, expiry)}
self._execution_trace: list = [] # [(tool_name, args_key), ...]
self._token = None
self.metrics: dict = {
"hits": 0,
"misses": 0,
"invalidations": 0,
"loop_detections": 0,
}
# -- context manager ----------------------------------------------------
async def __aenter__(self):
self._token = _active_session.set(self)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
_active_session.reset(self._token)
self._execution_trace.clear()
self._session_caches.clear()
return False # do not suppress exceptions
# -- public helpers -----------------------------------------------------
[docs]
def get_execution_trace(self):
"""Return a copy of the execution trace."""
return list(self._execution_trace)
[docs]
def get_metrics(self) -> dict:
total = self.metrics["hits"] + self.metrics["misses"]
return {
**self.metrics,
"hit_rate": self.metrics["hits"] / total if total > 0 else 0.0,
}
# -- loop detection internals -------------------------------------------
def _record_tool_call(self, tool_name: str, args_key):
"""Append to trace and run loop checks. May raise."""
self._execution_trace.append((tool_name, args_key))
if self.loop_detection:
self._check_loops(tool_name, args_key)
def _check_loops(self, tool_name: str, args_key):
trace = self._execution_trace
# 1. Execution depth
if len(trace) > self.max_execution_depth:
self._handle_loop(
f"Execution depth exceeded: {len(trace)} > {self.max_execution_depth}"
)
return
# 2. Repeated identical calls (same tool + same args)
call = (tool_name, args_key)
count = sum(1 for entry in trace if entry == call)
if count > self.max_tool_repeats:
self._handle_loop(
f"Tool '{tool_name}' called {count} times with same arguments "
f"(limit: {self.max_tool_repeats})"
)
return
# 3. Oscillation / cycle detection (tool names only, 2+ distinct tools)
n = len(trace)
if n >= 4:
tool_names = [t for t, _ in trace]
max_cycle = min(n // 2, 10)
for cycle_len in range(2, max_cycle + 1):
cycle = tool_names[-cycle_len:]
prev = tool_names[-2 * cycle_len : -cycle_len]
if cycle == prev and len(set(cycle)) >= 2:
self._handle_loop(
f"Cycle detected: {' -> '.join(cycle)}"
)
return
def _handle_loop(self, message: str):
self.metrics["loop_detections"] += 1
_global_metrics["loop_detections"] += 1
if self.on_loop == "raise":
raise AgentLoopDetectedError(message)
elif self.on_loop == "warn":
logger.warning("Agent loop detected: %s", message)
elif self.on_loop == "short_circuit":
raise _LoopShortCircuit(message)
# ---------------------------------------------------------------------------
# AgentCache decorator
# ---------------------------------------------------------------------------
[docs]
class AgentCache:
"""Decorator that caches async tool results with resource tagging and scope.
Example::
@AgentCache(resource="cart", scope="global", ttl=60)
async def get_cart(user_id):
...
"""
def __init__(
self,
resource: str,
scope: str = "global",
ttl=None,
maxsize: int = 128,
skip_args: int = 0,
backup=None,
remote_cache=None,
):
self.resource = resource
self.scope = scope
self.ttl = ttl
self.maxsize = maxsize
self.skip_args = skip_args
self.backup = backup
self.remote_cache = remote_cache
def __call__(self, func):
resource = self.resource
scope = self.scope
ttl = self.ttl
maxsize = self.maxsize
skip_args = self.skip_args
backup = self.backup
remote_cache = self.remote_cache
async def wrapper(*args, **kwargs):
session = _active_session.get()
key = make_key(func, args, kwargs, skip_args)
# --- trace recording + loop detection ---
if session:
tool_name = getattr(func, "__qualname__", func.__name__)
call_args = args[skip_args:] if skip_args else args
args_key = KEY(call_args, kwargs)
try:
session._record_tool_call(tool_name, args_key)
except _LoopShortCircuit:
return _try_get_cached(scope, resource, key, session)
# --- caching ---
if scope == "session":
if session is None:
return await func(*args, **kwargs)
return await _session_cache_get(
session, resource, key, func, args, kwargs, ttl
)
else: # global
global_cache = _get_global_cache(resource, maxsize, backup=backup, remote_cache=remote_cache)
return await _global_cache_get(
global_cache, key, func, args, kwargs, ttl, session
)
wrapper.__name__ = func.__name__
wrapper.__qualname__ = getattr(func, "__qualname__", func.__name__)
wrapper.__wrapped__ = func
wrapper.resource = resource
wrapper.scope = scope
return wrapper
# ---------------------------------------------------------------------------
# AgentCacheInvalidator decorator
# ---------------------------------------------------------------------------
[docs]
class AgentCacheInvalidator:
"""Decorator for write/mutation tools that invalidate related cached reads.
Invalidation happens **before** the mutation executes so that no reader
ever sees stale data, even if the mutation raises.
Example::
@AgentCacheInvalidator(resource="cart", scope="global")
async def add_to_cart(user_id, item):
...
:param resource: The resource tag to invalidate (must match an ``@AgentCache``).
:param scope: ``"global"`` or ``"session"``.
:param clear_all: If True (default), clear all entries for the resource.
If False, requires ``key_fn`` to target a specific key.
:param key_fn: Optional ``(args, kwargs) -> cache_key`` for selective
invalidation when ``clear_all=False``.
:param skip_args: **Removed.** Passing a non-zero value raises ``TypeError``.
Use ``key_fn`` instead to map mutation args to cache-key args.
"""
def __init__(self, resource: str, scope: str = "global",
skip_args: int = 0, clear_all: bool = True, key_fn=None):
if skip_args != 0:
raise TypeError(
"skip_args is no longer supported on invalidators because it "
"can silently produce the wrong cache key. Use key_fn to "
"map the mutation function's arguments to the cached "
"function's arguments instead."
)
self.resource = resource
self.scope = scope
self.clear_all = clear_all
self.key_fn = key_fn
def __call__(self, func):
resource = self.resource
scope = self.scope
clear_all = self.clear_all
key_fn = self.key_fn
async def wrapper(*args, **kwargs):
session = _active_session.get()
# --- trace recording + loop detection ---
if session:
tool_name = getattr(func, "__qualname__", func.__name__)
args_key = KEY(args, kwargs)
try:
session._record_tool_call(tool_name, args_key)
except _LoopShortCircuit:
return None
# --- invalidate BEFORE the mutation so no reader sees stale
# data after the mutation starts, even if mutation raises ---
if clear_all:
if scope == "global" and resource in _global_caches:
_global_caches[resource].clear()
if scope == "session" and session and resource in session._session_caches:
session._session_caches[resource].clear()
elif key_fn is not None:
cache_key = key_fn(args, kwargs)
if scope == "global" and resource in _global_caches:
_global_caches[resource].delete(cache_key)
if scope == "session" and session and resource in session._session_caches:
session._session_caches[resource].pop(cache_key, None)
# --- metrics ---
if session:
session.metrics["invalidations"] += 1
_global_metrics["invalidations"] += 1
# --- execute the mutation ---
result = await func(*args, **kwargs)
return result
wrapper.__name__ = func.__name__
wrapper.__qualname__ = getattr(func, "__qualname__", func.__name__)
wrapper.__wrapped__ = func
wrapper.resource = resource
wrapper.scope = scope
return wrapper
# ---------------------------------------------------------------------------
# Internal cache helpers
# ---------------------------------------------------------------------------
def _try_get_cached(scope, resource, key, session):
"""Return a cached value for short-circuit handling, or None."""
if scope == "session" and session:
rc = session._session_caches.get(resource, {})
if key in rc:
val, exp = rc[key]
if exp is None or exp > _time.monotonic():
return val
elif scope == "global":
gc = _global_caches.get(resource)
if gc:
result = gc.cache.get_if_present(key)
if result is not _CACHE_MISSING:
return result
return None
async def _session_cache_get(session, resource, key, func, args, kwargs, ttl):
"""Session-scoped cache get with loader-on-miss."""
resource_cache = session._session_caches.setdefault(resource, {})
if key in resource_cache:
val, exp = resource_cache[key]
if exp is None or exp > _time.monotonic():
session.metrics["hits"] += 1
_global_metrics["hits"] += 1
return val
del resource_cache[key]
session.metrics["misses"] += 1
_global_metrics["misses"] += 1
value = await func(*args, **kwargs)
exp = _time.monotonic() + ttl if ttl else None
resource_cache[key] = (value, exp)
return value
async def _global_cache_get(global_cache, key, func, args, kwargs, ttl, session):
"""Global-scoped cache get with loader-on-miss, leveraging AsyncCache."""
result = global_cache.cache.get_if_present(key)
if result is not _CACHE_MISSING:
if session:
session.metrics["hits"] += 1
_global_metrics["hits"] += 1
return result
if session:
session.metrics["misses"] += 1
_global_metrics["misses"] += 1
async def loader():
return await func(*args, **kwargs)
return await global_cache.get(key, loader=loader, ttl=ttl)