diff --git a/environments/agent_loop.py b/environments/agent_loop.py index 11a8a01f3..ba2db0b57 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -21,6 +21,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set from model_tools import handle_function_call +from tools.terminal_tool import get_active_env +from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget # Thread pool for running sync tool calls that internally use asyncio.run() # (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate @@ -446,8 +448,17 @@ class HermesAgentLoop: except (json.JSONDecodeError, TypeError): pass - # Add tool response to conversation tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id + try: + tool_result = maybe_persist_tool_result( + content=tool_result, + tool_name=tool_name, + tool_use_id=tc_id, + env=get_active_env(self.task_id), + ) + except Exception: + pass # Persistence is best-effort in eval path + messages.append( { "role": "tool", @@ -456,6 +467,13 @@ class HermesAgentLoop: } ) + try: + num_tcs = len(assistant_msg.tool_calls) + if num_tcs > 0: + enforce_turn_budget(messages[-num_tcs:], env=get_active_env(self.task_id)) + except Exception: + pass + turn_elapsed = _time.monotonic() - turn_start logger.info( "[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs", diff --git a/run_agent.py b/run_agent.py index 22928bb18..49f36da41 100644 --- a/run_agent.py +++ b/run_agent.py @@ -66,7 +66,8 @@ from model_tools import ( handle_function_call, check_toolset_requirements, ) -from tools.terminal_tool import cleanup_vm +from tools.terminal_tool import cleanup_vm, get_active_env +from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget from tools.interrupt import set_interrupt as _set_interrupt from tools.browser_tool import cleanup_browser @@ -411,63 +412,6 @@ def _strip_budget_warnings_from_history(messages: list) -> None: # Large tool result handler — save oversized output to temp file # ========================================================================= -# Threshold at which tool results are saved to a file instead of kept inline. -# 100K chars ≈ 25K tokens — generous for any reasonable output but prevents -# catastrophic context explosions. -_LARGE_RESULT_CHARS = 100_000 - -# How many characters of the original result to include as an inline preview -# so the model has immediate context about what the tool returned. -_LARGE_RESULT_PREVIEW_CHARS = 1_500 - - -def _save_oversized_tool_result(function_name: str, function_result: str) -> str: - """Replace oversized tool results with a file reference + preview. - - When a tool returns more than ``_LARGE_RESULT_CHARS`` characters, the full - content is written to a temporary file under ``HERMES_HOME/cache/tool_responses/`` - and the result sent to the model is replaced with: - • a brief head preview (first ``_LARGE_RESULT_PREVIEW_CHARS`` chars) - • the file path so the model can use ``read_file`` / ``search_files`` - - Falls back to destructive truncation if the file write fails. - """ - original_len = len(function_result) - if original_len <= _LARGE_RESULT_CHARS: - return function_result - - # Build the target directory - try: - response_dir = os.path.join(get_hermes_home(), "cache", "tool_responses") - os.makedirs(response_dir, exist_ok=True) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - # Sanitize tool name for use in filename - safe_name = re.sub(r"[^\w\-]", "_", function_name)[:40] - filename = f"{safe_name}_{timestamp}.txt" - filepath = os.path.join(response_dir, filename) - - with open(filepath, "w", encoding="utf-8") as f: - f.write(function_result) - - preview = function_result[:_LARGE_RESULT_PREVIEW_CHARS] - return ( - f"{preview}\n\n" - f"[Large tool response: {original_len:,} characters total — " - f"only the first {_LARGE_RESULT_PREVIEW_CHARS:,} shown above. " - f"Full output saved to: {filepath}\n" - f"Use read_file or search_files on that path to access the rest.]" - ) - except Exception as exc: - # Fall back to destructive truncation if file write fails - logger.warning("Failed to save large tool result to file: %s", exc) - return ( - function_result[:_LARGE_RESULT_CHARS] - + f"\n\n[Truncated: tool response was {original_len:,} chars, " - f"exceeding the {_LARGE_RESULT_CHARS:,} char limit. " - f"File save failed: {exc}]" - ) - class AIAgent: """ @@ -6262,15 +6206,17 @@ class AIAgent: except Exception as cb_err: logging.debug(f"Tool complete callback error: {cb_err}") - # Save oversized results to file instead of destructive truncation - function_result = _save_oversized_tool_result(name, function_result) + function_result = maybe_persist_tool_result( + content=function_result, + tool_name=name, + tool_use_id=tc.id, + env=get_active_env(effective_task_id), + ) - # Discover subdirectory context files from tool arguments subdir_hints = self._subdirectory_hints.check_tool_call(name, args) if subdir_hints: function_result += subdir_hints - # Append tool result message in order tool_msg = { "role": "tool", "content": function_result, @@ -6278,6 +6224,12 @@ class AIAgent: } messages.append(tool_msg) + # ── Per-turn aggregate budget enforcement ───────────────────────── + num_tools = len(parsed_calls) + if num_tools > 0: + turn_tool_msgs = messages[-num_tools:] + enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id)) + # ── Budget pressure injection ──────────────────────────────────── budget_warning = self._get_budget_warning(api_call_count) if budget_warning and messages and messages[-1].get("role") == "tool": @@ -6562,8 +6514,12 @@ class AIAgent: except Exception as cb_err: logging.debug(f"Tool complete callback error: {cb_err}") - # Save oversized results to file instead of destructive truncation - function_result = _save_oversized_tool_result(function_name, function_result) + function_result = maybe_persist_tool_result( + content=function_result, + tool_name=function_name, + tool_use_id=tool_call.id, + env=get_active_env(effective_task_id), + ) # Discover subdirectory context files from tool arguments subdir_hints = self._subdirectory_hints.check_tool_call(function_name, function_args) @@ -6601,6 +6557,11 @@ class AIAgent: if self.tool_delay > 0 and i < len(assistant_message.tool_calls): time.sleep(self.tool_delay) + # ── Per-turn aggregate budget enforcement ───────────────────────── + num_tools_seq = len(assistant_message.tool_calls) + if num_tools_seq > 0: + enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id)) + # ── Budget pressure injection ───────────────────────────────── # After all tool calls in this turn are processed, check if we're # approaching max_iterations. If so, inject a warning into the LAST diff --git a/tests/run_agent/test_large_tool_result.py b/tests/run_agent/test_large_tool_result.py deleted file mode 100644 index ef51f2fe5..000000000 --- a/tests/run_agent/test_large_tool_result.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Tests for _save_oversized_tool_result() — the large tool response handler. - -When a tool returns more than _LARGE_RESULT_CHARS characters, the full content -is saved to a file and the model receives a preview + file path instead. -""" - -import os -import re - -import pytest - -from run_agent import ( - _save_oversized_tool_result, - _LARGE_RESULT_CHARS, - _LARGE_RESULT_PREVIEW_CHARS, -) - - -class TestSaveOversizedToolResult: - """Unit tests for the large tool result handler.""" - - def test_small_result_returned_unchanged(self): - """Results under the threshold pass through untouched.""" - small = "x" * 1000 - assert _save_oversized_tool_result("terminal", small) is small - - def test_exactly_at_threshold_returned_unchanged(self): - """Results exactly at the threshold pass through.""" - exact = "y" * _LARGE_RESULT_CHARS - assert _save_oversized_tool_result("terminal", exact) is exact - - def test_oversized_result_saved_to_file(self, tmp_path, monkeypatch): - """Results over the threshold are written to a file.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - big = "A" * (_LARGE_RESULT_CHARS + 500) - result = _save_oversized_tool_result("terminal", big) - - # Should contain the preview - assert result.startswith("A" * _LARGE_RESULT_PREVIEW_CHARS) - # Should mention the file path - assert "Full output saved to:" in result - # Should mention original size - assert f"{len(big):,}" in result - - # Extract the file path and verify the file exists with full content - match = re.search(r"Full output saved to: (.+?)\n", result) - assert match, f"No file path found in result: {result[:300]}" - filepath = match.group(1) - assert os.path.isfile(filepath) - with open(filepath, "r", encoding="utf-8") as f: - saved = f.read() - assert saved == big - assert len(saved) == _LARGE_RESULT_CHARS + 500 - - def test_file_placed_in_cache_tool_responses(self, tmp_path, monkeypatch): - """Saved file lives under HERMES_HOME/cache/tool_responses/.""" - hermes_home = str(tmp_path / ".hermes") - monkeypatch.setenv("HERMES_HOME", hermes_home) - os.makedirs(hermes_home, exist_ok=True) - - big = "B" * (_LARGE_RESULT_CHARS + 1) - result = _save_oversized_tool_result("web_search", big) - - match = re.search(r"Full output saved to: (.+?)\n", result) - filepath = match.group(1) - expected_dir = os.path.join(hermes_home, "cache", "tool_responses") - assert filepath.startswith(expected_dir) - - def test_filename_contains_tool_name(self, tmp_path, monkeypatch): - """The saved filename includes a sanitized version of the tool name.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - big = "C" * (_LARGE_RESULT_CHARS + 1) - result = _save_oversized_tool_result("browser_navigate", big) - - match = re.search(r"Full output saved to: (.+?)\n", result) - filename = os.path.basename(match.group(1)) - assert filename.startswith("browser_navigate_") - assert filename.endswith(".txt") - - def test_tool_name_sanitized(self, tmp_path, monkeypatch): - """Special characters in tool names are replaced in the filename.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - big = "D" * (_LARGE_RESULT_CHARS + 1) - result = _save_oversized_tool_result("mcp:some/weird tool", big) - - match = re.search(r"Full output saved to: (.+?)\n", result) - filename = os.path.basename(match.group(1)) - # No slashes or colons in filename - assert "/" not in filename - assert ":" not in filename - - def test_fallback_on_write_failure(self, tmp_path, monkeypatch): - """When file write fails, falls back to destructive truncation.""" - # Point HERMES_HOME to a path that will fail (file, not directory) - bad_path = str(tmp_path / "not_a_dir.txt") - with open(bad_path, "w") as f: - f.write("I'm a file, not a directory") - monkeypatch.setenv("HERMES_HOME", bad_path) - - big = "E" * (_LARGE_RESULT_CHARS + 50_000) - result = _save_oversized_tool_result("terminal", big) - - # Should still contain data (fallback truncation) - assert len(result) > 0 - assert result.startswith("E" * 1000) - # Should mention the failure - assert "File save failed" in result - # Should be truncated to approximately _LARGE_RESULT_CHARS + error msg - assert len(result) < len(big) - - def test_preview_length_capped(self, tmp_path, monkeypatch): - """The inline preview is capped at _LARGE_RESULT_PREVIEW_CHARS.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - # Use distinct chars so we can measure the preview - big = "Z" * (_LARGE_RESULT_CHARS + 5000) - result = _save_oversized_tool_result("terminal", big) - - # The preview section is the content before the "[Large tool response:" marker - marker_pos = result.index("[Large tool response:") - preview_section = result[:marker_pos].rstrip() - assert len(preview_section) == _LARGE_RESULT_PREVIEW_CHARS - - def test_guidance_message_mentions_tools(self, tmp_path, monkeypatch): - """The replacement message tells the model how to access the file.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - big = "F" * (_LARGE_RESULT_CHARS + 1) - result = _save_oversized_tool_result("terminal", big) - - assert "read_file" in result - assert "search_files" in result - - def test_empty_result_passes_through(self): - """Empty strings are not oversized.""" - assert _save_oversized_tool_result("terminal", "") == "" - - def test_unicode_content_preserved(self, tmp_path, monkeypatch): - """Unicode content is fully preserved in the saved file.""" - monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) - os.makedirs(tmp_path / ".hermes", exist_ok=True) - - # Mix of ASCII and multi-byte unicode to exceed threshold - unit = "Hello 世界! 🎉 " * 100 # ~1400 chars per repeat - big = unit * ((_LARGE_RESULT_CHARS // len(unit)) + 1) - assert len(big) > _LARGE_RESULT_CHARS - - result = _save_oversized_tool_result("terminal", big) - match = re.search(r"Full output saved to: (.+?)\n", result) - filepath = match.group(1) - - with open(filepath, "r", encoding="utf-8") as f: - saved = f.read() - assert saved == big diff --git a/tests/tools/test_tool_result_storage.py b/tests/tools/test_tool_result_storage.py new file mode 100644 index 000000000..7c757027a --- /dev/null +++ b/tests/tools/test_tool_result_storage.py @@ -0,0 +1,494 @@ +"""Tests for tools/tool_result_storage.py -- 3-layer tool result persistence.""" + +import pytest +from unittest.mock import MagicMock, patch + +from tools.tool_result_storage import ( + DEFAULT_MAX_RESULT_SIZE_CHARS, + HEREDOC_MARKER, + MAX_TURN_BUDGET_CHARS, + PERSISTED_OUTPUT_TAG, + PERSISTED_OUTPUT_CLOSING_TAG, + PREVIEW_SIZE_CHARS, + STORAGE_DIR, + _build_persisted_message, + _extract_raw_output, + _heredoc_marker, + _write_to_sandbox, + enforce_turn_budget, + generate_preview, + maybe_persist_tool_result, +) + + +# ── generate_preview ────────────────────────────────────────────────── + +class TestGeneratePreview: + def test_short_content_unchanged(self): + text = "short result" + preview, has_more = generate_preview(text) + assert preview == text + assert has_more is False + + def test_long_content_truncated(self): + text = "x" * 5000 + preview, has_more = generate_preview(text, max_chars=2000) + assert len(preview) <= 2000 + assert has_more is True + + def test_truncates_at_newline_boundary(self): + # 1500 chars + newline + 600 chars (past halfway) + text = "a" * 1500 + "\n" + "b" * 600 + preview, has_more = generate_preview(text, max_chars=2000) + assert preview == "a" * 1500 + "\n" + assert has_more is True + + def test_ignores_early_newline(self): + # Newline at position 100, well before halfway of 2000 + text = "a" * 100 + "\n" + "b" * 3000 + preview, has_more = generate_preview(text, max_chars=2000) + assert len(preview) == 2000 + assert has_more is True + + def test_empty_content(self): + preview, has_more = generate_preview("") + assert preview == "" + assert has_more is False + + def test_exact_boundary(self): + text = "x" * PREVIEW_SIZE_CHARS + preview, has_more = generate_preview(text) + assert preview == text + assert has_more is False + + +# ── _extract_raw_output ──────────────────────────────────────────────── + +class TestExtractRawOutput: + def test_extracts_output_from_terminal_json(self): + import json + content = json.dumps({"output": "hello world\nline2", "exit_code": 0, "error": None}) + assert _extract_raw_output(content) == "hello world\nline2" + + def test_passes_through_non_json(self): + assert _extract_raw_output("plain text output") == "plain text output" + + def test_passes_through_json_without_output_key(self): + import json + content = json.dumps({"result": "something", "status": "ok"}) + assert _extract_raw_output(content) == content + + def test_extracts_large_output(self): + import json + big = "x\n" * 30_000 + content = json.dumps({"output": big, "exit_code": 0, "error": None}) + assert _extract_raw_output(content) == big + + +# ── _heredoc_marker ─────────────────────────────────────────────────── + +class TestHeredocMarker: + def test_default_marker_when_no_collision(self): + assert _heredoc_marker("normal content") == HEREDOC_MARKER + + def test_uuid_marker_on_collision(self): + content = f"some text with {HEREDOC_MARKER} embedded" + marker = _heredoc_marker(content) + assert marker != HEREDOC_MARKER + assert marker.startswith("HERMES_PERSIST_") + assert marker not in content + + +# ── _write_to_sandbox ───────────────────────────────────────────────── + +class TestWriteToSandbox: + def test_success(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + result = _write_to_sandbox("hello world", "/tmp/hermes-results/abc.txt", env) + assert result is True + env.execute.assert_called_once() + cmd = env.execute.call_args[0][0] + assert "mkdir -p" in cmd + assert "hello world" in cmd + assert HEREDOC_MARKER in cmd + + def test_failure_returns_false(self): + env = MagicMock() + env.execute.return_value = {"output": "error", "returncode": 1} + result = _write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env) + assert result is False + + def test_heredoc_collision_uses_uuid_marker(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = f"text with {HEREDOC_MARKER} inside" + _write_to_sandbox(content, "/tmp/hermes-results/abc.txt", env) + cmd = env.execute.call_args[0][0] + # The default marker should NOT be used as the delimiter + lines = cmd.split("\n") + # The first and last lines contain the actual delimiter + assert HEREDOC_MARKER not in lines[0].split("<<")[1] + + def test_timeout_passed(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + _write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env) + assert env.execute.call_args[1]["timeout"] == 30 + + +# ── _build_persisted_message ────────────────────────────────────────── + +class TestBuildPersistedMessage: + def test_structure(self): + msg = _build_persisted_message( + preview="first 100 chars...", + has_more=True, + original_size=50_000, + file_path="/tmp/hermes-results/test123.txt", + ) + assert msg.startswith(PERSISTED_OUTPUT_TAG) + assert msg.endswith(PERSISTED_OUTPUT_CLOSING_TAG) + assert "50,000 characters" in msg + assert "/tmp/hermes-results/test123.txt" in msg + assert "read_file" in msg + assert "first 100 chars..." in msg + assert "..." in msg # has_more indicator + + def test_no_ellipsis_when_complete(self): + msg = _build_persisted_message( + preview="complete content", + has_more=False, + original_size=16, + file_path="/tmp/hermes-results/x.txt", + ) + # Should not have the trailing "..." indicator before closing tag + lines = msg.strip().split("\n") + assert lines[-2] != "..." + + def test_large_size_shows_mb(self): + msg = _build_persisted_message( + preview="x", + has_more=True, + original_size=2_000_000, + file_path="/tmp/hermes-results/big.txt", + ) + assert "MB" in msg + + +# ── maybe_persist_tool_result ───────────────────────────────────────── + +class TestMaybePersistToolResult: + def test_below_threshold_returns_unchanged(self): + content = "small result" + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_123", + env=None, + threshold=50_000, + ) + assert result == content + + def test_above_threshold_with_env_persists(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_456", + env=env, + threshold=30_000, + ) + assert PERSISTED_OUTPUT_TAG in result + assert "tc_456.txt" in result + assert len(result) < len(content) + env.execute.assert_called_once() + + def test_persists_raw_output_not_json_wrapper(self): + """When content is JSON with 'output' key, file should contain raw output.""" + import json + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + raw = "line1\nline2\n" * 5_000 + content = json.dumps({"output": raw, "exit_code": 0, "error": None}) + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_json", + env=env, + threshold=30_000, + ) + assert PERSISTED_OUTPUT_TAG in result + # The heredoc written to sandbox should contain raw text, not JSON + cmd = env.execute.call_args[0][0] + assert "line1\nline2\n" in cmd + assert '"exit_code"' not in cmd + + def test_above_threshold_no_env_truncates_inline(self): + content = "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_789", + env=None, + threshold=30_000, + ) + assert PERSISTED_OUTPUT_TAG not in result + assert "Truncated" in result + assert len(result) < len(content) + + def test_env_write_failure_falls_back_to_truncation(self): + env = MagicMock() + env.execute.return_value = {"output": "disk full", "returncode": 1} + content = "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_fail", + env=env, + threshold=30_000, + ) + assert PERSISTED_OUTPUT_TAG not in result + assert "Truncated" in result + + def test_env_execute_exception_falls_back(self): + env = MagicMock() + env.execute.side_effect = RuntimeError("connection lost") + content = "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_exc", + env=env, + threshold=30_000, + ) + assert "Truncated" in result + + def test_read_file_never_persisted(self): + """read_file has threshold=inf, should never be persisted.""" + env = MagicMock() + content = "x" * 200_000 + result = maybe_persist_tool_result( + content=content, + tool_name="read_file", + tool_use_id="tc_rf", + env=env, + threshold=float("inf"), + ) + assert result == content + env.execute.assert_not_called() + + def test_uses_registry_threshold_when_not_provided(self): + """When threshold=None, looks up from registry.""" + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = "x" * 60_000 + + mock_registry = MagicMock() + mock_registry.get_max_result_size.return_value = 30_000 + + with patch("tools.registry.registry", mock_registry): + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_reg", + env=env, + threshold=None, + ) + # Should have persisted since 60K > 30K + assert PERSISTED_OUTPUT_TAG in result or "Truncated" in result + + def test_unicode_content_survives(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = "日本語テスト " * 10_000 # ~60K chars of unicode + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_uni", + env=env, + threshold=30_000, + ) + assert PERSISTED_OUTPUT_TAG in result + # Preview should contain unicode + assert "日本語テスト" in result + + def test_empty_content_returns_unchanged(self): + result = maybe_persist_tool_result( + content="", + tool_name="terminal", + tool_use_id="tc_empty", + env=None, + threshold=30_000, + ) + assert result == "" + + def test_whitespace_only_below_threshold(self): + content = " " * 100 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_ws", + env=None, + threshold=30_000, + ) + assert result == content + + def test_file_path_uses_tool_use_id(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="unique_id_abc", + env=env, + threshold=30_000, + ) + assert "unique_id_abc.txt" in result + + def test_preview_included_in_persisted_output(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + # Create content with a distinctive start + content = "DISTINCTIVE_START_MARKER" + "x" * 60_000 + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_prev", + env=env, + threshold=30_000, + ) + assert "DISTINCTIVE_START_MARKER" in result + + def test_threshold_zero_forces_persist(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + content = "even short content" + result = maybe_persist_tool_result( + content=content, + tool_name="terminal", + tool_use_id="tc_zero", + env=env, + threshold=0, + ) + # Any non-empty content with threshold=0 should be persisted + assert PERSISTED_OUTPUT_TAG in result + + +# ── enforce_turn_budget ─────────────────────────────────────────────── + +class TestEnforceTurnBudget: + def test_under_budget_no_changes(self): + msgs = [ + {"role": "tool", "tool_call_id": "t1", "content": "small"}, + {"role": "tool", "tool_call_id": "t2", "content": "also small"}, + ] + result = enforce_turn_budget(msgs, env=None, budget=200_000) + assert result[0]["content"] == "small" + assert result[1]["content"] == "also small" + + def test_over_budget_largest_persisted_first(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + msgs = [ + {"role": "tool", "tool_call_id": "t1", "content": "a" * 80_000}, + {"role": "tool", "tool_call_id": "t2", "content": "b" * 130_000}, + ] + # Total 210K > 200K budget + enforce_turn_budget(msgs, env=env, budget=200_000) + # The larger one (130K) should be persisted first + assert PERSISTED_OUTPUT_TAG in msgs[1]["content"] + + def test_already_persisted_results_skipped(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + msgs = [ + {"role": "tool", "tool_call_id": "t1", + "content": f"{PERSISTED_OUTPUT_TAG}\nalready persisted\n{PERSISTED_OUTPUT_CLOSING_TAG}"}, + {"role": "tool", "tool_call_id": "t2", "content": "x" * 250_000}, + ] + enforce_turn_budget(msgs, env=env, budget=200_000) + # t1 should be untouched (already persisted) + assert msgs[0]["content"].startswith(PERSISTED_OUTPUT_TAG) + # t2 should be persisted + assert PERSISTED_OUTPUT_TAG in msgs[1]["content"] + + def test_medium_result_regression(self): + """6 results of 42K chars each (252K total) — each under 50K default + threshold but aggregate exceeds 200K budget. L3 should persist.""" + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + msgs = [ + {"role": "tool", "tool_call_id": f"t{i}", "content": "x" * 42_000} + for i in range(6) + ] + enforce_turn_budget(msgs, env=env, budget=200_000) + # At least some results should be persisted to get under 200K + persisted_count = sum( + 1 for m in msgs if PERSISTED_OUTPUT_TAG in m["content"] + ) + assert persisted_count >= 2 # Need to shed at least ~52K + + def test_no_env_falls_back_to_truncation(self): + msgs = [ + {"role": "tool", "tool_call_id": "t1", "content": "x" * 250_000}, + ] + enforce_turn_budget(msgs, env=None, budget=200_000) + # Should be truncated (no sandbox available) + assert "Truncated" in msgs[0]["content"] or PERSISTED_OUTPUT_TAG in msgs[0]["content"] + + def test_returns_same_list(self): + msgs = [{"role": "tool", "tool_call_id": "t1", "content": "ok"}] + result = enforce_turn_budget(msgs, env=None, budget=200_000) + assert result is msgs + + def test_empty_messages(self): + result = enforce_turn_budget([], env=None, budget=200_000) + assert result == [] + + +# ── Per-tool threshold integration ──────────────────────────────────── + +class TestPerToolThresholds: + """Verify registry wiring for per-tool thresholds.""" + + def test_registry_has_get_max_result_size(self): + from tools.registry import registry + assert hasattr(registry, "get_max_result_size") + + def test_default_threshold(self): + from tools.registry import registry + # Unknown tool should return the default + val = registry.get_max_result_size("nonexistent_tool_xyz") + assert val == DEFAULT_MAX_RESULT_SIZE_CHARS + + def test_terminal_threshold(self): + from tools.registry import registry + # Trigger import of terminal_tool to register the tool + try: + import tools.terminal_tool # noqa: F401 + val = registry.get_max_result_size("terminal") + assert val == 30_000 + except ImportError: + pytest.skip("terminal_tool not importable in test env") + + def test_read_file_never_persisted(self): + from tools.registry import registry + try: + import tools.file_tools # noqa: F401 + val = registry.get_max_result_size("read_file") + assert val == float("inf") + except ImportError: + pytest.skip("file_tools not importable in test env") + + def test_search_files_threshold(self): + from tools.registry import registry + try: + import tools.file_tools # noqa: F401 + val = registry.get_max_result_size("search_files") + assert val == 20_000 + except ImportError: + pytest.skip("file_tools not importable in test env") diff --git a/tools/binary_extensions.py b/tools/binary_extensions.py new file mode 100644 index 000000000..f7e63bdad --- /dev/null +++ b/tools/binary_extensions.py @@ -0,0 +1,42 @@ +"""Binary file extensions to skip for text-based operations. + +These files can't be meaningfully compared as text and are often large. +Ported from free-code src/constants/files.ts. +""" + +BINARY_EXTENSIONS = frozenset({ + # Images + ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", ".tiff", ".tif", + # Videos + ".mp4", ".mov", ".avi", ".mkv", ".webm", ".wmv", ".flv", ".m4v", ".mpeg", ".mpg", + # Audio + ".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma", ".aiff", ".opus", + # Archives + ".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", ".z", ".tgz", ".iso", + # Executables/binaries + ".exe", ".dll", ".so", ".dylib", ".bin", ".o", ".a", ".obj", ".lib", + ".app", ".msi", ".deb", ".rpm", + # Documents (PDF is here; read_file excludes it at the call site) + ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx", + ".odt", ".ods", ".odp", + # Fonts + ".ttf", ".otf", ".woff", ".woff2", ".eot", + # Bytecode / VM artifacts + ".pyc", ".pyo", ".class", ".jar", ".war", ".ear", ".node", ".wasm", ".rlib", + # Database files + ".sqlite", ".sqlite3", ".db", ".mdb", ".idx", + # Design / 3D + ".psd", ".ai", ".eps", ".sketch", ".fig", ".xd", ".blend", ".3ds", ".max", + # Flash + ".swf", ".fla", + # Lock/profiling data + ".lockb", ".dat", ".data", +}) + + +def has_binary_extension(path: str) -> bool: + """Check if a file path has a binary extension. Pure string check, no I/O.""" + dot = path.rfind(".") + if dot == -1: + return False + return path[dot:].lower() in BINARY_EXTENSIONS diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 77be55697..f48c4b99e 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -1343,4 +1343,5 @@ registry.register( enabled_tools=kw.get("enabled_tools")), check_fn=check_sandbox_requirements, emoji="🐍", + max_result_size_chars=30_000, ) diff --git a/tools/file_tools.py b/tools/file_tools.py index 43e40315f..265c9ed2e 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -7,6 +7,7 @@ import logging import os import threading from pathlib import Path +from tools.binary_extensions import has_binary_extension from tools.file_operations import ShellFileOperations from agent.redact import redact_sensitive_text @@ -290,11 +291,24 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = ), }) + # Resolve path once for all guards below + import pathlib as _pathlib + _resolved = _pathlib.Path(path).expanduser().resolve() + + # ── Binary file guard ───────────────────────────────────────── + # Block binary files by extension (no I/O). + if has_binary_extension(str(_resolved)): + _ext = _resolved.suffix.lower() + return json.dumps({ + "error": ( + f"Cannot read binary file '{path}' ({_ext}). " + "Use vision_analyze for images, or terminal to inspect binary files." + ), + }) + # ── Hermes internal path guard ──────────────────────────────── # Prevent prompt injection via catalog or hub metadata files. - import pathlib as _pathlib from hermes_constants import get_hermes_home as _get_hh - _resolved = _pathlib.Path(path).expanduser().resolve() _hermes_home = _get_hh().resolve() _blocked_dirs = [ _hermes_home / "skills" / ".hub" / "index-cache", @@ -313,6 +327,27 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = except ValueError: pass + # ── Pre-read file size guard ────────────────────────────────── + # Stat the file before reading. If it's large and the model + # didn't request a narrow range, block and tell it to use + # offset/limit — cheaper than reading 200K chars then rejecting. + _PRE_READ_MAX_BYTES = 100_000 + _NARROW_LIMIT = 200 + try: + _fsize = os.path.getsize(str(_resolved)) + except OSError: + _fsize = 0 + if _fsize > _PRE_READ_MAX_BYTES and limit > _NARROW_LIMIT: + return json.dumps({ + "error": ( + f"File is too large to read in full ({_fsize:,} bytes). " + f"Use offset and limit parameters to read specific sections " + f"(e.g. offset=1, limit=100 for the first 100 lines)." + ), + "path": path, + "file_size": _fsize, + }, ensure_ascii=False) + # ── Dedup check ─────────────────────────────────────────────── # If we already read this exact (path, offset, limit) and the # file hasn't been modified since, return a lightweight stub @@ -726,7 +761,7 @@ def _check_file_reqs(): READ_FILE_SCHEMA = { "name": "read_file", - "description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. Reads exceeding ~100K characters are rejected; use offset and limit to read specific sections of large files. NOTE: Cannot read images or binary files — use vision_analyze for images.", + "description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. When you already know which part of the file you need, only read that part using offset and limit — this is important for larger files. Files over 100KB will be rejected unless you specify a narrow range (limit <= 200). NOTE: Cannot read images or binary files — use vision_analyze for images.", "parameters": { "type": "object", "properties": { @@ -817,7 +852,7 @@ def _handle_search_files(args, **kw): output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid) -registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖") -registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️") -registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧") -registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎") +registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖", max_result_size_chars=float('inf')) +registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️", max_result_size_chars=100_000) +registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧", max_result_size_chars=100_000) +registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎", max_result_size_chars=20_000) diff --git a/tools/registry.py b/tools/registry.py index 079052a3f..c01c60c09 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -27,10 +27,12 @@ class ToolEntry: __slots__ = ( "name", "toolset", "schema", "handler", "check_fn", "requires_env", "is_async", "description", "emoji", + "max_result_size_chars", ) def __init__(self, name, toolset, schema, handler, check_fn, - requires_env, is_async, description, emoji): + requires_env, is_async, description, emoji, + max_result_size_chars=None): self.name = name self.toolset = toolset self.schema = schema @@ -40,6 +42,7 @@ class ToolEntry: self.is_async = is_async self.description = description self.emoji = emoji + self.max_result_size_chars = max_result_size_chars class ToolRegistry: @@ -64,6 +67,7 @@ class ToolRegistry: is_async: bool = False, description: str = "", emoji: str = "", + max_result_size_chars: int | float | None = None, ): """Register a tool. Called at module-import time by each tool file.""" existing = self._tools.get(name) @@ -83,6 +87,7 @@ class ToolRegistry: is_async=is_async, description=description or schema.get("description", ""), emoji=emoji, + max_result_size_chars=max_result_size_chars, ) if check_fn and toolset not in self._toolset_checks: self._toolset_checks[toolset] = check_fn @@ -164,6 +169,14 @@ class ToolRegistry: # Query helpers (replace redundant dicts in model_tools.py) # ------------------------------------------------------------------ + def get_max_result_size(self, name: str) -> int | float: + """Return per-tool max result size, or global default.""" + from tools.tool_result_storage import DEFAULT_MAX_RESULT_SIZE_CHARS + entry = self._tools.get(name) + if entry and entry.max_result_size_chars is not None: + return entry.max_result_size_chars + return DEFAULT_MAX_RESULT_SIZE_CHARS + def get_all_tool_names(self) -> List[str]: """Return sorted list of all registered tool names.""" return sorted(self._tools.keys()) diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 32f1bd3be..ff9e064b8 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -811,6 +811,12 @@ def _stop_cleanup_thread(): pass +def get_active_env(task_id: str): + """Return the active BaseEnvironment for *task_id*, or None.""" + with _env_lock: + return _active_environments.get(task_id) + + def get_active_environments_info() -> Dict[str, Any]: """Get information about currently active environments.""" info = { @@ -1617,4 +1623,5 @@ registry.register( handler=_handle_terminal, check_fn=check_terminal_requirements, emoji="💻", + max_result_size_chars=30_000, ) diff --git a/tools/tool_result_storage.py b/tools/tool_result_storage.py new file mode 100644 index 000000000..c478431be --- /dev/null +++ b/tools/tool_result_storage.py @@ -0,0 +1,223 @@ +"""Tool result persistence -- preserves large outputs instead of truncating. + +Defense against context-window overflow operates at three levels: + +1. **Per-tool output cap** (inside each tool): Tools like search_files + pre-truncate their own output before returning. This is the first line + of defense and the only one the tool author controls. + +2. **Per-result persistence** (maybe_persist_tool_result): After a tool + returns, if its output exceeds the tool's registered threshold + (registry.get_max_result_size), the full output is written INTO THE + SANDBOX at /tmp/hermes-results/{tool_use_id}.txt via env.execute(). + The in-context content is replaced with a preview + file path reference. + The model can read_file to access the full output on any backend. + +3. **Per-turn aggregate budget** (enforce_turn_budget): After all tool + results in a single assistant turn are collected, if the total exceeds + MAX_TURN_BUDGET_CHARS (200K), the largest non-persisted results are + spilled to disk until the aggregate is under budget. This catches cases + where many medium-sized results combine to overflow context. +""" + +import json +import logging +import uuid + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_RESULT_SIZE_CHARS: int = 50_000 +MAX_TURN_BUDGET_CHARS: int = 200_000 +PREVIEW_SIZE_CHARS: int = 2_000 +PERSISTED_OUTPUT_TAG = "" +PERSISTED_OUTPUT_CLOSING_TAG = "" +STORAGE_DIR = "/tmp/hermes-results" +HEREDOC_MARKER = "HERMES_PERSIST_EOF" +_BUDGET_TOOL_NAME = "__budget_enforcement__" + + +def generate_preview(content: str, max_chars: int = PREVIEW_SIZE_CHARS) -> tuple[str, bool]: + """Truncate at last newline within max_chars. Returns (preview, has_more).""" + if len(content) <= max_chars: + return content, False + truncated = content[:max_chars] + last_nl = truncated.rfind("\n") + if last_nl > max_chars // 2: + truncated = truncated[:last_nl + 1] + return truncated, True + + +def _heredoc_marker(content: str) -> str: + """Return a heredoc delimiter that doesn't collide with content.""" + if HEREDOC_MARKER not in content: + return HEREDOC_MARKER + return f"HERMES_PERSIST_{uuid.uuid4().hex[:8]}" + + +def _extract_raw_output(content: str) -> str: + """Extract the 'output' field from JSON tool results for cleaner persistence. + + Tool handlers return json.dumps({"output": ..., "exit_code": ...}) for the + API, but persisted files should contain readable text, not a JSON blob. + """ + try: + data = json.loads(content) + if isinstance(data, dict) and "output" in data: + return data["output"] + except (json.JSONDecodeError, TypeError): + pass + return content + + +def _write_to_sandbox(content: str, remote_path: str, env) -> bool: + """Write content into the sandbox via env.execute(). Returns True on success.""" + marker = _heredoc_marker(content) + cmd = ( + f"mkdir -p {STORAGE_DIR} && cat > {remote_path} << '{marker}'\n" + f"{content}\n" + f"{marker}" + ) + result = env.execute(cmd, timeout=30) + return result.get("returncode", 1) == 0 + + +def _build_persisted_message( + preview: str, + has_more: bool, + original_size: int, + file_path: str, +) -> str: + """Build the replacement block.""" + size_kb = original_size / 1024 + if size_kb >= 1024: + size_str = f"{size_kb / 1024:.1f} MB" + else: + size_str = f"{size_kb:.1f} KB" + + msg = f"{PERSISTED_OUTPUT_TAG}\n" + msg += f"This tool result was too large ({original_size:,} characters, {size_str}).\n" + msg += f"Full output saved to: {file_path}\n" + msg += "Use the read_file tool with offset and limit to access specific sections of this output.\n\n" + msg += f"Preview (first {len(preview)} chars):\n" + msg += preview + if has_more: + msg += "\n..." + msg += f"\n{PERSISTED_OUTPUT_CLOSING_TAG}" + return msg + + +def maybe_persist_tool_result( + content: str, + tool_name: str, + tool_use_id: str, + env=None, + threshold: int | float | None = None, +) -> str: + """Layer 2: persist oversized result into the sandbox, return preview + path. + + Writes via env.execute() so the file is accessible from any backend + (local, Docker, SSH, Modal, Daytona). Falls back to inline truncation + if write fails or no env is available. + + Args: + content: Raw tool result string. + tool_name: Name of the tool (used for threshold lookup). + tool_use_id: Unique ID for this tool call (used as filename). + env: The active BaseEnvironment instance, or None. + threshold: Override threshold; if None, looked up from registry. + + Returns: + Original content if small, or replacement. + """ + if threshold is None: + from tools.registry import registry + threshold = registry.get_max_result_size(tool_name) + + # Infinity means never persist (e.g. read_file) + if threshold == float("inf"): + return content + + if len(content) <= threshold: + return content + + remote_path = f"{STORAGE_DIR}/{tool_use_id}.txt" + # Write raw output (not JSON wrapper) so read_file returns readable text + file_content = _extract_raw_output(content) + preview, has_more = generate_preview(file_content) + + # Try writing into the sandbox + if env is not None: + try: + if _write_to_sandbox(file_content, remote_path, env): + logger.info( + "Persisted large tool result: %s (%s, %d chars -> %s)", + tool_name, tool_use_id, len(content), remote_path, + ) + return _build_persisted_message(preview, has_more, len(content), remote_path) + except Exception as exc: + logger.warning("Sandbox write failed for %s: %s", tool_use_id, exc) + + # Fallback: inline truncation (no sandbox available or write failed) + logger.info( + "Inline-truncating large tool result: %s (%d chars, no sandbox write)", + tool_name, len(content), + ) + return ( + f"{preview}\n\n" + f"[Truncated: tool response was {len(content):,} chars. " + f"Full output could not be saved to sandbox.]" + ) + + +def enforce_turn_budget( + tool_messages: list[dict], + env=None, + budget: int = MAX_TURN_BUDGET_CHARS, +) -> list[dict]: + """Layer 3: enforce aggregate budget across all tool results in a turn. + + If total chars exceed budget, persist the largest non-persisted results + first (via sandbox write) until under budget. Already-persisted results + are skipped. + + Mutates the list in-place and returns it. + """ + candidates = [] + total_size = 0 + for i, msg in enumerate(tool_messages): + content = msg.get("content", "") + size = len(content) + total_size += size + if PERSISTED_OUTPUT_TAG not in content: + candidates.append((i, size)) + + if total_size <= budget: + return tool_messages + + # Sort candidates by size descending — persist largest first + candidates.sort(key=lambda x: x[1], reverse=True) + + for idx, size in candidates: + if total_size <= budget: + break + msg = tool_messages[idx] + content = msg["content"] + tool_use_id = msg.get("tool_call_id", f"budget_{idx}") + + replacement = maybe_persist_tool_result( + content=content, + tool_name=_BUDGET_TOOL_NAME, + tool_use_id=tool_use_id, + env=env, + threshold=0, + ) + if replacement != content: + total_size -= size + total_size += len(replacement) + tool_messages[idx]["content"] = replacement + logger.info( + "Budget enforcement: persisted tool result %s (%d chars)", + tool_use_id, size, + ) + + return tool_messages diff --git a/tools/web_tools.py b/tools/web_tools.py index 803a09c03..f743c4272 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -2085,6 +2085,7 @@ registry.register( check_fn=check_web_api_key, requires_env=_web_requires_env(), emoji="🔍", + max_result_size_chars=100_000, ) registry.register( name="web_extract", @@ -2096,4 +2097,5 @@ registry.register( requires_env=_web_requires_env(), is_async=True, emoji="📄", + max_result_size_chars=100_000, )