Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a61761d321 |
@@ -50,6 +50,78 @@ def sanitize_context(text: str) -> str:
|
||||
return _FENCE_TAG_RE.sub('', text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefetch filtering helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Meta-instruction debris that memory providers sometimes echo back.
|
||||
# These are prompts/instructions, not user-generated content.
|
||||
_META_INSTRUCTION_PATTERNS = [
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Focus on:\s*", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Note:\s*", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*System\s+(note|prompt|instruction):", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*You are\s+", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Please\s+(provide|respond|answer|write)", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Do not\s+", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Always\s+", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Consider\s+(the following|these|this)\b", re.IGNORECASE),
|
||||
re.compile(r"^\s*[\-\*]?\s*>?\s*Here\s+(is|are)\s+(some|the|a few)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _is_meta_instruction_line(line: str) -> bool:
|
||||
"""Return True if the line looks like a prompt/template instruction, not memory content."""
|
||||
for pat in _META_INSTRUCTION_PATTERNS:
|
||||
if pat.search(line):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_low_signal_line(line: str) -> bool:
|
||||
"""Return True for very short or content-free lines."""
|
||||
stripped = line.strip()
|
||||
# Empty or just punctuation/list marker
|
||||
if not stripped or stripped in {"-", "*", ">", "•", "—", "--"}:
|
||||
return True
|
||||
# Too short to be meaningful (< 15 chars after stripping markers)
|
||||
cleaned = re.sub(r"^[\-\*•>\s]+", "", stripped)
|
||||
if len(cleaned) < 15:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _filter_prefetch_lines(text: str) -> str:
|
||||
"""Filter and deduplicate prefetch result lines.
|
||||
|
||||
Removes:
|
||||
- exact duplicate lines
|
||||
- meta-instruction debris (prompts, templates)
|
||||
- very short / content-free lines
|
||||
|
||||
Returns cleaned text, preserving original line grouping.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
|
||||
seen: set = set()
|
||||
kept: list = []
|
||||
for line in text.splitlines(keepends=False):
|
||||
stripped = line.strip()
|
||||
# Deduplicate exact lines
|
||||
if stripped in seen:
|
||||
continue
|
||||
# Skip meta-instructions
|
||||
if _is_meta_instruction_line(line):
|
||||
continue
|
||||
# Skip low-signal lines
|
||||
if _is_low_signal_line(line):
|
||||
continue
|
||||
seen.add(stripped)
|
||||
kept.append(line)
|
||||
|
||||
return "\n".join(kept)
|
||||
|
||||
|
||||
def build_memory_context_block(raw_context: str) -> str:
|
||||
"""Wrap prefetched memory in a fenced block with system note.
|
||||
|
||||
@@ -180,7 +252,14 @@ class MemoryManager:
|
||||
"Memory provider '%s' prefetch failed (non-fatal): %s",
|
||||
provider.name, e,
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
raw = "\n\n".join(parts)
|
||||
if not raw:
|
||||
return ""
|
||||
# Apply line-level filtering: dedupe, strip meta-instructions,
|
||||
# remove very short fragments. This prevents noisy providers
|
||||
# (e.g. MemPalace transcript recall) from bloating context.
|
||||
filtered = _filter_prefetch_lines(raw)
|
||||
return filtered
|
||||
|
||||
def queue_prefetch_all(self, query: str, *, session_id: str = "") -> None:
|
||||
"""Queue background prefetch on all providers for the next turn."""
|
||||
|
||||
@@ -198,14 +198,14 @@ class TestMemoryManager:
|
||||
def test_prefetch_skips_empty(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("builtin")
|
||||
p1._prefetch_result = "Has memories"
|
||||
p1._prefetch_result = "This provider has meaningful memories with enough length"
|
||||
p2 = FakeMemoryProvider("external")
|
||||
p2._prefetch_result = ""
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.prefetch_all("query")
|
||||
assert result == "Has memories"
|
||||
assert result == "This provider has meaningful memories with enough length"
|
||||
|
||||
def test_queue_prefetch_all(self):
|
||||
mgr = MemoryManager()
|
||||
@@ -695,3 +695,92 @@ class TestMemoryContextFencing:
|
||||
fence_end = combined.index("</memory-context>")
|
||||
assert "Alice" in combined[fence_start:fence_end]
|
||||
assert combined.index("weather") < fence_start
|
||||
|
||||
|
||||
class TestPrefetchFiltering:
|
||||
"""Tests for _filter_prefetch_lines and related helpers."""
|
||||
|
||||
def test_deduplicates_exact_lines(self):
|
||||
from agent.memory_manager import _filter_prefetch_lines
|
||||
raw = "- This is line one with enough characters\n- This is line two with enough characters\n- This is line one with enough characters\n- This is line three with enough characters"
|
||||
result = _filter_prefetch_lines(raw)
|
||||
lines = [l for l in result.splitlines() if l.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "- This is line one with enough characters" in result
|
||||
assert "- This is line two with enough characters" in result
|
||||
assert "- This is line three with enough characters" in result
|
||||
|
||||
def test_removes_meta_instruction_debris(self):
|
||||
from agent.memory_manager import _filter_prefetch_lines
|
||||
raw = (
|
||||
"## Fleet Memories\n"
|
||||
"- > Focus on: was a non-trivial approach used\n"
|
||||
"- > Focus on: was a non-trivial approach used\n"
|
||||
"- Actual memory content about fleet ops\n"
|
||||
"- Note: this is just a note\n"
|
||||
)
|
||||
result = _filter_prefetch_lines(raw)
|
||||
assert "Focus on" not in result
|
||||
assert "Note:" not in result
|
||||
assert "Actual memory content about fleet ops" in result
|
||||
assert "Fleet Memories" in result
|
||||
|
||||
def test_removes_low_signal_short_lines(self):
|
||||
from agent.memory_manager import _filter_prefetch_lines
|
||||
raw = (
|
||||
"- \n"
|
||||
"- x\n"
|
||||
"- This is a meaningful memory entry with enough length\n"
|
||||
)
|
||||
result = _filter_prefetch_lines(raw)
|
||||
assert "- x" not in result
|
||||
assert "meaningful memory entry" in result
|
||||
|
||||
def test_preserves_structured_facts(self):
|
||||
from agent.memory_manager import _filter_prefetch_lines
|
||||
raw = (
|
||||
"## Local Facts (Hologram)\n"
|
||||
"- ALEXANDER: Prefers Gitea for reports and deliverables.\n"
|
||||
"- Telegram home channel is Timmy Time.\n"
|
||||
)
|
||||
result = _filter_prefetch_lines(raw)
|
||||
assert "ALEXANDER" in result
|
||||
assert "Gitea" in result
|
||||
assert "Telegram" in result
|
||||
|
||||
def test_is_meta_instruction_line(self):
|
||||
from agent.memory_manager import _is_meta_instruction_line
|
||||
assert _is_meta_instruction_line("- > Focus on: something") is True
|
||||
assert _is_meta_instruction_line("- Focus on: something") is True
|
||||
assert _is_meta_instruction_line("* Focus on: something") is True
|
||||
assert _is_meta_instruction_line("- Actual user memory content") is False
|
||||
assert _is_meta_instruction_line("ALEXANDER: Prefers Gitea") is False
|
||||
|
||||
def test_is_low_signal_line(self):
|
||||
from agent.memory_manager import _is_low_signal_line
|
||||
assert _is_low_signal_line("- ") is True
|
||||
assert _is_low_signal_line("*") is True
|
||||
assert _is_low_signal_line("- x") is True
|
||||
assert _is_low_signal_line("- Short line") is True
|
||||
assert _is_low_signal_line("- This is a long meaningful memory entry") is False
|
||||
|
||||
def test_prefetch_all_applies_filtering(self):
|
||||
from agent.memory_manager import MemoryManager
|
||||
mgr = MemoryManager()
|
||||
fake = FakeMemoryProvider(name="test")
|
||||
fake._prefetch_result = (
|
||||
"- > Focus on: was a non-trivial approach\n"
|
||||
"- > Focus on: was a non-trivial approach\n"
|
||||
"- Real memory fact\n"
|
||||
)
|
||||
mgr.add_provider(fake)
|
||||
result = mgr.prefetch_all("query")
|
||||
assert "Focus on" not in result
|
||||
assert "Real memory fact" in result
|
||||
assert result.count("Real memory fact") == 1
|
||||
|
||||
def test_empty_prefetch_returns_empty(self):
|
||||
from agent.memory_manager import _filter_prefetch_lines
|
||||
assert _filter_prefetch_lines("") == ""
|
||||
assert _filter_prefetch_lines(" ") == ""
|
||||
assert _filter_prefetch_lines("\n\n") == ""
|
||||
|
||||
@@ -1302,9 +1302,9 @@ class TestConcurrentToolExecution:
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_malformed_json_args_forces_sequential(self, agent):
|
||||
"""Non-dict tool arguments (e.g. JSON array) should fall back to sequential."""
|
||||
"""Unparseable tool arguments should fall back to sequential."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='[1, 2, 3]', call_id="c2")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments="NOT JSON {{{", call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
@@ -1384,9 +1384,10 @@ class TestConcurrentToolExecution:
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
call_count = [0]
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
# Deterministic failure based on tool_call_id to avoid race conditions
|
||||
if kwargs.get("tool_call_id") == "c1":
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise RuntimeError("boom")
|
||||
return "success"
|
||||
|
||||
|
||||
@@ -416,219 +416,3 @@ class TestEdgeCases:
|
||||
"""Verify max workers constant exists and is reasonable."""
|
||||
from run_agent import _MAX_TOOL_WORKERS
|
||||
assert 1 <= _MAX_TOOL_WORKERS <= 32
|
||||
|
||||
|
||||
# ── Integration Tests: AIAgent Concurrent Execution ───────────────────────────
|
||||
|
||||
class TestAIAgentConcurrentExecution:
|
||||
"""Exercise _execute_tool_calls_concurrent through an AIAgent instance."""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
"""Minimal AIAgent with mocked OpenAI client and tool loading."""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from run_agent import AIAgent
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search", "read_file")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
def _mock_assistant_msg(self, tool_calls=None):
|
||||
from types import SimpleNamespace
|
||||
return SimpleNamespace(content="", tool_calls=tool_calls)
|
||||
|
||||
def _mock_tool_call(self, name, arguments, call_id):
|
||||
from types import SimpleNamespace
|
||||
return SimpleNamespace(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=SimpleNamespace(name=name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
def test_two_tool_batch_executes_concurrently(self, agent):
|
||||
"""2-tool parallel batch: all execute, results ordered, 100% pass."""
|
||||
tc1 = self._mock_tool_call("read_file", {"path": "a.txt"}, "c1")
|
||||
tc2 = self._mock_tool_call("read_file", {"path": "b.txt"}, "c2")
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"file": args.get("path", ""), "content": f"content_of_{args.get('path', '')}"})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert "a.txt" in messages[0]["content"]
|
||||
assert "b.txt" in messages[1]["content"]
|
||||
|
||||
def test_three_tool_batch_executes_concurrently(self, agent):
|
||||
"""3-tool parallel batch: all execute, results ordered, 100% pass."""
|
||||
tcs = [
|
||||
self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"query": args.get("query", ""), "results": [f"result_{args.get('query', '')}"]})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 3
|
||||
for i, tc in enumerate(tcs):
|
||||
assert messages[i]["tool_call_id"] == tc.id
|
||||
assert f"q{i}" in messages[i]["content"]
|
||||
|
||||
def test_four_tool_batch_executes_concurrently(self, agent):
|
||||
"""4-tool parallel batch: all execute, results ordered, 100% pass."""
|
||||
tcs = [
|
||||
self._mock_tool_call("read_file", {"path": f"file{i}.txt"}, f"c{i}")
|
||||
for i in range(4)
|
||||
]
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"path": args.get("path", ""), "size": 100})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 4
|
||||
for i, tc in enumerate(tcs):
|
||||
assert messages[i]["tool_call_id"] == tc.id
|
||||
assert f"file{i}.txt" in messages[i]["content"]
|
||||
|
||||
def test_mixed_read_and_search_batch(self, agent):
|
||||
"""read_file + search_files: safe parallel, different scopes."""
|
||||
tc1 = self._mock_tool_call("read_file", {"path": "config.yaml"}, "c1")
|
||||
tc2 = self._mock_tool_call("web_search", {"query": "provider"}, "c2")
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"tool": name, "args": args})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert "config.yaml" in messages[0]["content"]
|
||||
assert "provider" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_pass_rate_report(self, agent):
|
||||
"""Simulate 2/3/4-tool batches and report pass rate."""
|
||||
batch_sizes = [2, 3, 4]
|
||||
pass_rates = {}
|
||||
|
||||
for size in batch_sizes:
|
||||
tcs = [
|
||||
self._mock_tool_call("web_search", {"query": f"q{i}"}, f"c{i}")
|
||||
for i in range(size)
|
||||
]
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"ok": True, "query": args.get("query", "")})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
passed = sum(1 for m in messages if "ok" in m.get("content", ""))
|
||||
pass_rates[size] = passed / size if size > 0 else 0.0
|
||||
|
||||
for size, rate in pass_rates.items():
|
||||
assert rate == 1.0, f"Expected 100% pass rate for {size}-tool batch, got {rate:.0%}"
|
||||
|
||||
def test_gemma4_style_two_read_files(self, agent):
|
||||
"""Gemma 4 may issue two reads simultaneously — verify both returned."""
|
||||
tc1 = self._mock_tool_call("read_file", {"path": "src/main.py"}, "c1")
|
||||
tc2 = self._mock_tool_call("read_file", {"path": "src/utils.py"}, "c2")
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"content": f"# {args['path']}\nprint('hello')"})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "main.py" in messages[0]["content"]
|
||||
assert "utils.py" in messages[1]["content"]
|
||||
|
||||
def test_gemma4_style_three_reads(self, agent):
|
||||
"""Gemma 4 may issue 3 reads for different files — all returned."""
|
||||
tcs = [
|
||||
self._mock_tool_call("read_file", {"path": f"mod{i}.py"}, f"c{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=tcs)
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return json.dumps({"content": f"# {args['path']}"})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 3
|
||||
for i in range(3):
|
||||
assert f"mod{i}.py" in messages[i]["content"]
|
||||
|
||||
def test_mixed_safe_and_write_tools_parallel(self, agent):
|
||||
"""Mix of read (safe) and write (path-scoped) on different paths — parallel."""
|
||||
tc1 = self._mock_tool_call("read_file", {"path": "input.txt"}, "c1")
|
||||
tc2 = self._mock_tool_call("write_file", {"path": "output.txt", "content": "x"}, "c2")
|
||||
tc3 = self._mock_tool_call("read_file", {"path": "config.txt"}, "c3")
|
||||
mock_msg = self._mock_assistant_msg(tool_calls=[tc1, tc2, tc3])
|
||||
messages = []
|
||||
|
||||
call_order = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
call_order.append(name)
|
||||
return json.dumps({"tool": name, "path": args.get("path", "")})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 3
|
||||
# Results ordered by tool call ID, not completion order
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert messages[2]["tool_call_id"] == "c3"
|
||||
# All three should have executed
|
||||
assert len(call_order) == 3
|
||||
|
||||
Reference in New Issue
Block a user