Compare commits
1 Commits
burn/herme
...
fix/800
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f88e57bcfe |
@@ -1,171 +0,0 @@
|
||||
"""
|
||||
Warm session provisioning — pre-proficient agent sessions.
|
||||
|
||||
Part of #327: [Research] Warm session provisioning - pre-proficient agent sessions.
|
||||
|
||||
Saves a session's conversation history as a reusable warm template.
|
||||
New sessions can start from a warm template instead of a cold start,
|
||||
carrying over established patterns and successful tool-call examples.
|
||||
|
||||
Usage:
|
||||
from agent.warm_session import WarmSessionStore
|
||||
store = WarmSessionStore()
|
||||
store.bake_from_session(session_id="2026...", name="fullstack-dev")
|
||||
history = store.load_template("fullstack-dev")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home, load_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TEMPLATES_DIR = "warm_sessions"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WarmSessionTemplate:
|
||||
"""A reusable warm session template."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
created_from_session_id: str
|
||||
created_at: float
|
||||
message_count: int
|
||||
messages: List[Dict[str, Any]]
|
||||
model: str = ""
|
||||
source: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "WarmSessionTemplate":
|
||||
valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {k: v for k, v in data.items() if k in valid_keys}
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class WarmSessionStore:
|
||||
"""Manage warm session templates on disk."""
|
||||
|
||||
def __init__(self, templates_dir: Path = None) -> None:
|
||||
if templates_dir is None:
|
||||
cfg = load_config()
|
||||
warm_cfg = cfg.get("warm_session", {})
|
||||
custom_dir = warm_cfg.get("templates_dir")
|
||||
if custom_dir:
|
||||
templates_dir = Path(custom_dir).expanduser()
|
||||
else:
|
||||
templates_dir = get_hermes_home() / DEFAULT_TEMPLATES_DIR
|
||||
self.templates_dir = Path(templates_dir)
|
||||
self.templates_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _template_path(self, name: str) -> Path:
|
||||
"""Sanitize name and return path."""
|
||||
safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in name).lower()
|
||||
return self.templates_dir / f"{safe}.json"
|
||||
|
||||
def bake_from_session(
|
||||
self,
|
||||
session_id: str,
|
||||
name: str,
|
||||
description: str = "",
|
||||
max_messages: Optional[int] = None,
|
||||
session_db=None,
|
||||
) -> WarmSessionTemplate:
|
||||
"""Extract conversation history from a session and save as a warm template.
|
||||
|
||||
Args:
|
||||
session_id: Source session ID.
|
||||
name: Template name (must be unique).
|
||||
description: Human-readable description.
|
||||
max_messages: If set, keep only the last N messages.
|
||||
session_db: Optional SessionDB instance. If None, a new one is created.
|
||||
"""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = session_db or SessionDB()
|
||||
try:
|
||||
session_meta = db.get_session(session_id)
|
||||
if not session_meta:
|
||||
raise ValueError(f"Session '{session_id}' not found.")
|
||||
|
||||
messages = db.get_messages_as_conversation(session_id)
|
||||
if not messages:
|
||||
raise ValueError(f"Session '{session_id}' has no messages.")
|
||||
|
||||
# Strip session_meta pseudo-messages
|
||||
messages = [m for m in messages if m.get("role") != "session_meta"]
|
||||
|
||||
if max_messages is not None and max_messages > 0:
|
||||
messages = messages[-max_messages:]
|
||||
|
||||
template = WarmSessionTemplate(
|
||||
name=name,
|
||||
description=description,
|
||||
created_from_session_id=session_id,
|
||||
created_at=time.time(),
|
||||
message_count=len(messages),
|
||||
messages=messages,
|
||||
model=session_meta.get("model", ""),
|
||||
source=session_meta.get("source", ""),
|
||||
)
|
||||
path = self._template_path(name)
|
||||
path.write_text(json.dumps(template.to_dict(), indent=2), encoding="utf-8")
|
||||
logger.info("Baked warm session template '%s' from %s (%d messages)", name, session_id, len(messages))
|
||||
return template
|
||||
finally:
|
||||
if session_db is None:
|
||||
db.close()
|
||||
|
||||
def load_template(self, name: str) -> WarmSessionTemplate:
|
||||
"""Load a warm session template by name."""
|
||||
path = self._template_path(name)
|
||||
if not path.exists():
|
||||
available = [t["name"] for t in self.list_templates()]
|
||||
raise ValueError(
|
||||
f"Warm session template '{name}' not found. "
|
||||
f"Available: {', '.join(available) or 'none'}"
|
||||
)
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return WarmSessionTemplate.from_dict(data)
|
||||
|
||||
def delete_template(self, name: str) -> bool:
|
||||
"""Delete a warm session template. Returns True if it existed."""
|
||||
path = self._template_path(name)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_templates(self) -> List[Dict[str, Any]]:
|
||||
"""List all warm session templates with metadata."""
|
||||
templates = []
|
||||
for path in sorted(self.templates_dir.glob("*.json")):
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
templates.append({
|
||||
"name": data.get("name", path.stem),
|
||||
"description": data.get("description", ""),
|
||||
"created_from_session_id": data.get("created_from_session_id", ""),
|
||||
"created_at": data.get("created_at", 0),
|
||||
"message_count": data.get("message_count", 0),
|
||||
"model": data.get("model", ""),
|
||||
"source": data.get("source", ""),
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to read warm session template %s: %s", path, exc)
|
||||
return templates
|
||||
|
||||
def provision_conversation_history(self, name: str) -> List[Dict[str, Any]]:
|
||||
"""Return the conversation history messages for a template."""
|
||||
template = self.load_template(name)
|
||||
return list(template.messages)
|
||||
31
cli.py
31
cli.py
@@ -1595,7 +1595,6 @@ class HermesCLI:
|
||||
resume: str = None,
|
||||
checkpoints: bool = False,
|
||||
pass_session_id: bool = False,
|
||||
warm_template: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Hermes CLI.
|
||||
@@ -1779,7 +1778,6 @@ class HermesCLI:
|
||||
self.conversation_history: List[Dict[str, Any]] = []
|
||||
self.session_start = datetime.now()
|
||||
self._resumed = False
|
||||
self.warm_template = warm_template
|
||||
# Initialize SQLite session store early so /title works before first message
|
||||
self._session_db = None
|
||||
try:
|
||||
@@ -2859,33 +2857,6 @@ class HermesCLI:
|
||||
self._session_db._conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Warm session provisioning (#327): preload template messages into a fresh session
|
||||
if self.warm_template and not self.conversation_history and self._session_db:
|
||||
try:
|
||||
from agent.warm_session import WarmSessionStore
|
||||
store = WarmSessionStore()
|
||||
template = store.load_template(self.warm_template)
|
||||
self.conversation_history = list(template.messages)
|
||||
# Persist warm messages into the new session so they survive
|
||||
for msg in template.messages:
|
||||
self._session_db.append_message(
|
||||
session_id=self.session_id,
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
reasoning=msg.get("reasoning"),
|
||||
reasoning_details=msg.get("reasoning_details"),
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items"),
|
||||
)
|
||||
_cprint(
|
||||
f"\033[1;36m\u267e Warm start from template '{template.name}'"
|
||||
f" ({template.message_count} messages){_RST}"
|
||||
)
|
||||
except Exception as e:
|
||||
_cprint(f"\033[1;31mWarm template '{self.warm_template}' failed: {e}{_RST}")
|
||||
|
||||
try:
|
||||
runtime = runtime_override or {
|
||||
@@ -9862,7 +9833,6 @@ def main(
|
||||
w: bool = False,
|
||||
checkpoints: bool = False,
|
||||
pass_session_id: bool = False,
|
||||
warm_template: str = None,
|
||||
):
|
||||
"""
|
||||
Hermes Agent CLI - Interactive AI Assistant
|
||||
@@ -9972,7 +9942,6 @@ def main(
|
||||
resume=resume,
|
||||
checkpoints=checkpoints,
|
||||
pass_session_id=pass_session_id,
|
||||
warm_template=warm_template,
|
||||
)
|
||||
|
||||
if parsed_skills:
|
||||
|
||||
@@ -702,16 +702,8 @@ DEFAULT_CONFIG = {
|
||||
"force_ipv4": False,
|
||||
},
|
||||
|
||||
# Warm session provisioning (#327)
|
||||
"warm_session": {
|
||||
# Directory for warm session templates (default: ~/.hermes/warm_sessions)
|
||||
"templates_dir": "",
|
||||
# Default template to use when --warm is passed without a name
|
||||
"default_template": "",
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 18,
|
||||
"_config_version": 17,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -1794,7 +1786,7 @@ _KNOWN_ROOT_KEYS = {
|
||||
"_config_version", "model", "providers", "fallback_model",
|
||||
"fallback_providers", "credential_pool_strategies", "toolsets",
|
||||
"agent", "terminal", "display", "compression", "delegation",
|
||||
"auxiliary", "custom_providers", "context", "memory", "gateway", "warm_session",
|
||||
"auxiliary", "custom_providers", "context", "memory", "gateway",
|
||||
}
|
||||
|
||||
# Valid fields inside a custom_providers list entry
|
||||
|
||||
@@ -772,7 +772,6 @@ def cmd_chat(args):
|
||||
"checkpoints": getattr(args, "checkpoints", False),
|
||||
"pass_session_id": getattr(args, "pass_session_id", False),
|
||||
"max_turns": getattr(args, "max_turns", None),
|
||||
"warm_template": getattr(args, "warm", None),
|
||||
}
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
@@ -4578,12 +4577,6 @@ For more help on a command:
|
||||
default=False,
|
||||
help="Include the session ID in the agent's system prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warm",
|
||||
metavar="TEMPLATE",
|
||||
default=None,
|
||||
help="Start from a warm session template (see 'hermes sessions warm')"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
@@ -5557,28 +5550,6 @@ Examples:
|
||||
sessions_browse.add_argument("--source", help="Filter by source (cli, telegram, discord, etc.)")
|
||||
sessions_browse.add_argument("--limit", type=int, default=50, help="Max sessions to load (default: 50)")
|
||||
|
||||
# warm session provisioning (#327)
|
||||
sessions_warm = sessions_subparsers.add_parser(
|
||||
"warm",
|
||||
help="Warm session provisioning — bake, list, delete, and use warm templates",
|
||||
)
|
||||
warm_subparsers = sessions_warm.add_subparsers(dest="warm_action")
|
||||
|
||||
warm_bake = warm_subparsers.add_parser("bake", help="Save a session as a warm template")
|
||||
warm_bake.add_argument("session_id", help="Session ID or title to bake")
|
||||
warm_bake.add_argument("--name", required=True, help="Template name")
|
||||
warm_bake.add_argument("--description", default="", help="Template description")
|
||||
warm_bake.add_argument("--max-msgs", type=int, default=None, help="Keep only the last N messages")
|
||||
|
||||
warm_subparsers.add_parser("list", help="List warm session templates")
|
||||
|
||||
warm_delete = warm_subparsers.add_parser("delete", help="Delete a warm template")
|
||||
warm_delete.add_argument("name", help="Template name")
|
||||
warm_delete.add_argument("--yes", "-y", action="store_true", help="Skip confirmation")
|
||||
|
||||
warm_use = warm_subparsers.add_parser("use", help="Start a new session from a warm template")
|
||||
warm_use.add_argument("name", help="Template name")
|
||||
|
||||
def _confirm_prompt(prompt: str) -> bool:
|
||||
"""Prompt for y/N confirmation, safe against non-TTY environments."""
|
||||
try:
|
||||
@@ -5735,83 +5706,6 @@ Examples:
|
||||
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||
print(f"Database size: {size_mb:.1f} MB")
|
||||
|
||||
elif action == "warm":
|
||||
from agent.warm_session import WarmSessionStore
|
||||
|
||||
store = WarmSessionStore()
|
||||
warm_action = getattr(args, "warm_action", None)
|
||||
|
||||
if warm_action == "bake":
|
||||
resolved = db.resolve_session_id(args.session_id)
|
||||
if not resolved:
|
||||
# Try resolving by title
|
||||
resolved = db.resolve_session_by_title(args.session_id)
|
||||
if not resolved:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
db.close()
|
||||
return
|
||||
try:
|
||||
template = store.bake_from_session(
|
||||
session_id=resolved,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
max_messages=args.max_msgs,
|
||||
session_db=db,
|
||||
)
|
||||
print(f"Baked warm template '{template.name}' from session {resolved}")
|
||||
print(f" Messages: {template.message_count}")
|
||||
if template.model:
|
||||
print(f" Model: {template.model}")
|
||||
except Exception as e:
|
||||
print(f"Error baking template: {e}")
|
||||
|
||||
elif warm_action == "list":
|
||||
templates = store.list_templates()
|
||||
if not templates:
|
||||
print("No warm session templates found.")
|
||||
print(f" Bake one with: hermes sessions warm bake <session_id> --name <template>")
|
||||
db.close()
|
||||
return
|
||||
print(f"{'Name':<20} {'Msgs':<6} {'Model':<24} {'Description'}")
|
||||
print("─" * 80)
|
||||
for t in templates:
|
||||
desc = t.get("description", "")[:30]
|
||||
model = t.get("model", "")[:22]
|
||||
print(f"{t['name']:<20} {t['message_count']:<6} {model:<24} {desc}")
|
||||
|
||||
elif warm_action == "delete":
|
||||
if not args.yes:
|
||||
if not _confirm_prompt(f"Delete warm template '{args.name}'? [y/N] "):
|
||||
print("Cancelled.")
|
||||
db.close()
|
||||
return
|
||||
if store.delete_template(args.name):
|
||||
print(f"Deleted warm template '{args.name}'.")
|
||||
else:
|
||||
print(f"Warm template '{args.name}' not found.")
|
||||
|
||||
elif warm_action == "use":
|
||||
try:
|
||||
template = store.load_template(args.name)
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
db.close()
|
||||
return
|
||||
print(f"Starting session from warm template '{template.name}' ({template.message_count} messages)")
|
||||
import shutil
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
os.execvp(hermes_bin, ["hermes", "--warm", args.name])
|
||||
else:
|
||||
os.execvp(
|
||||
sys.executable,
|
||||
[sys.executable, "-m", "hermes_cli.main", "--warm", args.name],
|
||||
)
|
||||
return # won't reach here after execvp
|
||||
|
||||
else:
|
||||
sessions_warm.print_help()
|
||||
|
||||
else:
|
||||
sessions_parser.print_help()
|
||||
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
"""Tests for agent.warm_session (#327)."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.warm_session import WarmSessionStore, WarmSessionTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("agent.warm_session.get_hermes_home", lambda: tmp_path)
|
||||
return WarmSessionStore(templates_dir=tmp_path / "warm_sessions")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Return a mock SessionDB with a fake session."""
|
||||
db = MagicMock()
|
||||
db.get_session.return_value = {
|
||||
"id": "sess_123",
|
||||
"model": "claude-sonnet-4",
|
||||
"source": "cli",
|
||||
}
|
||||
db.get_messages_as_conversation.return_value = [
|
||||
{"role": "system", "content": "You are Hermes"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "tool", "content": "{\"result\": 42}", "tool_call_id": "tc1"},
|
||||
]
|
||||
return db
|
||||
|
||||
|
||||
class TestWarmSessionTemplate:
|
||||
def test_roundtrip(self):
|
||||
t = WarmSessionTemplate(
|
||||
name="dev",
|
||||
description="Fullstack dev warm start",
|
||||
created_from_session_id="s1",
|
||||
created_at=123.0,
|
||||
message_count=2,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="m1",
|
||||
source="cli",
|
||||
)
|
||||
d = t.to_dict()
|
||||
restored = WarmSessionTemplate.from_dict(d)
|
||||
assert restored.name == t.name
|
||||
assert restored.messages == t.messages
|
||||
|
||||
|
||||
class TestWarmSessionStore:
|
||||
def test_bake_from_session(self, store, mock_db):
|
||||
template = store.bake_from_session(
|
||||
session_id="sess_123",
|
||||
name="fullstack-dev",
|
||||
description="A warm template",
|
||||
session_db=mock_db,
|
||||
)
|
||||
assert template.name == "fullstack-dev"
|
||||
assert template.message_count == 4
|
||||
assert template.model == "claude-sonnet-4"
|
||||
assert store._template_path("fullstack-dev").exists()
|
||||
|
||||
def test_bake_max_messages(self, store, mock_db):
|
||||
template = store.bake_from_session(
|
||||
session_id="sess_123",
|
||||
name="recent",
|
||||
max_messages=2,
|
||||
session_db=mock_db,
|
||||
)
|
||||
assert template.message_count == 2
|
||||
assert template.messages[0]["role"] == "assistant"
|
||||
|
||||
def test_bake_session_not_found(self, store):
|
||||
db = MagicMock()
|
||||
db.get_session.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
store.bake_from_session("missing", "x", session_db=db)
|
||||
|
||||
def test_list_templates(self, store, mock_db):
|
||||
store.bake_from_session("sess_123", "t1", session_db=mock_db)
|
||||
store.bake_from_session("sess_123", "t2", session_db=mock_db)
|
||||
templates = store.list_templates()
|
||||
assert len(templates) == 2
|
||||
names = {t["name"] for t in templates}
|
||||
assert names == {"t1", "t2"}
|
||||
|
||||
def test_load_template(self, store, mock_db):
|
||||
store.bake_from_session("sess_123", "prod", session_db=mock_db)
|
||||
loaded = store.load_template("prod")
|
||||
assert loaded.name == "prod"
|
||||
assert loaded.message_count == 4
|
||||
|
||||
def test_load_missing_template(self, store):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
store.load_template("missing")
|
||||
|
||||
def test_delete_template(self, store, mock_db):
|
||||
store.bake_from_session("sess_123", "tmp", session_db=mock_db)
|
||||
assert store.delete_template("tmp") is True
|
||||
assert store.delete_template("tmp") is False
|
||||
|
||||
def test_provision_conversation_history(self, store, mock_db):
|
||||
store.bake_from_session("sess_123", "ctx", session_db=mock_db)
|
||||
hist = store.provision_conversation_history("ctx")
|
||||
assert len(hist) == 4
|
||||
assert hist[0]["role"] == "system"
|
||||
|
||||
def test_strips_session_meta(self, store):
|
||||
db = MagicMock()
|
||||
db.get_session.return_value = {"id": "s1"}
|
||||
db.get_messages_as_conversation.return_value = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "session_meta", "content": "meta"},
|
||||
]
|
||||
template = store.bake_from_session("s1", "meta-test", session_db=db)
|
||||
assert template.message_count == 1
|
||||
assert template.messages[0]["role"] == "user"
|
||||
@@ -459,7 +459,7 @@ class TestCustomProviderCompatibility:
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == DEFAULT_CONFIG["_config_version"]
|
||||
assert raw["_config_version"] == 17
|
||||
assert raw["providers"]["openai-direct"] == {
|
||||
"api": "https://api.openai.com/v1",
|
||||
"api_key": "test-key",
|
||||
@@ -606,6 +606,6 @@ class TestInterimAssistantMessageConfig:
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == DEFAULT_CONFIG["_config_version"]
|
||||
assert raw["_config_version"] == 17
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
||||
39
tests/tools/test_binary_extensions.py
Normal file
39
tests/tools/test_binary_extensions.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Tests for binary_extensions helpers."""
|
||||
|
||||
from tools.binary_extensions import has_binary_extension, has_image_extension
|
||||
|
||||
|
||||
def test_has_image_extension_png():
|
||||
assert has_image_extension("/tmp/test.png") is True
|
||||
assert has_image_extension("/tmp/test.PNG") is True
|
||||
|
||||
|
||||
def test_has_image_extension_jpg_variants():
|
||||
assert has_image_extension("/tmp/test.jpg") is True
|
||||
assert has_image_extension("/tmp/test.jpeg") is True
|
||||
assert has_image_extension("/tmp/test.JPG") is True
|
||||
|
||||
|
||||
def test_has_image_extension_webp():
|
||||
assert has_image_extension("/tmp/test.webp") is True
|
||||
|
||||
|
||||
def test_has_image_extension_gif():
|
||||
assert has_image_extension("/tmp/test.gif") is True
|
||||
|
||||
|
||||
def test_has_image_extension_no_ext():
|
||||
assert has_image_extension("/tmp/test") is False
|
||||
|
||||
|
||||
def test_has_image_extension_non_image():
|
||||
assert has_image_extension("/tmp/test.txt") is False
|
||||
assert has_image_extension("/tmp/test.exe") is False
|
||||
assert has_image_extension("/tmp/test.pdf") is False
|
||||
|
||||
|
||||
def test_has_binary_extension_includes_images():
|
||||
"""All image extensions must also be in binary extensions."""
|
||||
assert has_binary_extension("/tmp/test.png") is True
|
||||
assert has_binary_extension("/tmp/test.jpg") is True
|
||||
assert has_binary_extension("/tmp/test.webp") is True
|
||||
@@ -59,9 +59,9 @@ class TestCamofoxConfigDefaults:
|
||||
browser_cfg = DEFAULT_CONFIG["browser"]
|
||||
assert browser_cfg["camofox"]["managed_persistence"] is False
|
||||
|
||||
def test_config_version_is_at_least_browser_schema(self):
|
||||
def test_config_version_matches_current_schema(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# The browser camofox schema was introduced at version 17.
|
||||
# The global version may be higher; this ensures we haven't regressed.
|
||||
assert DEFAULT_CONFIG["_config_version"] >= 17
|
||||
# The current schema version is tracked globally; unrelated default
|
||||
# options may bump it after browser defaults are added.
|
||||
assert DEFAULT_CONFIG["_config_version"] == 17
|
||||
|
||||
@@ -294,3 +294,67 @@ class TestSearchHints:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TestReadFileImageRouting:
|
||||
"""Tests that image files are routed through vision analysis."""
|
||||
|
||||
@patch("tools.file_tools._analyze_image_with_vision")
|
||||
def test_image_png_routes_to_vision(self, mock_analyze, tmp_path):
|
||||
mock_analyze.return_value = json.dumps({"analysis": "test image"})
|
||||
img = tmp_path / "test.png"
|
||||
img.write_bytes(b"fake png data")
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
result = read_file_tool(str(img))
|
||||
mock_analyze.assert_called_once()
|
||||
assert json.loads(result)["analysis"] == "test image"
|
||||
|
||||
@patch("tools.file_tools._analyze_image_with_vision")
|
||||
def test_image_jpeg_routes_to_vision(self, mock_analyze, tmp_path):
|
||||
mock_analyze.return_value = json.dumps({"analysis": "test image"})
|
||||
img = tmp_path / "test.jpeg"
|
||||
img.write_bytes(b"fake jpeg data")
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
result = read_file_tool(str(img))
|
||||
mock_analyze.assert_called_once()
|
||||
assert json.loads(result)["analysis"] == "test image"
|
||||
|
||||
@patch("tools.file_tools._analyze_image_with_vision")
|
||||
def test_image_webp_routes_to_vision(self, mock_analyze, tmp_path):
|
||||
mock_analyze.return_value = json.dumps({"analysis": "test image"})
|
||||
img = tmp_path / "test.webp"
|
||||
img.write_bytes(b"fake webp data")
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
result = read_file_tool(str(img))
|
||||
mock_analyze.assert_called_once()
|
||||
assert json.loads(result)["analysis"] == "test image"
|
||||
|
||||
def test_non_image_binary_blocked(self, tmp_path):
|
||||
from tools.file_tools import read_file_tool
|
||||
exe = tmp_path / "test.exe"
|
||||
exe.write_bytes(b"fake exe data")
|
||||
result = json.loads(read_file_tool(str(exe)))
|
||||
assert "error" in result
|
||||
assert "Cannot read binary" in result["error"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TestAnalyzeImageWithVision:
|
||||
"""Tests for the _analyze_image_with_vision helper."""
|
||||
|
||||
def test_import_error_fallback(self):
|
||||
with patch.dict("sys.modules", {"tools.vision_tools": None}):
|
||||
from tools.file_tools import _analyze_image_with_vision
|
||||
result = json.loads(_analyze_image_with_vision("/tmp/test.png"))
|
||||
assert "error" in result
|
||||
assert "vision_analyze tool is not available" in result["error"]
|
||||
|
||||
@@ -34,9 +34,22 @@ BINARY_EXTENSIONS = frozenset({
|
||||
})
|
||||
|
||||
|
||||
IMAGE_EXTENSIONS = frozenset({
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", ".tiff", ".tif",
|
||||
})
|
||||
|
||||
|
||||
def has_binary_extension(path: str) -> bool:
|
||||
"""Check if a file path has a binary extension. Pure string check, no I/O."""
|
||||
dot = path.rfind(".")
|
||||
if dot == -1:
|
||||
return False
|
||||
return path[dot:].lower() in BINARY_EXTENSIONS
|
||||
|
||||
|
||||
def has_image_extension(path: str) -> bool:
|
||||
"""Check if a file path has an image extension. Pure string check, no I/O."""
|
||||
dot = path.rfind(".")
|
||||
if dot == -1:
|
||||
return False
|
||||
return path[dot:].lower() in IMAGE_EXTENSIONS
|
||||
|
||||
@@ -1893,11 +1893,13 @@ def browser_get_images(task_id: Optional[str] = None) -> str:
|
||||
def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Take a screenshot of the current page and analyze it with vision AI.
|
||||
|
||||
|
||||
This tool captures what's visually displayed in the browser and sends it
|
||||
to Gemini for analysis. Useful for understanding visual content that the
|
||||
text-based snapshot may not capture (CAPTCHAs, verification challenges,
|
||||
images, complex layouts, etc.).
|
||||
to the configured vision model for analysis. When the active model is
|
||||
natively multimodal (e.g. Gemma 4) it is used directly; otherwise the
|
||||
auxiliary vision backend is used. Useful for understanding visual content
|
||||
that the text-based snapshot may not capture (CAPTCHAs, verification
|
||||
challenges, images, complex layouts, etc.).
|
||||
|
||||
The screenshot is saved persistently and its file path is returned alongside
|
||||
the analysis, so it can be shared with users via MEDIA:<path> in the response.
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from tools.binary_extensions import has_binary_extension
|
||||
from tools.binary_extensions import has_binary_extension, has_image_extension
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
@@ -279,6 +279,52 @@ def clear_file_ops_cache(task_id: str = None):
|
||||
_file_ops_cache.clear()
|
||||
|
||||
|
||||
def _analyze_image_with_vision(image_path: str, task_id: str = "default") -> str:
|
||||
"""Route an image file through the vision analysis pipeline.
|
||||
|
||||
Uses vision_analyze_tool with a default descriptive prompt. Falls back
|
||||
to a manual error when no vision backend is available.
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Image file '{image_path}' detected but vision_analyze tool "
|
||||
"is not available. Use vision_analyze directly if configured."
|
||||
),
|
||||
})
|
||||
|
||||
prompt = (
|
||||
"Describe this image in detail. If it contains text, transcribe "
|
||||
"the text. If it is a diagram, chart, or UI screenshot, describe "
|
||||
"the layout, colors, labels, and any visible data."
|
||||
)
|
||||
|
||||
try:
|
||||
result = asyncio.run(vision_analyze_tool(image_url=image_path, question=prompt))
|
||||
except Exception as exc:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Image file '{image_path}' detected but vision analysis failed: {exc}. "
|
||||
"Use vision_analyze directly if configured."
|
||||
),
|
||||
})
|
||||
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
parsed = {"content": result}
|
||||
|
||||
# Wrap the vision result so the caller knows it came from image analysis
|
||||
return json.dumps({
|
||||
"image_path": image_path,
|
||||
"analysis": parsed.get("content") or parsed.get("analysis") or result,
|
||||
"source": "vision_analyze",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str:
|
||||
"""Read a file with pagination and line numbers."""
|
||||
try:
|
||||
@@ -295,10 +341,13 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||
|
||||
_resolved = Path(path).expanduser().resolve()
|
||||
|
||||
# ── Binary file guard ─────────────────────────────────────────
|
||||
# Block binary files by extension (no I/O).
|
||||
# ── Binary / image file guard ─────────────────────────────────
|
||||
# Block binary files by extension (no I/O). Images are routed
|
||||
# through the vision analysis pipeline when a backend is available.
|
||||
if has_binary_extension(str(_resolved)):
|
||||
_ext = _resolved.suffix.lower()
|
||||
if has_image_extension(str(_resolved)):
|
||||
return _analyze_image_with_vision(str(_resolved), task_id=task_id)
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Cannot read binary file '{path}' ({_ext}). "
|
||||
@@ -729,7 +778,7 @@ def _check_file_reqs():
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. Reads exceeding ~100K characters are rejected; use offset and limit to read specific sections of large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. Reads exceeding ~100K characters are rejected; use offset and limit to read specific sections of large files. NOTE: Image files (PNG, JPEG, WebP, GIF, etc.) are automatically analyzed via vision_analyze. Other binary files cannot be read as text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Reference in New Issue
Block a user