""" engine/token_tracker.py ------------------------ Process-wide capture of Anthropic API token usage. Monkey-patches ``anthropic.Anthropic.__init__`` so every client instance — no matter where in the codebase it's constructed — gets its ``messages.create`` method wrapped to extract ``response.usage`` and persist one row per call to the ``token_usage`` table. Why patch instead of wrapping each call site: • 12+ ``messages.create(...)`` call sites across pipeline.py, content_planner.py, template_analyzer.py, etc — touching all of them is fragile and any new call site would silently miss tracking. • A single patch survives future call sites for free. Per-call rows include model id and cache tokens (prompt-cache hits/misses) so we can audit cost composition, not just totals. To link API calls to the job that issued them, the web layer calls ``set_active_job(job_id)`` at the start of each pipeline run (and clears it on exit). The patched wrapper reads the thread-local and stamps each row. This module replaces the legacy ``web/data/tokens.json`` writer that lived here pre-Phase-1. The migrated daily aggregates from that file live in the same ``token_usage`` table tagged ``model='aggregate-legacy'``. """ from __future__ import annotations import threading from datetime import datetime import anthropic # ── Job context (thread-local) ──────────────────────────────────────────────── # Each pipeline thread sets its own job_id at the start of a build so the # token rows can be linked back. Concurrent builds are isolated automatically. _ctx = threading.local() def set_active_job(job_id: str | None) -> None: """Set (or clear, with None) the active job_id for the current thread.""" _ctx.job_id = job_id def get_active_job() -> str | None: return getattr(_ctx, "job_id", None) # ── DB write (one row per messages.create call) ─────────────────────────────── def _record(model: str, input_tokens: int, output_tokens: int, cache_creation_tokens: int = 0, cache_read_tokens: int = 0) -> None: if not (input_tokens or output_tokens or cache_creation_tokens or cache_read_tokens): return # Local import to avoid module-load coupling — the engine doesn't have a # hard dependency on the web layer; the DB is just one persistence backend. try: from web.db import SessionLocal from web import models except Exception: # If the DB isn't available (e.g. test runs that don't init it), # silently drop the record rather than crashing the API call. return db = SessionLocal() try: db.add(models.TokenUsage( job_id = get_active_job(), timestamp = datetime.utcnow(), model = model or "unknown", input_tokens = int(input_tokens or 0), output_tokens = int(output_tokens or 0), cache_creation_tokens = int(cache_creation_tokens or 0), cache_read_tokens = int(cache_read_tokens or 0), )) db.commit() except Exception: # Tracking must never block a generation. db.rollback() finally: db.close() # ── Patch ───────────────────────────────────────────────────────────────────── # Guard against being imported twice — only patch the first time. if not getattr(anthropic.Anthropic, "_tk_patched", False): _orig_init = anthropic.Anthropic.__init__ def _patched_init(self, *args, **kwargs): _orig_init(self, *args, **kwargs) if not hasattr(self, "messages"): return _orig_create = self.messages.create def _tracked_create(*c_args, **c_kwargs): response = _orig_create(*c_args, **c_kwargs) try: u = getattr(response, "usage", None) if u is not None: model_id = ( c_kwargs.get("model") or (c_args[0] if c_args and isinstance(c_args[0], str) else None) or "unknown" ) _record( model = model_id, input_tokens = int(getattr(u, "input_tokens", 0) or 0), output_tokens = int(getattr(u, "output_tokens", 0) or 0), cache_creation_tokens= int(getattr(u, "cache_creation_input_tokens", 0) or 0), cache_read_tokens = int(getattr(u, "cache_read_input_tokens", 0) or 0), ) except Exception: # Never let tracking break a generation. pass return response self.messages.create = _tracked_create anthropic.Anthropic.__init__ = _patched_init anthropic.Anthropic._tk_patched = True # ── Public helpers (for callers / debugging) ────────────────────────────────── def totals() -> dict: """Return current totals (in + out + cache) and per-day breakdown, aggregated from the token_usage table. Same shape as the legacy tokens.json so any code or test that still inspects it keeps working. """ try: from sqlalchemy import func from web.db import SessionLocal from web import models except Exception: return {"total_input": 0, "total_output": 0, "days": {}} db = SessionLocal() try: rows = ( db.query( func.date(models.TokenUsage.timestamp).label("day"), func.sum(models.TokenUsage.input_tokens).label("inp"), func.sum(models.TokenUsage.output_tokens).label("out"), ) .group_by("day") .all() ) days = {r.day: {"input": int(r.inp or 0), "output": int(r.out or 0)} for r in rows} finally: db.close() return { "total_input": sum(d["input"] for d in days.values()), "total_output": sum(d["output"] for d in days.values()), "days": days, }