Compare commits
1 Commits
fix/836
...
burn/herme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6336e9f9e1 |
115
hermes_state.py
115
hermes_state.py
@@ -29,9 +29,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Sentinel for "profile parameter not explicitly provided"
|
||||
_UNSET_PROFILE = object()
|
||||
|
||||
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 6
|
||||
SCHEMA_VERSION = 7
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -65,6 +68,7 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
cost_source TEXT,
|
||||
pricing_version TEXT,
|
||||
title TEXT,
|
||||
profile TEXT DEFAULT 'default',
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
@@ -159,7 +163,31 @@ class SessionDB:
|
||||
|
||||
self._init_schema()
|
||||
|
||||
# ── Core write helper ──
|
||||
@staticmethod
|
||||
def _get_active_profile_name() -> str:
|
||||
"""Infer the current profile name from HERMES_HOME.
|
||||
|
||||
Returns the profile directory name when HERMES_HOME points into
|
||||
~/.hermes/profiles/<name>. Returns 'default' otherwise.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
try:
|
||||
default_root = (Path.home() / ".hermes").resolve()
|
||||
resolved = hermes_home.resolve()
|
||||
if resolved == default_root:
|
||||
return "default"
|
||||
# Check if this is a profile path: <root>/profiles/<name>
|
||||
parts = resolved.relative_to(default_root).parts
|
||||
if len(parts) >= 2 and parts[0] == "profiles":
|
||||
return parts[1]
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
# For custom/Docker deployments where parent is 'profiles'
|
||||
if hermes_home.parent.name == "profiles":
|
||||
return hermes_home.name
|
||||
return "default"
|
||||
|
||||
# —— Core write helper ——
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
"""Execute a write transaction with BEGIN IMMEDIATE and jitter retry.
|
||||
@@ -330,6 +358,22 @@ class SessionDB:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 6")
|
||||
|
||||
if current_version < 7:
|
||||
# v7: add profile column to sessions for per-profile isolation
|
||||
try:
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN profile TEXT DEFAULT 'default'")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 7")
|
||||
|
||||
# Profile index — always ensure it exists
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_sessions_profile ON sessions(profile)"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
@@ -363,11 +407,13 @@ class SessionDB:
|
||||
parent_session_id: str = None,
|
||||
) -> str:
|
||||
"""Create a new session record. Returns the session_id."""
|
||||
profile = self._get_active_profile_name()
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config,
|
||||
system_prompt, parent_session_id, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
system_prompt, parent_session_id, started_at, profile)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
source,
|
||||
@@ -377,6 +423,7 @@ class SessionDB:
|
||||
system_prompt,
|
||||
parent_session_id,
|
||||
time.time(),
|
||||
profile,
|
||||
),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
@@ -511,12 +558,14 @@ class SessionDB:
|
||||
create_session() call (e.g. transient SQLite lock at agent startup).
|
||||
INSERT OR IGNORE is safe to call even when the row already exists.
|
||||
"""
|
||||
profile = self._get_active_profile_name()
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions
|
||||
(id, source, model, started_at)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time()),
|
||||
(id, source, model, started_at, profile)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time(), profile),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
@@ -721,6 +770,7 @@ class SessionDB:
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
include_children: bool = False,
|
||||
profile: Any = _UNSET_PROFILE,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
@@ -732,6 +782,9 @@ class SessionDB:
|
||||
|
||||
By default, child sessions (subagent runs, compression continuations)
|
||||
are excluded. Pass ``include_children=True`` to include them.
|
||||
|
||||
Profile filtering defaults to the active profile. Pass ``profile=None``
|
||||
to list sessions from all profiles (admin/audit mode).
|
||||
"""
|
||||
where_clauses = []
|
||||
params = []
|
||||
@@ -747,6 +800,13 @@ class SessionDB:
|
||||
where_clauses.append(f"s.source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
|
||||
# Profile isolation: default to active profile when not explicitly overridden
|
||||
if profile is _UNSET_PROFILE:
|
||||
profile = self._get_active_profile_name()
|
||||
if profile is not None:
|
||||
where_clauses.append("s.profile = ?")
|
||||
params.append(profile)
|
||||
|
||||
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
||||
query = f"""
|
||||
SELECT s.*,
|
||||
@@ -1095,14 +1155,31 @@ class SessionDB:
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
profile: Any = _UNSET_PROFILE,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions, optionally filtered by source."""
|
||||
"""List sessions, optionally filtered by source.
|
||||
|
||||
Defaults to filtering by the active profile. Pass ``profile=None``
|
||||
to list sessions from all profiles.
|
||||
"""
|
||||
if profile is _UNSET_PROFILE:
|
||||
profile = self._get_active_profile_name()
|
||||
with self._lock:
|
||||
if source:
|
||||
if source and profile:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? AND profile = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, profile, limit, offset),
|
||||
)
|
||||
elif source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, limit, offset),
|
||||
)
|
||||
elif profile:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE profile = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(profile, limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
@@ -1114,13 +1191,27 @@ class SessionDB:
|
||||
# Utility
|
||||
# =========================================================================
|
||||
|
||||
def session_count(self, source: str = None) -> int:
|
||||
"""Count sessions, optionally filtered by source."""
|
||||
def session_count(self, source: str = None, profile: Any = _UNSET_PROFILE) -> int:
|
||||
"""Count sessions, optionally filtered by source and/or profile.
|
||||
|
||||
Defaults to filtering by the active profile. Pass ``profile=None``
|
||||
to count sessions from all profiles.
|
||||
"""
|
||||
if profile is _UNSET_PROFILE:
|
||||
profile = self._get_active_profile_name()
|
||||
with self._lock:
|
||||
if source:
|
||||
if source and profile:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ? AND profile = ?", (source, profile)
|
||||
)
|
||||
elif source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
elif profile:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE profile = ?", (profile,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
@@ -935,7 +935,7 @@ class TestSchemaInit:
|
||||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 6
|
||||
assert version == 7
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
@@ -943,6 +943,12 @@ class TestSchemaInit:
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "title" in columns
|
||||
|
||||
def test_profile_column_exists(self, db):
|
||||
"""Verify the profile column exists in the sessions table."""
|
||||
cursor = db._conn.execute("PRAGMA table_info(sessions)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "profile" in columns
|
||||
|
||||
def test_migration_from_v2(self, tmp_path):
|
||||
"""Simulate a v2 database and verify migration adds title column."""
|
||||
import sqlite3
|
||||
@@ -991,12 +997,12 @@ class TestSchemaInit:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v6
|
||||
# Open with SessionDB — should migrate to v7
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 6
|
||||
assert cursor.fetchone()[0] == 7
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
@@ -1375,3 +1381,188 @@ class TestConcurrentWriteSafety:
|
||||
assert "30" in src, (
|
||||
"SQLite timeout should be at least 30s to handle CLI/gateway lock contention"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Profile isolation (#323)
|
||||
# =========================================================================
|
||||
|
||||
class TestProfileIsolation:
|
||||
def test_create_session_tags_with_default_profile(self, db, monkeypatch):
|
||||
"""Sessions created without HERMES_HOME set should tag as 'default'."""
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
session = db.get_session("s1")
|
||||
assert session["profile"] == "default"
|
||||
|
||||
def test_create_session_tags_with_profile_from_hermes_home(self, db, monkeypatch, tmp_path):
|
||||
"""Sessions created under a profile HERMES_HOME should tag with that profile name."""
|
||||
profile_home = tmp_path / ".hermes" / "profiles" / "coder"
|
||||
profile_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
# Recreate DB so it picks up the new HERMES_HOME
|
||||
db2 = SessionDB(db_path=tmp_path / "profile_state.db")
|
||||
db2.create_session(session_id="s1", source="cli")
|
||||
session = db2.get_session("s1")
|
||||
assert session["profile"] == "coder"
|
||||
db2.close()
|
||||
|
||||
def test_list_sessions_rich_filters_by_active_profile(self, db, monkeypatch, tmp_path):
|
||||
"""list_sessions_rich should only show sessions for the active profile."""
|
||||
# Create a default-profile session
|
||||
db.create_session(session_id="default-sess", source="cli")
|
||||
|
||||
# Simulate profile context
|
||||
profile_home = tmp_path / ".hermes" / "profiles" / "sprint"
|
||||
profile_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
db2 = SessionDB(db_path=tmp_path / "profile_state.db")
|
||||
db2.create_session(session_id="sprint-sess", source="cli")
|
||||
|
||||
# In profile context, should only see sprint sessions
|
||||
sessions = db2.list_sessions_rich()
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0]["id"] == "sprint-sess"
|
||||
db2.close()
|
||||
|
||||
def test_search_sessions_filters_by_active_profile(self, db, monkeypatch, tmp_path):
|
||||
"""search_sessions should only show sessions for the active profile."""
|
||||
db.create_session(session_id="default-sess", source="cli")
|
||||
|
||||
profile_home = tmp_path / ".hermes" / "profiles" / "sprint"
|
||||
profile_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
db2 = SessionDB(db_path=tmp_path / "profile_state.db")
|
||||
db2.create_session(session_id="sprint-sess", source="cli")
|
||||
|
||||
sessions = db2.search_sessions()
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0]["id"] == "sprint-sess"
|
||||
db2.close()
|
||||
|
||||
def test_session_count_filters_by_active_profile(self, db, monkeypatch, tmp_path):
|
||||
"""session_count should only count sessions for the active profile."""
|
||||
db.create_session(session_id="default-sess", source="cli")
|
||||
|
||||
profile_home = tmp_path / ".hermes" / "profiles" / "sprint"
|
||||
profile_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
db2 = SessionDB(db_path=tmp_path / "profile_state.db")
|
||||
db2.create_session(session_id="sprint-sess", source="cli")
|
||||
|
||||
assert db2.session_count() == 1
|
||||
db2.close()
|
||||
|
||||
def test_profile_override_to_none_lists_all_sessions(self, db, monkeypatch, tmp_path):
|
||||
"""Passing profile=None should bypass isolation and list all sessions."""
|
||||
# Create sessions with different profile tags directly in the DB
|
||||
db.create_session(session_id="default-sess", source="cli")
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET profile = ? WHERE id = ?",
|
||||
("default", "default-sess"),
|
||||
)
|
||||
db.create_session(session_id="sprint-sess", source="cli")
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET profile = ? WHERE id = ?",
|
||||
("sprint", "sprint-sess"),
|
||||
)
|
||||
db._conn.commit()
|
||||
|
||||
# With profile=None, should see both profiles
|
||||
sessions = db.search_sessions(profile=None)
|
||||
ids = {s["id"] for s in sessions}
|
||||
assert ids == {"default-sess", "sprint-sess"}
|
||||
|
||||
# Filtering by specific profile works
|
||||
sessions = db.search_sessions(profile="sprint")
|
||||
ids = {s["id"] for s in sessions}
|
||||
assert ids == {"sprint-sess"}
|
||||
|
||||
def test_ensure_session_tags_with_profile(self, db, monkeypatch, tmp_path):
|
||||
"""ensure_session must also tag with the active profile."""
|
||||
profile_home = tmp_path / ".hermes" / "profiles" / "fenrir"
|
||||
profile_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
db2 = SessionDB(db_path=tmp_path / "profile_state.db")
|
||||
db2.ensure_session("late-session", source="gateway")
|
||||
row = db2.get_session("late-session")
|
||||
assert row["profile"] == "fenrir"
|
||||
db2.close()
|
||||
|
||||
def test_schema_migration_v6_to_v7_adds_profile_column(self, tmp_path):
|
||||
"""A v6 database opened by SessionDB gets the profile column via migration."""
|
||||
import sqlite3
|
||||
|
||||
db_path = tmp_path / "migrate_v6.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE schema_version (version INTEGER NOT NULL);
|
||||
INSERT INTO schema_version (version) VALUES (6);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
model TEXT,
|
||||
model_config TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
billing_provider TEXT,
|
||||
billing_base_url TEXT,
|
||||
billing_mode TEXT,
|
||||
estimated_cost_usd REAL,
|
||||
actual_cost_usd REAL,
|
||||
cost_status TEXT,
|
||||
cost_source TEXT,
|
||||
pricing_version TEXT,
|
||||
title TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_name TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
token_count INTEGER,
|
||||
finish_reason TEXT,
|
||||
reasoning TEXT,
|
||||
reasoning_details TEXT,
|
||||
codex_reasoning_items TEXT
|
||||
);
|
||||
""")
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (id, source, started_at) VALUES (?, ?, ?)",
|
||||
("existing", "cli", 1000.0),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 7
|
||||
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session is not None
|
||||
# Existing sessions get 'default' via DEFAULT
|
||||
assert session["profile"] == "default"
|
||||
|
||||
migrated_db.close()
|
||||
|
||||
@@ -26,28 +26,6 @@ class TestHandleFunctionCall:
|
||||
assert "error" in result
|
||||
assert "agent loop" in result["error"].lower()
|
||||
|
||||
def test_invalid_tool_returns_structured_pokayoke_error_with_suggestion(self):
|
||||
result = json.loads(handle_function_call("broswer_type", {"ref": "@e1"}))
|
||||
assert result["pokayoke"] is True
|
||||
assert result["tool_name"] == "broswer_type"
|
||||
assert "Did you mean" in result["error"]
|
||||
|
||||
def test_parameter_typo_is_autocorrected_before_dispatch(self, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_dispatch(name, args, **kwargs):
|
||||
captured["name"] = name
|
||||
captured["args"] = args
|
||||
return json.dumps({"ok": True})
|
||||
|
||||
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
|
||||
|
||||
result = json.loads(handle_function_call("read_file", {"pathe": "test.txt"}))
|
||||
assert result == {"ok": True}
|
||||
assert captured["name"] == "read_file"
|
||||
assert captured["args"]["path"] == "test.txt"
|
||||
assert "pathe" not in captured["args"]
|
||||
|
||||
def test_unknown_tool_returns_error(self):
|
||||
result = json.loads(handle_function_call("totally_fake_tool_xyz", {}))
|
||||
assert "error" in result
|
||||
|
||||
@@ -114,9 +114,8 @@ class TestToolCallValidator:
|
||||
assert len(msgs) == 0
|
||||
|
||||
def test_invalid_tool_suggests(self, validator):
|
||||
is_valid, corrected, params, msgs = validator.validate("broswer_type", {"ref": "@e1"})
|
||||
is_valid, corrected, params, msgs = validator.validate("browser_typo", {"ref": "@e1"})
|
||||
assert is_valid is False
|
||||
assert corrected is None
|
||||
assert "browser_type" in str(msgs)
|
||||
|
||||
def test_auto_correct_tool_name(self, validator):
|
||||
@@ -131,10 +130,12 @@ class TestToolCallValidator:
|
||||
assert "ref" in params
|
||||
assert any("reff" in m and "ref" in m for m in msgs)
|
||||
|
||||
def test_circuit_breaker_triggers_on_third_consecutive_failure(self, validator):
|
||||
validator.validate("nonexistent_tool", {})
|
||||
validator.validate("nonexistent_tool", {})
|
||||
|
||||
def test_circuit_breaker(self, validator):
|
||||
# Fail 3 times
|
||||
for _ in range(3):
|
||||
validator.validate("nonexistent_tool", {})
|
||||
|
||||
# 4th attempt should trigger circuit breaker
|
||||
is_valid, corrected, params, msgs = validator.validate("nonexistent_tool", {})
|
||||
assert is_valid is False
|
||||
assert any("CIRCUIT BREAKER" in m for m in msgs)
|
||||
|
||||
@@ -182,10 +182,7 @@ class ToolCallValidator:
|
||||
name_valid, corrected_name, name_messages = self.validate_tool_name(tool_name)
|
||||
|
||||
if not name_valid:
|
||||
failure_count = self._record_failure(tool_name)
|
||||
if failure_count >= self.failure_threshold:
|
||||
_, _, breaker_messages = self.validate_tool_name(tool_name)
|
||||
return False, None, params, breaker_messages
|
||||
self._record_failure(tool_name)
|
||||
return False, None, params, name_messages
|
||||
|
||||
# Use corrected name if provided
|
||||
@@ -202,8 +199,8 @@ class ToolCallValidator:
|
||||
all_messages = name_messages + param_warnings
|
||||
return True, corrected_name, corrected_params, all_messages
|
||||
|
||||
def _record_failure(self, tool_name: str) -> int:
|
||||
"""Record a failure for circuit breaker and return the new count."""
|
||||
def _record_failure(self, tool_name: str):
|
||||
"""Record a failure for circuit breaker."""
|
||||
self.consecutive_failures[tool_name] = self.consecutive_failures.get(tool_name, 0) + 1
|
||||
count = self.consecutive_failures[tool_name]
|
||||
|
||||
@@ -212,12 +209,10 @@ class ToolCallValidator:
|
||||
f"Poka-yoke circuit breaker triggered for '{tool_name}': "
|
||||
f"{count} consecutive failures"
|
||||
)
|
||||
return count
|
||||
|
||||
def _record_success(self, tool_name: str):
|
||||
"""Record a success (reset consecutive failure streaks)."""
|
||||
if self.consecutive_failures:
|
||||
self.consecutive_failures.clear()
|
||||
"""Record a success (reset failure counter)."""
|
||||
self.consecutive_failures.pop(tool_name, None)
|
||||
|
||||
def get_diagnostic_message(self, tool_name: str) -> str:
|
||||
"""Generate diagnostic message for circuit breaker."""
|
||||
|
||||
Reference in New Issue
Block a user