fix: deep review — prefix matching, tool_calls extraction, query perf, serialization

Issues found and fixed during deep code path review:

1. CRITICAL: Prefix matching returned wrong prices for dated model names
   - 'gpt-4o-mini-2024-07-18' matched gpt-4o ($2.50) instead of gpt-4o-mini ($0.15)
   - Same for o3-mini→o3 (9x), gpt-4.1-mini→gpt-4.1 (5x), gpt-4.1-nano→gpt-4.1 (20x)
   - Fix: use longest-match-wins strategy instead of first-match
   - Removed dangerous key.startswith(bare) reverse matching

2. CRITICAL: Top Tools section was empty for CLI sessions
   - run_agent.py doesn't set tool_name on tool response messages (pre-existing)
   - Insights now also extracts tool names from tool_calls JSON on assistant
     messages, which IS populated for all sessions
   - Uses max() merge strategy to avoid double-counting between sources

3. SELECT * replaced with explicit column list
   - Skips system_prompt and model_config blobs (can be thousands of chars)
   - Reduces memory and I/O for large session counts

4. Sets in overview dict converted to sorted lists
   - models_with_pricing / models_without_pricing were Python sets
   - Sets aren't JSON-serializable — would crash json.dumps()

5. Negative duration guard
   - end > start check prevents negative durations from clock drift

6. Model breakdown sort fallback
   - When all tokens are 0, now sorts by session count instead of arbitrary order

7. Removed unused timedelta import

Added 6 new tests: dated model pricing (4), tool_calls JSON extraction,
JSON serialization safety. Total: 69 tests.
This commit is contained in:
teknium1
2026-03-06 14:50:57 -08:00
parent 75f523f5c0
commit 585f8528b2
2 changed files with 169 additions and 19 deletions

View File

@@ -16,9 +16,10 @@ Usage:
print(engine.format_terminal(report))
"""
import json
import time
from collections import Counter, defaultdict
from datetime import datetime, timedelta
from datetime import datetime
from typing import Any, Dict, List, Optional
# =========================================================================
@@ -82,12 +83,18 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
if bare in MODEL_PRICING:
return MODEL_PRICING[bare]
# Fuzzy prefix match
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
best_match = None
best_len = 0
for key, price in MODEL_PRICING.items():
if bare.startswith(key) or key.startswith(bare):
return price
if bare.startswith(key) and len(key) > best_len:
best_match = price
best_len = len(key)
if best_match:
return best_match
# Keyword heuristics
# Keyword heuristics (checked in most-specific-first order)
if "opus" in bare:
return {"input": 15.00, "output": 75.00}
if "sonnet" in bare:
@@ -211,26 +218,39 @@ class InsightsEngine:
# Data gathering (SQL queries)
# =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window."""
if source:
cursor = self._conn.execute(
"""SELECT * FROM sessions
WHERE started_at >= ? AND source = ?
ORDER BY started_at DESC""",
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ? AND source = ?
ORDER BY started_at DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT * FROM sessions
WHERE started_at >= ?
ORDER BY started_at DESC""",
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ?
ORDER BY started_at DESC""",
(cutoff,),
)
return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
"""Get tool call counts from messages."""
"""Get tool call counts from messages.
Uses two sources:
1. tool_name column on 'tool' role messages (set by gateway)
2. tool_calls JSON on 'assistant' role messages (covers CLI where
tool_name is not populated on tool responses)
"""
tool_counts = Counter()
# Source 1: explicit tool_name on tool response messages
if source:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
@@ -253,7 +273,64 @@ class InsightsEngine:
ORDER BY count DESC""",
(cutoff,),
)
return [dict(row) for row in cursor.fetchall()]
for row in cursor.fetchall():
tool_counts[row["tool_name"]] += row["count"]
# Source 2: extract from tool_calls JSON on assistant messages
# (covers CLI sessions where tool_name is NULL on tool responses)
if source:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff, source),
)
else:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff,),
)
tool_calls_counts = Counter()
for row in cursor2.fetchall():
try:
calls = row["tool_calls"]
if isinstance(calls, str):
calls = json.loads(calls)
if isinstance(calls, list):
for call in calls:
func = call.get("function", {}) if isinstance(call, dict) else {}
name = func.get("name")
if name:
tool_calls_counts[name] += 1
except (json.JSONDecodeError, TypeError, AttributeError):
continue
# Merge: prefer tool_name source, supplement with tool_calls source
# for tools not already counted
if not tool_counts and tool_calls_counts:
# No tool_name data at all — use tool_calls exclusively
tool_counts = tool_calls_counts
elif tool_counts and tool_calls_counts:
# Both sources have data — use whichever has the higher count per tool
# (they may overlap, so take the max to avoid double-counting)
all_tools = set(tool_counts) | set(tool_calls_counts)
merged = Counter()
for tool in all_tools:
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
"""Get aggregate message statistics."""
@@ -314,12 +391,12 @@ class InsightsEngine:
else:
models_without_pricing.add(display)
# Session duration stats
# Session duration stats (guard against negative durations from clock drift)
durations = []
for s in sessions:
start = s.get("started_at")
end = s.get("ended_at")
if start and end:
if start and end and end > start:
durations.append(end - start)
total_hours = sum(durations) / 3600 if durations else 0
@@ -347,8 +424,8 @@ class InsightsEngine:
"tool_messages": message_stats.get("tool_messages") or 0,
"date_range_start": date_range_start,
"date_range_end": date_range_end,
"models_with_pricing": models_with_pricing,
"models_without_pricing": models_without_pricing,
"models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": sorted(models_without_pricing),
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
@@ -377,7 +454,8 @@ class InsightsEngine:
{"model": model, **data}
for model, data in model_data.items()
]
result.sort(key=lambda x: x["total_tokens"], reverse=True)
# Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]: