Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
4d8e004b5f fix: extend JSON repair to remaining json.loads sites in run_agent.py
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Successful in 42s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Nix / nix (ubuntu-latest) (pull_request) Failing after 4s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 36s
Tests / test (pull_request) Failing after 1h13m6s
Tests / e2e (pull_request) Successful in 1m32s
Nix / nix (macos-latest) (pull_request) Has been cancelled
Adds `repair_and_load_json()` to utils.py using the `json_repair` library
as a fallback when `json.loads()` fails. Replaces 8 non-hot-path json.loads
sites identified in issue #809:

- L2250: trajectory/sanitization message content parsing
- L2500: tool_call dict reconstruction in trajectory conversion
- L2535: tool_content parsing (JSON-like strings in tool responses)
- L2888: session log file loading (with warning on unrecoverable parse)
- L3119: todo content parsing in message processing
- L5963: vision result_json parsing
- L6761: memory flush tool call argument parsing
- L8300: cache serialization tool call args normalization

Each site uses an appropriate default ({} for tool args, None/continue for
content parsing) and a context label for debug tracing.

Fixes #809
2026-04-15 22:56:39 -04:00
4 changed files with 89 additions and 456 deletions

View File

@@ -106,7 +106,7 @@ from agent.trajectory import (
convert_scratchpad_to_think, has_incomplete_scratchpad,
save_trajectory as _save_trajectory_to_file,
)
from utils import atomic_json_write, env_var_enabled
from utils import atomic_json_write, env_var_enabled, repair_and_load_json
@@ -2246,9 +2246,8 @@ class AIAgent:
for msg in getattr(review_agent, "_session_messages", []):
if not isinstance(msg, dict) or msg.get("role") != "tool":
continue
try:
data = json.loads(msg.get("content", "{}"))
except (json.JSONDecodeError, TypeError):
data = repair_and_load_json(msg.get("content", "{}"), default=None, context="trajectory_content")
if data is None:
continue
if not data.get("success"):
continue
@@ -2496,13 +2495,13 @@ class AIAgent:
if not tool_call or not isinstance(tool_call, dict): continue
# Parse arguments - should always succeed since we validate during conversation
# but keep try-except as safety net
try:
arguments = json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
except json.JSONDecodeError:
# This shouldn't happen since we validate and retry during conversation,
# but if it does, log warning and use empty dict
logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}")
arguments = {}
raw_args = tool_call["function"]["arguments"]
if isinstance(raw_args, str):
arguments = repair_and_load_json(raw_args, default={}, context="trajectory_tool_call")
if arguments == {} and raw_args.strip() not in ("{}", ""):
logging.warning("Unexpected invalid JSON in trajectory conversion: %.100s", raw_args)
else:
arguments = raw_args
tool_call_json = {
"name": tool_call["function"]["name"],
@@ -2530,11 +2529,10 @@ class AIAgent:
# Try to parse tool content as JSON if it looks like JSON
tool_content = tool_msg["content"]
try:
if tool_content.strip().startswith(("{", "[")):
tool_content = json.loads(tool_content)
except (json.JSONDecodeError, AttributeError):
pass # Keep as string if not valid JSON
if isinstance(tool_content, str) and tool_content.strip().startswith(("{", "[")):
parsed = repair_and_load_json(tool_content, default=None, context="trajectory_tool_content")
if parsed is not None:
tool_content = parsed
tool_index = len(tool_responses)
tool_name = (
@@ -2885,14 +2883,21 @@ class AIAgent:
# with partial history and would otherwise clobber the full JSON log.
if self.session_log_file.exists():
try:
existing = json.loads(self.session_log_file.read_text(encoding="utf-8"))
existing_count = existing.get("message_count", len(existing.get("messages", [])))
if existing_count > len(cleaned):
logging.debug(
"Skipping session log overwrite: existing has %d messages, current has %d",
existing_count, len(cleaned),
)
return
existing = repair_and_load_json(
self.session_log_file.read_text(encoding="utf-8"),
default=None,
context="session_log_load",
)
if existing is None:
logging.warning("Session log at %s could not be parsed; allowing overwrite", self.session_log_file)
else:
existing_count = existing.get("message_count", len(existing.get("messages", [])))
if existing_count > len(cleaned):
logging.debug(
"Skipping session log overwrite: existing has %d messages, current has %d",
existing_count, len(cleaned),
)
return
except Exception:
pass # corrupted existing file — allow the overwrite
@@ -3115,13 +3120,12 @@ class AIAgent:
# Quick check: todo responses contain "todos" key
if '"todos"' not in content:
continue
try:
data = json.loads(content)
if "todos" in data and isinstance(data["todos"], list):
last_todo_response = data["todos"]
break
except (json.JSONDecodeError, TypeError):
data = repair_and_load_json(content, default=None, context="todo_content")
if data is None:
continue
if "todos" in data and isinstance(data["todos"], list):
last_todo_response = data["todos"]
break
if last_todo_response:
# Replay the items into the store (replace mode)
@@ -5960,7 +5964,7 @@ class AIAgent:
result_json = asyncio.run(
vision_analyze_tool(image_url=vision_source, user_prompt=analysis_prompt)
)
result = json.loads(result_json) if isinstance(result_json, str) else {}
result = repair_and_load_json(result_json, default={}, context="vision_result") if isinstance(result_json, str) else {}
description = (result.get("analysis") or "").strip()
except Exception as e:
description = f"Image analysis failed: {e}"
@@ -6758,7 +6762,7 @@ class AIAgent:
for tc in tool_calls:
if tc.function.name == "memory":
try:
args = json.loads(tc.function.arguments)
args = repair_and_load_json(tc.function.arguments, default={}, context="memory_flush")
flush_target = args.get("target", "memory")
from tools.memory_tool import memory_tool as _memory_tool
_memory_tool(
@@ -8297,14 +8301,15 @@ class AIAgent:
for tc in tcs:
if isinstance(tc, dict) and "function" in tc:
try:
args_obj = json.loads(tc["function"]["arguments"])
tc = {**tc, "function": {
**tc["function"],
"arguments": json.dumps(
args_obj, separators=(",", ":"),
sort_keys=True,
),
}}
args_obj = repair_and_load_json(tc["function"]["arguments"], default=None, context="cache_serialization")
if args_obj is not None:
tc = {**tc, "function": {
**tc["function"],
"arguments": json.dumps(
args_obj, separators=(",", ":"),
sort_keys=True,
),
}}
except Exception:
pass
new_tcs.append(tc)

View File

@@ -1,136 +0,0 @@
"""Tests for batch tool execution — Issue #749."""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.batch_executor import (
ToolSafety, ToolCall, BatchResult,
classify_tool_safety, classify_calls,
execute_batch_sync, get_tool_safety_report
)
class TestClassification:
def test_parallel_safe_read(self):
assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE
def test_sequential_write(self):
assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL
def test_destructive_terminal(self):
assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE
def test_unknown_defaults_sequential(self):
assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL
def test_prefix_match(self):
assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE
class TestClassifyCalls:
def test_classifies_multiple(self):
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "terminal", "arguments": "{}"},
]
result = classify_calls(calls)
assert len(result) == 3
assert result[0].safety == ToolSafety.PARALLEL_SAFE
assert result[1].safety == ToolSafety.SEQUENTIAL
assert result[2].safety == ToolSafety.DESTRUCTIVE
class TestBatchExecution:
def test_parallel_execution(self):
"""Parallel-safe calls should execute faster than sequential."""
import time
def slow_executor(name, args):
time.sleep(0.1)
return f"result_{name}"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_search", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should be faster than 0.3s (3 * 0.1) since parallel
assert duration < 0.25
assert result.parallel_count == 3
assert len(result.errors) == 0
def test_sequential_execution(self):
"""Sequential calls should execute one at a time."""
import time
def slow_executor(name, args):
time.sleep(0.05)
return f"result_{name}"
calls = [
{"name": "file_write", "arguments": "{}"},
{"name": "file_patch", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should take at least 0.1s (2 * 0.05) since sequential
assert duration >= 0.1
assert result.sequential_count == 2
def test_mixed_execution(self):
"""Mixed calls: parallel first, then sequential."""
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
def executor(name, args):
return f"result_{name}"
result = execute_batch_sync(calls, executor)
assert result.parallel_count == 2
assert result.sequential_count == 1
assert len(result.errors) == 0
def test_error_handling(self):
"""Errors in one call shouldn't stop others."""
def failing_executor(name, args):
if name == "file_write":
raise Exception("Write failed")
return "ok"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
]
result = execute_batch_sync(calls, failing_executor)
assert len(result.errors) == 1
assert "file_write" in result.errors[0]
class TestSafetyReport:
def test_report_format(self):
calls = [
ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1),
ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2),
]
report = get_tool_safety_report(calls)
assert "Parallel-safe: 1" in report
assert "Sequential: 1" in report
if __name__ == "__main__":
import pytest
pytest.main([__file__, "-v"])

View File

@@ -1,280 +0,0 @@
"""Batch tool execution with parallel safety checks.
Classifies tool calls as parallel-safe vs sequential and executes
parallel-safe calls concurrently while keeping destructive ops serialized.
Issue #749: feat: batch tool execution with parallel safety checks
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class ToolSafety(Enum):
"""Safety classification for tool calls."""
PARALLEL_SAFE = "parallel_safe" # Can run concurrently
SEQUENTIAL = "sequential" # Must run one at a time
DESTRUCTIVE = "destructive" # Destructive, needs approval
# Tool safety classifications
_TOOL_SAFETY: Dict[str, ToolSafety] = {
# Parallel-safe: reads, searches, non-destructive
"file_read": ToolSafety.PARALLEL_SAFE,
"file_search": ToolSafety.PARALLEL_SAFE,
"web_search": ToolSafety.PARALLEL_SAFE,
"web_extract": ToolSafety.PARALLEL_SAFE,
"browser_snapshot": ToolSafety.PARALLEL_SAFE,
"browser_vision": ToolSafety.PARALLEL_SAFE,
"browser_get_images": ToolSafety.PARALLEL_SAFE,
"skill_view": ToolSafety.PARALLEL_SAFE,
"memory_search": ToolSafety.PARALLEL_SAFE,
"memory_recall": ToolSafety.PARALLEL_SAFE,
"session_search": ToolSafety.PARALLEL_SAFE,
# Sequential: writes, edits, state changes
"file_write": ToolSafety.SEQUENTIAL,
"file_patch": ToolSafety.SEQUENTIAL,
"file_append": ToolSafety.SEQUENTIAL,
"browser_navigate": ToolSafety.SEQUENTIAL,
"browser_click": ToolSafety.SEQUENTIAL,
"browser_type": ToolSafety.SEQUENTIAL,
"browser_scroll": ToolSafety.SEQUENTIAL,
"memory_store": ToolSafety.SEQUENTIAL,
"memory_update": ToolSafety.SEQUENTIAL,
"cronjob": ToolSafety.SEQUENTIAL,
"send_message": ToolSafety.SEQUENTIAL,
# Destructive: needs approval
"terminal": ToolSafety.DESTRUCTIVE,
"execute_code": ToolSafety.DESTRUCTIVE,
"browser_execute_js": ToolSafety.DESTRUCTIVE,
"delegate_task": ToolSafety.DESTRUCTIVE,
}
@dataclass
class ToolCall:
"""A single tool call with metadata."""
name: str
args: Dict[str, Any]
call_id: str = ""
safety: ToolSafety = ToolSafety.SEQUENTIAL
result: Optional[Any] = None
error: Optional[str] = None
duration: float = 0.0
started_at: float = 0.0
completed_at: float = 0.0
@dataclass
class BatchResult:
"""Result of batch tool execution."""
calls: List[ToolCall] = field(default_factory=list)
parallel_count: int = 0
sequential_count: int = 0
total_duration: float = 0.0
errors: List[str] = field(default_factory=list)
def classify_tool_safety(tool_name: str) -> ToolSafety:
"""Classify a tool call's safety level."""
# Check exact match first
if tool_name in _TOOL_SAFETY:
return _TOOL_SAFETY[tool_name]
# Check prefix matches
for pattern, safety in _TOOL_SAFETY.items():
if tool_name.startswith(pattern):
return safety
# Default to sequential for unknown tools
return ToolSafety.SEQUENTIAL
def classify_calls(tool_calls: List[Dict[str, Any]]) -> List[ToolCall]:
"""Classify a list of tool calls by safety level."""
calls = []
for i, tc in enumerate(tool_calls):
name = tc.get("name", tc.get("function", {}).get("name", ""))
args = tc.get("arguments", tc.get("function", {}).get("arguments", {}))
if isinstance(args, str):
import json
try:
args = json.loads(args)
except Exception:
args = {}
call_id = tc.get("id", f"call_{i}")
safety = classify_tool_safety(name)
calls.append(ToolCall(
name=name,
args=args,
call_id=call_id,
safety=safety,
))
return calls
async def execute_parallel(
calls: List[ToolCall],
executor: Callable[[str, Dict[str, Any]], Any],
) -> List[ToolCall]:
"""Execute parallel-safe calls concurrently."""
async def run_call(call: ToolCall) -> ToolCall:
call.started_at = time.time()
try:
# Run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: executor(call.name, call.args),
)
call.result = result
except Exception as e:
call.error = str(e)
logger.error(f"Parallel call {call.name} failed: {e}")
finally:
call.completed_at = time.time()
call.duration = call.completed_at - call.started_at
return call
# Execute all parallel-safe calls concurrently
tasks = [run_call(call) for call in calls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle exceptions from gather
processed = []
for i, result in enumerate(results):
if isinstance(result, Exception):
calls[i].error = str(result)
calls[i].completed_at = time.time()
calls[i].duration = calls[i].completed_at - calls[i].started_at
processed.append(calls[i])
else:
processed.append(result)
return processed
async def execute_sequential(
calls: List[ToolCall],
executor: Callable[[str, Dict[str, Any]], Any],
) -> List[ToolCall]:
"""Execute sequential/destructive calls one at a time."""
for call in calls:
call.started_at = time.time()
try:
result = executor(call.name, call.args)
call.result = result
except Exception as e:
call.error = str(e)
logger.error(f"Sequential call {call.name} failed: {e}")
finally:
call.completed_at = time.time()
call.duration = call.completed_at - call.started_at
return calls
async def execute_batch(
tool_calls: List[Dict[str, Any]],
executor: Callable[[str, Dict[str, Any]], Any],
max_parallel: int = 5,
) -> BatchResult:
"""Execute a batch of tool calls with parallel safety checks.
Args:
tool_calls: List of tool call dicts (OpenAI format)
executor: Function to execute a single tool call (name, args) -> result
max_parallel: Maximum concurrent parallel calls
Returns:
BatchResult with all call results and timing info
"""
start_time = time.time()
# Classify all calls
calls = classify_calls(tool_calls)
# Split by safety level
parallel_calls = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE]
sequential_calls = [c for c in calls if c.safety != ToolSafety.PARALLEL_SAFE]
result = BatchResult(
calls=calls,
parallel_count=len(parallel_calls),
sequential_count=len(sequential_calls),
)
# Execute parallel calls concurrently
if parallel_calls:
logger.info(f"Executing {len(parallel_calls)} parallel-safe calls concurrently")
# Batch into chunks of max_parallel
for i in range(0, len(parallel_calls), max_parallel):
chunk = parallel_calls[i:i + max_parallel]
await execute_parallel(chunk, executor)
# Execute sequential calls one at a time
if sequential_calls:
logger.info(f"Executing {len(sequential_calls)} sequential calls")
await execute_sequential(sequential_calls, executor)
# Collect errors
for call in calls:
if call.error:
result.errors.append(f"{call.name}: {call.error}")
result.total_duration = time.time() - start_time
return result
def execute_batch_sync(
tool_calls: List[Dict[str, Any]],
executor: Callable[[str, Dict[str, Any]], Any],
max_parallel: int = 5,
) -> BatchResult:
"""Synchronous wrapper for execute_batch."""
return asyncio.run(execute_batch(tool_calls, executor, max_parallel))
def get_tool_safety_report(calls: List[ToolCall]) -> str:
"""Generate a human-readable safety report."""
parallel = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE]
sequential = [c for c in calls if c.safety == ToolSafety.SEQUENTIAL]
destructive = [c for c in calls if c.safety == ToolSafety.DESTRUCTIVE]
lines = ["Tool Safety Report:"]
lines.append(f" Parallel-safe: {len(parallel)}")
lines.append(f" Sequential: {len(sequential)}")
lines.append(f" Destructive: {len(destructive)}")
if parallel:
lines.append("\nParallel-safe calls:")
for c in parallel:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
if sequential:
lines.append("\nSequential calls:")
for c in sequential:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
if destructive:
lines.append("\nDestructive calls:")
for c in destructive:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
return "\n".join(lines)

View File

@@ -145,6 +145,50 @@ def safe_json_loads(text: str, default: Any = None) -> Any:
return default
def repair_and_load_json(text: str, default: Any = None, *, context: str = "") -> Any:
"""Parse JSON with automatic repair fallback.
Tries ``json.loads`` first. On failure, attempts to repair the string
using the ``json_repair`` library before falling back to *default*.
Logs a debug-level warning when repair is triggered so that callers can
observe silent-failure patterns without raising exceptions.
Args:
text: The JSON string to parse.
default: Value returned when both parse and repair fail.
context: Optional label included in the debug log (e.g. the call-site
name) to aid tracing.
Returns:
Parsed Python object, or *default* on unrecoverable failure.
"""
if not isinstance(text, str):
return default
try:
return json.loads(text)
except (json.JSONDecodeError, ValueError):
pass
try:
import json_repair # optional dependency
repaired = json_repair.repair_json(text, return_objects=True)
# json_repair returns "" when it cannot produce a valid structure.
# Guard against returning that sentinel as if it were a successful parse.
# Exception: if the original text was a JSON empty-string literal like `""`
# then "" is the correct parse result.
if repaired == "" and text.strip() not in ('""', "''"):
tag = f" [{context}]" if context else ""
logger.debug("repair_and_load_json%s: repair yielded empty string; returning default", tag)
return default
tag = f" [{context}]" if context else ""
logger.debug("repair_and_load_json%s: repaired malformed JSON (first 120 chars): %.120s", tag, text)
return repaired
except Exception as exc:
tag = f" [{context}]" if context else ""
logger.debug("repair_and_load_json%s: repair failed (%s); returning default", tag, exc)
return default
# ─── Environment Variable Helpers ─────────────────────────────────────────────