Source code for agent_cache.core

"""
Agent-aware caching layer for AI agent workflows.

Provides:
- AgentCache: decorator for caching async read-tool results
- AgentCacheInvalidator: decorator for write-tools that invalidate cached reads
- AgentCacheSession: async context manager with per-session caching and loop detection
"""

import contextvars
import logging
import time as _time
import uuid

from cache import AsyncCache
from cache.async_cache import _MISSING as _CACHE_MISSING
from cache.key import KEY, make_key

logger = logging.getLogger("agent_cache")

# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------


[docs] class AgentLoopDetectedError(Exception): """Raised when an agent execution loop is detected (on_loop='raise').""" pass
class _LoopShortCircuit(Exception): """Internal: signals short-circuit on loop detection. Never propagates.""" pass # --------------------------------------------------------------------------- # Global state # --------------------------------------------------------------------------- _active_session: contextvars.ContextVar = contextvars.ContextVar( "_active_session", default=None ) _global_caches: dict = {} # resource -> AsyncCache _global_metrics: dict = { "hits": 0, "misses": 0, "invalidations": 0, "loop_detections": 0, }
[docs] def get_metrics() -> dict: """Return global aggregate metrics across all sessions.""" total = _global_metrics["hits"] + _global_metrics["misses"] return { **_global_metrics, "hit_rate": _global_metrics["hits"] / total if total > 0 else 0.0, }
def _reset_global_state(): """Reset all global caches and metrics. **Testing only.**""" _global_caches.clear() _global_metrics.update( {"hits": 0, "misses": 0, "invalidations": 0, "loop_detections": 0} ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _get_global_cache(resource: str, maxsize: int = 128, backup=None, remote_cache=None) -> AsyncCache: if resource not in _global_caches: _global_caches[resource] = AsyncCache(maxsize=maxsize, default_ttl=None, backup=backup, remote_cache=remote_cache) return _global_caches[resource] # --------------------------------------------------------------------------- # AgentCacheSession # ---------------------------------------------------------------------------
[docs] class AgentCacheSession: """Async context manager providing session-scoped caching and loop detection. Usage:: async with AgentCacheSession(loop_detection=True, max_tool_repeats=5) as session: await some_cached_tool(...) """ def __init__( self, session_id=None, loop_detection=True, max_tool_repeats=5, max_execution_depth=50, on_loop="raise", ): self.session_id = session_id or str(uuid.uuid4()) self.loop_detection = loop_detection self.max_tool_repeats = max_tool_repeats self.max_execution_depth = max_execution_depth self.on_loop = on_loop self._session_caches: dict = {} # resource -> {key: (value, expiry)} self._execution_trace: list = [] # [(tool_name, args_key), ...] self._token = None self.metrics: dict = { "hits": 0, "misses": 0, "invalidations": 0, "loop_detections": 0, } # -- context manager ---------------------------------------------------- async def __aenter__(self): self._token = _active_session.set(self) return self async def __aexit__(self, exc_type, exc_val, exc_tb): _active_session.reset(self._token) self._execution_trace.clear() self._session_caches.clear() return False # do not suppress exceptions # -- public helpers -----------------------------------------------------
[docs] def get_execution_trace(self): """Return a copy of the execution trace.""" return list(self._execution_trace)
[docs] def get_metrics(self) -> dict: total = self.metrics["hits"] + self.metrics["misses"] return { **self.metrics, "hit_rate": self.metrics["hits"] / total if total > 0 else 0.0, }
# -- loop detection internals ------------------------------------------- def _record_tool_call(self, tool_name: str, args_key): """Append to trace and run loop checks. May raise.""" self._execution_trace.append((tool_name, args_key)) if self.loop_detection: self._check_loops(tool_name, args_key) def _check_loops(self, tool_name: str, args_key): trace = self._execution_trace # 1. Execution depth if len(trace) > self.max_execution_depth: self._handle_loop( f"Execution depth exceeded: {len(trace)} > {self.max_execution_depth}" ) return # 2. Repeated identical calls (same tool + same args) call = (tool_name, args_key) count = sum(1 for entry in trace if entry == call) if count > self.max_tool_repeats: self._handle_loop( f"Tool '{tool_name}' called {count} times with same arguments " f"(limit: {self.max_tool_repeats})" ) return # 3. Oscillation / cycle detection (tool names only, 2+ distinct tools) n = len(trace) if n >= 4: tool_names = [t for t, _ in trace] max_cycle = min(n // 2, 10) for cycle_len in range(2, max_cycle + 1): cycle = tool_names[-cycle_len:] prev = tool_names[-2 * cycle_len : -cycle_len] if cycle == prev and len(set(cycle)) >= 2: self._handle_loop( f"Cycle detected: {' -> '.join(cycle)}" ) return def _handle_loop(self, message: str): self.metrics["loop_detections"] += 1 _global_metrics["loop_detections"] += 1 if self.on_loop == "raise": raise AgentLoopDetectedError(message) elif self.on_loop == "warn": logger.warning("Agent loop detected: %s", message) elif self.on_loop == "short_circuit": raise _LoopShortCircuit(message)
# --------------------------------------------------------------------------- # AgentCache decorator # ---------------------------------------------------------------------------
[docs] class AgentCache: """Decorator that caches async tool results with resource tagging and scope. Example:: @AgentCache(resource="cart", scope="global", ttl=60) async def get_cart(user_id): ... """ def __init__( self, resource: str, scope: str = "global", ttl=None, maxsize: int = 128, skip_args: int = 0, backup=None, remote_cache=None, ): self.resource = resource self.scope = scope self.ttl = ttl self.maxsize = maxsize self.skip_args = skip_args self.backup = backup self.remote_cache = remote_cache def __call__(self, func): resource = self.resource scope = self.scope ttl = self.ttl maxsize = self.maxsize skip_args = self.skip_args backup = self.backup remote_cache = self.remote_cache async def wrapper(*args, **kwargs): session = _active_session.get() key = make_key(func, args, kwargs, skip_args) # --- trace recording + loop detection --- if session: tool_name = getattr(func, "__qualname__", func.__name__) call_args = args[skip_args:] if skip_args else args args_key = KEY(call_args, kwargs) try: session._record_tool_call(tool_name, args_key) except _LoopShortCircuit: return _try_get_cached(scope, resource, key, session) # --- caching --- if scope == "session": if session is None: return await func(*args, **kwargs) return await _session_cache_get( session, resource, key, func, args, kwargs, ttl ) else: # global global_cache = _get_global_cache(resource, maxsize, backup=backup, remote_cache=remote_cache) return await _global_cache_get( global_cache, key, func, args, kwargs, ttl, session ) wrapper.__name__ = func.__name__ wrapper.__qualname__ = getattr(func, "__qualname__", func.__name__) wrapper.__wrapped__ = func wrapper.resource = resource wrapper.scope = scope return wrapper
# --------------------------------------------------------------------------- # AgentCacheInvalidator decorator # ---------------------------------------------------------------------------
[docs] class AgentCacheInvalidator: """Decorator for write/mutation tools that invalidate related cached reads. Invalidation happens **before** the mutation executes so that no reader ever sees stale data, even if the mutation raises. Example:: @AgentCacheInvalidator(resource="cart", scope="global") async def add_to_cart(user_id, item): ... :param resource: The resource tag to invalidate (must match an ``@AgentCache``). :param scope: ``"global"`` or ``"session"``. :param clear_all: If True (default), clear all entries for the resource. If False, requires ``key_fn`` to target a specific key. :param key_fn: Optional ``(args, kwargs) -> cache_key`` for selective invalidation when ``clear_all=False``. :param skip_args: **Removed.** Passing a non-zero value raises ``TypeError``. Use ``key_fn`` instead to map mutation args to cache-key args. """ def __init__(self, resource: str, scope: str = "global", skip_args: int = 0, clear_all: bool = True, key_fn=None): if skip_args != 0: raise TypeError( "skip_args is no longer supported on invalidators because it " "can silently produce the wrong cache key. Use key_fn to " "map the mutation function's arguments to the cached " "function's arguments instead." ) self.resource = resource self.scope = scope self.clear_all = clear_all self.key_fn = key_fn def __call__(self, func): resource = self.resource scope = self.scope clear_all = self.clear_all key_fn = self.key_fn async def wrapper(*args, **kwargs): session = _active_session.get() # --- trace recording + loop detection --- if session: tool_name = getattr(func, "__qualname__", func.__name__) args_key = KEY(args, kwargs) try: session._record_tool_call(tool_name, args_key) except _LoopShortCircuit: return None # --- invalidate BEFORE the mutation so no reader sees stale # data after the mutation starts, even if mutation raises --- if clear_all: if scope == "global" and resource in _global_caches: _global_caches[resource].clear() if scope == "session" and session and resource in session._session_caches: session._session_caches[resource].clear() elif key_fn is not None: cache_key = key_fn(args, kwargs) if scope == "global" and resource in _global_caches: _global_caches[resource].delete(cache_key) if scope == "session" and session and resource in session._session_caches: session._session_caches[resource].pop(cache_key, None) # --- metrics --- if session: session.metrics["invalidations"] += 1 _global_metrics["invalidations"] += 1 # --- execute the mutation --- result = await func(*args, **kwargs) return result wrapper.__name__ = func.__name__ wrapper.__qualname__ = getattr(func, "__qualname__", func.__name__) wrapper.__wrapped__ = func wrapper.resource = resource wrapper.scope = scope return wrapper
# --------------------------------------------------------------------------- # Internal cache helpers # --------------------------------------------------------------------------- def _try_get_cached(scope, resource, key, session): """Return a cached value for short-circuit handling, or None.""" if scope == "session" and session: rc = session._session_caches.get(resource, {}) if key in rc: val, exp = rc[key] if exp is None or exp > _time.monotonic(): return val elif scope == "global": gc = _global_caches.get(resource) if gc: result = gc.cache.get_if_present(key) if result is not _CACHE_MISSING: return result return None async def _session_cache_get(session, resource, key, func, args, kwargs, ttl): """Session-scoped cache get with loader-on-miss.""" resource_cache = session._session_caches.setdefault(resource, {}) if key in resource_cache: val, exp = resource_cache[key] if exp is None or exp > _time.monotonic(): session.metrics["hits"] += 1 _global_metrics["hits"] += 1 return val del resource_cache[key] session.metrics["misses"] += 1 _global_metrics["misses"] += 1 value = await func(*args, **kwargs) exp = _time.monotonic() + ttl if ttl else None resource_cache[key] = (value, exp) return value async def _global_cache_get(global_cache, key, func, args, kwargs, ttl, session): """Global-scoped cache get with loader-on-miss, leveraging AsyncCache.""" result = global_cache.cache.get_if_present(key) if result is not _CACHE_MISSING: if session: session.metrics["hits"] += 1 _global_metrics["hits"] += 1 return result if session: session.metrics["misses"] += 1 _global_metrics["misses"] += 1 async def loader(): return await func(*args, **kwargs) return await global_cache.get(key, loader=loader, ttl=ttl)