fix: thread safety for concurrent subagent delegation (#1672)
* fix: thread safety for concurrent subagent delegation Four thread-safety fixes that prevent crashes and data races when running multiple subagents concurrently via delegate_task: 1. Remove redirect_stdout/stderr from delegate_tool — mutating global sys.stdout races with the spinner thread when multiple children start concurrently, causing segfaults. Children already run with quiet_mode=True so the redirect was redundant. 2. Split _run_single_child into _build_child_agent (main thread) + _run_single_child (worker thread). AIAgent construction creates httpx/SSL clients which are not thread-safe to initialize concurrently. 3. Add threading.Lock to SessionDB — subagents share the parent's SessionDB and call create_session/append_message from worker threads with no synchronization. 4. Add _active_children_lock to AIAgent — interrupt() iterates _active_children while worker threads append/remove children. 5. Add _client_cache_lock to auxiliary_client — multiple subagent threads may resolve clients concurrently via call_llm(). Based on PR #1471 by peteromallet. * feat: Honcho base_url override via config.yaml + quick command alias type Two features salvaged from PR #1576: 1. Honcho base_url override: allows pointing Hermes at a remote self-hosted Honcho deployment via config.yaml: honcho: base_url: "http://192.168.x.x:8000" When set, this overrides the Honcho SDK's environment mapping (production/local), enabling LAN/VPN Honcho deployments without requiring the server to live on localhost. Uses config.yaml instead of env var (HONCHO_URL) per project convention. 2. Quick command alias type: adds a new 'alias' quick command type that rewrites to another slash command before normal dispatch: quick_commands: sc: type: alias target: /context Supports both CLI and gateway. Arguments are forwarded to the target command. Based on PR #1576 by redhelix. --------- Co-authored-by: peteromallet <peteromallet@users.noreply.github.com> Co-authored-by: redhelix <redhelix@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
@@ -1171,6 +1172,7 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
|||||||
|
|
||||||
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
||||||
_client_cache: Dict[tuple, tuple] = {}
|
_client_cache: Dict[tuple, tuple] = {}
|
||||||
|
_client_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_client(
|
def _get_cached_client(
|
||||||
@@ -1182,9 +1184,11 @@ def _get_cached_client(
|
|||||||
) -> Tuple[Optional[Any], Optional[str]]:
|
) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
"""Get or create a cached client for the given provider."""
|
"""Get or create a cached client for the given provider."""
|
||||||
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
cache_key = (provider, async_mode, base_url or "", api_key or "")
|
||||||
if cache_key in _client_cache:
|
with _client_cache_lock:
|
||||||
cached_client, cached_default = _client_cache[cache_key]
|
if cache_key in _client_cache:
|
||||||
return cached_client, model or cached_default
|
cached_client, cached_default = _client_cache[cache_key]
|
||||||
|
return cached_client, model or cached_default
|
||||||
|
# Build outside the lock
|
||||||
client, default_model = resolve_provider_client(
|
client, default_model = resolve_provider_client(
|
||||||
provider,
|
provider,
|
||||||
model,
|
model,
|
||||||
@@ -1193,7 +1197,11 @@ def _get_cached_client(
|
|||||||
explicit_api_key=api_key,
|
explicit_api_key=api_key,
|
||||||
)
|
)
|
||||||
if client is not None:
|
if client is not None:
|
||||||
_client_cache[cache_key] = (client, default_model)
|
with _client_cache_lock:
|
||||||
|
if cache_key not in _client_cache:
|
||||||
|
_client_cache[cache_key] = (client, default_model)
|
||||||
|
else:
|
||||||
|
client, default_model = _client_cache[cache_key]
|
||||||
return client, model or default_model
|
return client, model or default_model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
cli.py
11
cli.py
@@ -3652,8 +3652,17 @@ class HermesCLI:
|
|||||||
self.console.print(f"[bold red]Quick command error: {e}[/]")
|
self.console.print(f"[bold red]Quick command error: {e}[/]")
|
||||||
else:
|
else:
|
||||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
||||||
|
elif qcmd.get("type") == "alias":
|
||||||
|
target = qcmd.get("target", "").strip()
|
||||||
|
if target:
|
||||||
|
target = target if target.startswith("/") else f"/{target}"
|
||||||
|
user_args = cmd_original[len(base_cmd):].strip()
|
||||||
|
aliased_command = f"{target} {user_args}".strip()
|
||||||
|
return self.process_command(aliased_command)
|
||||||
|
else:
|
||||||
|
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||||
else:
|
else:
|
||||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]")
|
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||||
# Check for skill slash commands (/gif-search, /axolotl, etc.)
|
# Check for skill slash commands (/gif-search, /axolotl, etc.)
|
||||||
elif base_cmd in _skill_commands:
|
elif base_cmd in _skill_commands:
|
||||||
user_instruction = cmd_original[len(base_cmd):].strip()
|
user_instruction = cmd_original[len(base_cmd):].strip()
|
||||||
|
|||||||
@@ -1421,8 +1421,19 @@ class GatewayRunner:
|
|||||||
return f"Quick command error: {e}"
|
return f"Quick command error: {e}"
|
||||||
else:
|
else:
|
||||||
return f"Quick command '/{command}' has no command defined."
|
return f"Quick command '/{command}' has no command defined."
|
||||||
|
elif qcmd.get("type") == "alias":
|
||||||
|
target = qcmd.get("target", "").strip()
|
||||||
|
if target:
|
||||||
|
target = target if target.startswith("/") else f"/{target}"
|
||||||
|
target_command = target.lstrip("/")
|
||||||
|
user_args = event.get_command_args().strip()
|
||||||
|
event.text = f"{target} {user_args}".strip()
|
||||||
|
command = target_command
|
||||||
|
# Fall through to normal command dispatch below
|
||||||
|
else:
|
||||||
|
return f"Quick command '/{command}' has no target defined."
|
||||||
else:
|
else:
|
||||||
return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)."
|
return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')."
|
||||||
|
|
||||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||||
if command:
|
if command:
|
||||||
|
|||||||
314
hermes_state.py
314
hermes_state.py
@@ -18,6 +18,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
@@ -104,6 +105,7 @@ class SessionDB:
|
|||||||
self.db_path = db_path or DEFAULT_DB_PATH
|
self.db_path = db_path or DEFAULT_DB_PATH
|
||||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self._lock = threading.Lock()
|
||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
str(self.db_path),
|
str(self.db_path),
|
||||||
check_same_thread=False,
|
check_same_thread=False,
|
||||||
@@ -173,9 +175,10 @@ class SessionDB:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the database connection."""
|
"""Close the database connection."""
|
||||||
if self._conn:
|
with self._lock:
|
||||||
self._conn.close()
|
if self._conn:
|
||||||
self._conn = None
|
self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Session lifecycle
|
# Session lifecycle
|
||||||
@@ -192,61 +195,66 @@ class SessionDB:
|
|||||||
parent_session_id: str = None,
|
parent_session_id: str = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new session record. Returns the session_id."""
|
"""Create a new session record. Returns the session_id."""
|
||||||
self._conn.execute(
|
with self._lock:
|
||||||
"""INSERT INTO sessions (id, source, user_id, model, model_config,
|
self._conn.execute(
|
||||||
system_prompt, parent_session_id, started_at)
|
"""INSERT INTO sessions (id, source, user_id, model, model_config,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
system_prompt, parent_session_id, started_at)
|
||||||
(
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
session_id,
|
(
|
||||||
source,
|
session_id,
|
||||||
user_id,
|
source,
|
||||||
model,
|
user_id,
|
||||||
json.dumps(model_config) if model_config else None,
|
model,
|
||||||
system_prompt,
|
json.dumps(model_config) if model_config else None,
|
||||||
parent_session_id,
|
system_prompt,
|
||||||
time.time(),
|
parent_session_id,
|
||||||
),
|
time.time(),
|
||||||
)
|
),
|
||||||
self._conn.commit()
|
)
|
||||||
|
self._conn.commit()
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||||
"""Mark a session as ended."""
|
"""Mark a session as ended."""
|
||||||
self._conn.execute(
|
with self._lock:
|
||||||
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
self._conn.execute(
|
||||||
(time.time(), end_reason, session_id),
|
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
||||||
)
|
(time.time(), end_reason, session_id),
|
||||||
self._conn.commit()
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||||
"""Store the full assembled system prompt snapshot."""
|
"""Store the full assembled system prompt snapshot."""
|
||||||
self._conn.execute(
|
with self._lock:
|
||||||
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
self._conn.execute(
|
||||||
(system_prompt, session_id),
|
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
||||||
)
|
(system_prompt, session_id),
|
||||||
self._conn.commit()
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
def update_token_counts(
|
def update_token_counts(
|
||||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
|
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
|
||||||
model: str = None,
|
model: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Increment token counters and backfill model if not already set."""
|
"""Increment token counters and backfill model if not already set."""
|
||||||
self._conn.execute(
|
with self._lock:
|
||||||
"""UPDATE sessions SET
|
self._conn.execute(
|
||||||
input_tokens = input_tokens + ?,
|
"""UPDATE sessions SET
|
||||||
output_tokens = output_tokens + ?,
|
input_tokens = input_tokens + ?,
|
||||||
model = COALESCE(model, ?)
|
output_tokens = output_tokens + ?,
|
||||||
WHERE id = ?""",
|
model = COALESCE(model, ?)
|
||||||
(input_tokens, output_tokens, model, session_id),
|
WHERE id = ?""",
|
||||||
)
|
(input_tokens, output_tokens, model, session_id),
|
||||||
self._conn.commit()
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get a session by ID."""
|
"""Get a session by ID."""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
cursor = self._conn.execute(
|
||||||
)
|
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
||||||
row = cursor.fetchone()
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
||||||
@@ -331,38 +339,42 @@ class SessionDB:
|
|||||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||||
"""
|
"""
|
||||||
title = self.sanitize_title(title)
|
title = self.sanitize_title(title)
|
||||||
if title:
|
with self._lock:
|
||||||
# Check uniqueness (allow the same session to keep its own title)
|
if title:
|
||||||
|
# Check uniqueness (allow the same session to keep its own title)
|
||||||
|
cursor = self._conn.execute(
|
||||||
|
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||||
|
(title, session_id),
|
||||||
|
)
|
||||||
|
conflict = cursor.fetchone()
|
||||||
|
if conflict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||||
|
)
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||||
(title, session_id),
|
(title, session_id),
|
||||||
)
|
)
|
||||||
conflict = cursor.fetchone()
|
self._conn.commit()
|
||||||
if conflict:
|
rowcount = cursor.rowcount
|
||||||
raise ValueError(
|
return rowcount > 0
|
||||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
|
||||||
)
|
|
||||||
cursor = self._conn.execute(
|
|
||||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
|
||||||
(title, session_id),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
return cursor.rowcount > 0
|
|
||||||
|
|
||||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||||
"""Get the title for a session, or None."""
|
"""Get the title for a session, or None."""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
cursor = self._conn.execute(
|
||||||
)
|
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||||
row = cursor.fetchone()
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
return row["title"] if row else None
|
return row["title"] if row else None
|
||||||
|
|
||||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Look up a session by exact title. Returns session dict or None."""
|
"""Look up a session by exact title. Returns session dict or None."""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
cursor = self._conn.execute(
|
||||||
)
|
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||||
row = cursor.fetchone()
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||||
@@ -379,12 +391,13 @@ class SessionDB:
|
|||||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT id, title, started_at FROM sessions "
|
cursor = self._conn.execute(
|
||||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
"SELECT id, title, started_at FROM sessions "
|
||||||
(f"{escaped} #%",),
|
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||||
)
|
(f"{escaped} #%",),
|
||||||
numbered = cursor.fetchall()
|
)
|
||||||
|
numbered = cursor.fetchall()
|
||||||
|
|
||||||
if numbered:
|
if numbered:
|
||||||
# Return the most recent numbered variant
|
# Return the most recent numbered variant
|
||||||
@@ -409,11 +422,12 @@ class SessionDB:
|
|||||||
# Find all existing numbered variants
|
# Find all existing numbered variants
|
||||||
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||||
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
cursor = self._conn.execute(
|
||||||
(base, f"{escaped} #%"),
|
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||||
)
|
(base, f"{escaped} #%"),
|
||||||
existing = [row["title"] for row in cursor.fetchall()]
|
)
|
||||||
|
existing = [row["title"] for row in cursor.fetchall()]
|
||||||
|
|
||||||
if not existing:
|
if not existing:
|
||||||
return base # No conflict, use the base name as-is
|
return base # No conflict, use the base name as-is
|
||||||
@@ -461,9 +475,11 @@ class SessionDB:
|
|||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
"""
|
"""
|
||||||
params = (source, limit, offset) if source else (limit, offset)
|
params = (source, limit, offset) if source else (limit, offset)
|
||||||
cursor = self._conn.execute(query, params)
|
with self._lock:
|
||||||
|
cursor = self._conn.execute(query, params)
|
||||||
|
rows = cursor.fetchall()
|
||||||
sessions = []
|
sessions = []
|
||||||
for row in cursor.fetchall():
|
for row in rows:
|
||||||
s = dict(row)
|
s = dict(row)
|
||||||
# Build the preview from the raw substring
|
# Build the preview from the raw substring
|
||||||
raw = s.pop("_preview_raw", "").strip()
|
raw = s.pop("_preview_raw", "").strip()
|
||||||
@@ -497,52 +513,54 @@ class SessionDB:
|
|||||||
Also increments the session's message_count (and tool_call_count
|
Also increments the session's message_count (and tool_call_count
|
||||||
if role is 'tool' or tool_calls is present).
|
if role is 'tool' or tool_calls is present).
|
||||||
"""
|
"""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
cursor = self._conn.execute(
|
||||||
tool_calls, tool_name, timestamp, token_count, finish_reason)
|
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
tool_calls, tool_name, timestamp, token_count, finish_reason)
|
||||||
(
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
session_id,
|
(
|
||||||
role,
|
session_id,
|
||||||
content,
|
role,
|
||||||
tool_call_id,
|
content,
|
||||||
json.dumps(tool_calls) if tool_calls else None,
|
tool_call_id,
|
||||||
tool_name,
|
json.dumps(tool_calls) if tool_calls else None,
|
||||||
time.time(),
|
tool_name,
|
||||||
token_count,
|
time.time(),
|
||||||
finish_reason,
|
token_count,
|
||||||
),
|
finish_reason,
|
||||||
)
|
),
|
||||||
msg_id = cursor.lastrowid
|
|
||||||
|
|
||||||
# Update counters
|
|
||||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
|
||||||
# A single assistant message can contain multiple parallel tool calls.
|
|
||||||
num_tool_calls = 0
|
|
||||||
if tool_calls is not None:
|
|
||||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
|
||||||
if num_tool_calls > 0:
|
|
||||||
self._conn.execute(
|
|
||||||
"""UPDATE sessions SET message_count = message_count + 1,
|
|
||||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
|
||||||
(num_tool_calls, session_id),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._conn.execute(
|
|
||||||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
|
||||||
(session_id,),
|
|
||||||
)
|
)
|
||||||
|
msg_id = cursor.lastrowid
|
||||||
|
|
||||||
self._conn.commit()
|
# Update counters
|
||||||
|
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||||
|
# A single assistant message can contain multiple parallel tool calls.
|
||||||
|
num_tool_calls = 0
|
||||||
|
if tool_calls is not None:
|
||||||
|
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||||
|
if num_tool_calls > 0:
|
||||||
|
self._conn.execute(
|
||||||
|
"""UPDATE sessions SET message_count = message_count + 1,
|
||||||
|
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||||
|
(num_tool_calls, session_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._conn.execute(
|
||||||
|
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
return msg_id
|
return msg_id
|
||||||
|
|
||||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||||
"""Load all messages for a session, ordered by timestamp."""
|
"""Load all messages for a session, ordered by timestamp."""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
cursor = self._conn.execute(
|
||||||
(session_id,),
|
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||||
)
|
(session_id,),
|
||||||
rows = cursor.fetchall()
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
result = []
|
result = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
msg = dict(row)
|
msg = dict(row)
|
||||||
@@ -559,13 +577,15 @@ class SessionDB:
|
|||||||
Load messages in the OpenAI conversation format (role + content dicts).
|
Load messages in the OpenAI conversation format (role + content dicts).
|
||||||
Used by the gateway to restore conversation history.
|
Used by the gateway to restore conversation history.
|
||||||
"""
|
"""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT role, content, tool_call_id, tool_calls, tool_name "
|
cursor = self._conn.execute(
|
||||||
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
"SELECT role, content, tool_call_id, tool_calls, tool_name "
|
||||||
(session_id,),
|
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||||
)
|
(session_id,),
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
messages = []
|
messages = []
|
||||||
for row in cursor.fetchall():
|
for row in rows:
|
||||||
msg = {"role": row["role"], "content": row["content"]}
|
msg = {"role": row["role"], "content": row["content"]}
|
||||||
if row["tool_call_id"]:
|
if row["tool_call_id"]:
|
||||||
msg["tool_call_id"] = row["tool_call_id"]
|
msg["tool_call_id"] = row["tool_call_id"]
|
||||||
@@ -675,31 +695,33 @@ class SessionDB:
|
|||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
with self._lock:
|
||||||
cursor = self._conn.execute(sql, params)
|
|
||||||
except sqlite3.OperationalError:
|
|
||||||
# FTS5 query syntax error despite sanitization — return empty
|
|
||||||
return []
|
|
||||||
matches = [dict(row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
# Add surrounding context (1 message before + after each match)
|
|
||||||
for match in matches:
|
|
||||||
try:
|
try:
|
||||||
ctx_cursor = self._conn.execute(
|
cursor = self._conn.execute(sql, params)
|
||||||
"""SELECT role, content FROM messages
|
except sqlite3.OperationalError:
|
||||||
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
|
# FTS5 query syntax error despite sanitization — return empty
|
||||||
ORDER BY id""",
|
return []
|
||||||
(match["session_id"], match["id"], match["id"]),
|
matches = [dict(row) for row in cursor.fetchall()]
|
||||||
)
|
|
||||||
context_msgs = [
|
|
||||||
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
|
||||||
for r in ctx_cursor.fetchall()
|
|
||||||
]
|
|
||||||
match["context"] = context_msgs
|
|
||||||
except Exception:
|
|
||||||
match["context"] = []
|
|
||||||
|
|
||||||
# Remove full content from result (snippet is enough, saves tokens)
|
# Add surrounding context (1 message before + after each match)
|
||||||
|
for match in matches:
|
||||||
|
try:
|
||||||
|
ctx_cursor = self._conn.execute(
|
||||||
|
"""SELECT role, content FROM messages
|
||||||
|
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
|
||||||
|
ORDER BY id""",
|
||||||
|
(match["session_id"], match["id"], match["id"]),
|
||||||
|
)
|
||||||
|
context_msgs = [
|
||||||
|
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
||||||
|
for r in ctx_cursor.fetchall()
|
||||||
|
]
|
||||||
|
match["context"] = context_msgs
|
||||||
|
except Exception:
|
||||||
|
match["context"] = []
|
||||||
|
|
||||||
|
# Remove full content from result (snippet is enough, saves tokens)
|
||||||
|
for match in matches:
|
||||||
match.pop("content", None)
|
match.pop("content", None)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ class HonchoClientConfig:
|
|||||||
workspace_id: str = "hermes"
|
workspace_id: str = "hermes"
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
environment: str = "production"
|
environment: str = "production"
|
||||||
|
# Optional base URL for self-hosted Honcho (overrides environment mapping)
|
||||||
|
base_url: str | None = None
|
||||||
# Identity
|
# Identity
|
||||||
peer_name: str | None = None
|
peer_name: str | None = None
|
||||||
ai_peer: str = "hermes"
|
ai_peer: str = "hermes"
|
||||||
@@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
|
|||||||
"Install it with: pip install honcho-ai"
|
"Install it with: pip install honcho-ai"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
|
# Allow config.yaml honcho.base_url to override the SDK's environment
|
||||||
|
# mapping, enabling remote self-hosted Honcho deployments without
|
||||||
|
# requiring the server to live on localhost.
|
||||||
|
resolved_base_url = config.base_url
|
||||||
|
if not resolved_base_url:
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
hermes_cfg = load_config()
|
||||||
|
honcho_cfg = hermes_cfg.get("honcho", {})
|
||||||
|
if isinstance(honcho_cfg, dict):
|
||||||
|
resolved_base_url = honcho_cfg.get("base_url", "").strip() or None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
_honcho_client = Honcho(
|
if resolved_base_url:
|
||||||
workspace_id=config.workspace_id,
|
logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id)
|
||||||
api_key=config.api_key,
|
else:
|
||||||
environment=config.environment,
|
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
|
||||||
)
|
|
||||||
|
kwargs: dict = {
|
||||||
|
"workspace_id": config.workspace_id,
|
||||||
|
"api_key": config.api_key,
|
||||||
|
"environment": config.environment,
|
||||||
|
}
|
||||||
|
if resolved_base_url:
|
||||||
|
kwargs["base_url"] = resolved_base_url
|
||||||
|
|
||||||
|
_honcho_client = Honcho(**kwargs)
|
||||||
|
|
||||||
return _honcho_client
|
return _honcho_client
|
||||||
|
|
||||||
|
|||||||
@@ -407,6 +407,7 @@ class AIAgent:
|
|||||||
# Subagent delegation state
|
# Subagent delegation state
|
||||||
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
|
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
|
||||||
self._active_children = [] # Running child AIAgents (for interrupt propagation)
|
self._active_children = [] # Running child AIAgents (for interrupt propagation)
|
||||||
|
self._active_children_lock = threading.Lock()
|
||||||
|
|
||||||
# Store OpenRouter provider preferences
|
# Store OpenRouter provider preferences
|
||||||
self.providers_allowed = providers_allowed
|
self.providers_allowed = providers_allowed
|
||||||
@@ -1526,7 +1527,9 @@ class AIAgent:
|
|||||||
# Signal all tools to abort any in-flight operations immediately
|
# Signal all tools to abort any in-flight operations immediately
|
||||||
_set_interrupt(True)
|
_set_interrupt(True)
|
||||||
# Propagate interrupt to any running child agents (subagent delegation)
|
# Propagate interrupt to any running child agents (subagent delegation)
|
||||||
for child in self._active_children:
|
with self._active_children_lock:
|
||||||
|
children_copy = list(self._active_children)
|
||||||
|
for child in children_copy:
|
||||||
try:
|
try:
|
||||||
child.interrupt(message)
|
child.interrupt(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def main() -> int:
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
parent.model = "test/model"
|
parent.model = "test/model"
|
||||||
parent.base_url = "http://localhost:1"
|
parent.base_url = "http://localhost:1"
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase):
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
parent.model = "test/model"
|
parent.model = "test/model"
|
||||||
parent.base_url = "http://localhost:1"
|
parent.base_url = "http://localhost:1"
|
||||||
@@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase):
|
|||||||
mock_instance._interrupt_requested = False
|
mock_instance._interrupt_requested = False
|
||||||
mock_instance._interrupt_message = None
|
mock_instance._interrupt_message = None
|
||||||
mock_instance._active_children = []
|
mock_instance._active_children = []
|
||||||
|
mock_instance._active_children_lock = threading.Lock()
|
||||||
mock_instance.quiet_mode = True
|
mock_instance.quiet_mode = True
|
||||||
mock_instance.run_conversation = mock_child_run_conversation
|
mock_instance.run_conversation = mock_child_run_conversation
|
||||||
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
|
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
|
||||||
mock_instance.tools = []
|
mock_instance.tools = []
|
||||||
MockAgent.return_value = mock_instance
|
MockAgent.return_value = mock_instance
|
||||||
|
|
||||||
|
# Register child manually (normally done by _build_child_agent)
|
||||||
|
parent._active_children.append(mock_instance)
|
||||||
|
|
||||||
result = _run_single_child(
|
result = _run_single_child(
|
||||||
task_index=0,
|
task_index=0,
|
||||||
goal="Do something slow",
|
goal="Do something slow",
|
||||||
context=None,
|
child=mock_instance,
|
||||||
toolsets=["terminal"],
|
|
||||||
model=None,
|
|
||||||
max_iterations=50,
|
|
||||||
parent_agent=parent,
|
parent_agent=parent,
|
||||||
task_count=1,
|
|
||||||
)
|
)
|
||||||
delegate_result[0] = result
|
delegate_result[0] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ def main() -> int:
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
parent.model = "test/model"
|
parent.model = "test/model"
|
||||||
parent.base_url = "http://localhost:1"
|
parent.base_url = "http://localhost:1"
|
||||||
|
|||||||
@@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
|
|
||||||
child = AIAgent.__new__(AIAgent)
|
child = AIAgent.__new__(AIAgent)
|
||||||
child._interrupt_requested = False
|
child._interrupt_requested = False
|
||||||
child._interrupt_message = None
|
child._interrupt_message = None
|
||||||
child._active_children = []
|
child._active_children = []
|
||||||
|
child._active_children_lock = threading.Lock()
|
||||||
child.quiet_mode = True
|
child.quiet_mode = True
|
||||||
|
|
||||||
parent._active_children.append(child)
|
parent._active_children.append(child)
|
||||||
@@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||||||
child._interrupt_message = "msg"
|
child._interrupt_message = "msg"
|
||||||
child.quiet_mode = True
|
child.quiet_mode = True
|
||||||
child._active_children = []
|
child._active_children = []
|
||||||
|
child._active_children_lock = threading.Lock()
|
||||||
|
|
||||||
# Global is set
|
# Global is set
|
||||||
set_interrupt(True)
|
set_interrupt(True)
|
||||||
@@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||||||
child._interrupt_requested = False
|
child._interrupt_requested = False
|
||||||
child._interrupt_message = None
|
child._interrupt_message = None
|
||||||
child._active_children = []
|
child._active_children = []
|
||||||
|
child._active_children_lock = threading.Lock()
|
||||||
child.quiet_mode = True
|
child.quiet_mode = True
|
||||||
child.api_mode = "chat_completions"
|
child.api_mode = "chat_completions"
|
||||||
child.log_prefix = ""
|
child.log_prefix = ""
|
||||||
@@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
|
|
||||||
child = AIAgent.__new__(AIAgent)
|
child = AIAgent.__new__(AIAgent)
|
||||||
child._interrupt_requested = False
|
child._interrupt_requested = False
|
||||||
child._interrupt_message = None
|
child._interrupt_message = None
|
||||||
child._active_children = []
|
child._active_children = []
|
||||||
|
child._active_children_lock = threading.Lock()
|
||||||
child.quiet_mode = True
|
child.quiet_mode = True
|
||||||
|
|
||||||
# Register child (simulating what _run_single_child does)
|
# Register child (simulating what _run_single_child does)
|
||||||
|
|||||||
@@ -47,6 +47,28 @@ class TestCLIQuickCommands:
|
|||||||
args = cli.console.print.call_args[0][0]
|
args = cli.console.print.call_args[0][0]
|
||||||
assert "no output" in args.lower()
|
assert "no output" in args.lower()
|
||||||
|
|
||||||
|
def test_alias_command_routes_to_target(self):
|
||||||
|
"""Alias quick commands rewrite to the target command."""
|
||||||
|
cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}})
|
||||||
|
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
|
||||||
|
cli.process_command("/shortcut")
|
||||||
|
# Should recursively call process_command with /help
|
||||||
|
spy.assert_any_call("/help")
|
||||||
|
|
||||||
|
def test_alias_command_passes_args(self):
|
||||||
|
"""Alias quick commands forward user arguments to the target."""
|
||||||
|
cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}})
|
||||||
|
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
|
||||||
|
cli.process_command("/sc some args")
|
||||||
|
spy.assert_any_call("/context some args")
|
||||||
|
|
||||||
|
def test_alias_no_target_shows_error(self):
|
||||||
|
cli = self._make_cli({"broken": {"type": "alias", "target": ""}})
|
||||||
|
cli.process_command("/broken")
|
||||||
|
cli.console.print.assert_called_once()
|
||||||
|
args = cli.console.print.call_args[0][0]
|
||||||
|
assert "no target defined" in args.lower()
|
||||||
|
|
||||||
def test_unsupported_type_shows_error(self):
|
def test_unsupported_type_shows_error(self):
|
||||||
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
|
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
|
||||||
cli.process_command("/bad")
|
cli.process_command("/bad")
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase):
|
|||||||
parent._interrupt_requested = False
|
parent._interrupt_requested = False
|
||||||
parent._interrupt_message = None
|
parent._interrupt_message = None
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
parent.quiet_mode = True
|
parent.quiet_mode = True
|
||||||
parent.model = "test/model"
|
parent.model = "test/model"
|
||||||
parent.base_url = "http://localhost:1"
|
parent.base_url = "http://localhost:1"
|
||||||
@@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase):
|
|||||||
return original_run(self_agent, *args, **kwargs)
|
return original_run(self_agent, *args, **kwargs)
|
||||||
|
|
||||||
with patch.object(AIAgent, 'run_conversation', patched_run):
|
with patch.object(AIAgent, 'run_conversation', patched_run):
|
||||||
|
# Build a real child agent (AIAgent is NOT patched here,
|
||||||
|
# only run_conversation and _build_system_prompt are)
|
||||||
|
child = AIAgent(
|
||||||
|
base_url="http://localhost:1",
|
||||||
|
api_key="test-key",
|
||||||
|
model="test/model",
|
||||||
|
provider="test",
|
||||||
|
api_mode="chat_completions",
|
||||||
|
max_iterations=5,
|
||||||
|
enabled_toolsets=["terminal"],
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
platform="cli",
|
||||||
|
)
|
||||||
|
child._delegate_depth = 1
|
||||||
|
parent._active_children.append(child)
|
||||||
result = _run_single_child(
|
result = _run_single_child(
|
||||||
task_index=0,
|
task_index=0,
|
||||||
goal="Test task",
|
goal="Test task",
|
||||||
context=None,
|
child=child,
|
||||||
toolsets=["terminal"],
|
|
||||||
model="test/model",
|
|
||||||
max_iterations=5,
|
|
||||||
parent_agent=parent,
|
parent_agent=parent,
|
||||||
task_count=1,
|
|
||||||
override_provider="test",
|
|
||||||
override_base_url="http://localhost:1",
|
|
||||||
override_api_key="test",
|
|
||||||
override_api_mode="chat_completions",
|
|
||||||
)
|
)
|
||||||
result_holder[0] = result
|
result_holder[0] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ def _make_mock_parent(depth=0):
|
|||||||
parent._session_db = None
|
parent._session_db = None
|
||||||
parent._delegate_depth = depth
|
parent._delegate_depth = depth
|
||||||
parent._active_children = []
|
parent._active_children = []
|
||||||
|
parent._active_children_lock = threading.Lock()
|
||||||
return parent
|
return parent
|
||||||
|
|
||||||
|
|
||||||
@@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
parent = _make_mock_parent(depth=0)
|
parent = _make_mock_parent(depth=0)
|
||||||
|
|
||||||
with patch("tools.delegate_tool._run_single_child") as mock_run:
|
# Patch _build_child_agent since credentials are now passed there
|
||||||
|
# (agents are built in the main thread before being handed to workers)
|
||||||
|
with patch("tools.delegate_tool._build_child_agent") as mock_build, \
|
||||||
|
patch("tools.delegate_tool._run_single_child") as mock_run:
|
||||||
|
mock_child = MagicMock()
|
||||||
|
mock_build.return_value = mock_child
|
||||||
mock_run.return_value = {
|
mock_run.return_value = {
|
||||||
"task_index": 0, "status": "completed",
|
"task_index": 0, "status": "completed",
|
||||||
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
|
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
|
||||||
@@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase):
|
|||||||
tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
|
tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
|
||||||
delegate_task(tasks=tasks, parent_agent=parent)
|
delegate_task(tasks=tasks, parent_agent=parent)
|
||||||
|
|
||||||
for call in mock_run.call_args_list:
|
self.assertEqual(mock_build.call_count, 2)
|
||||||
|
for call in mock_build.call_args_list:
|
||||||
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
|
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
|
||||||
self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
|
self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
|
||||||
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")
|
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")
|
||||||
|
|||||||
@@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result,
|
|||||||
never the child's intermediate tool calls or reasoning.
|
never the child's intermediate tool calls or reasoning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
@@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||||||
return _callback
|
return _callback
|
||||||
|
|
||||||
|
|
||||||
def _run_single_child(
|
def _build_child_agent(
|
||||||
task_index: int,
|
task_index: int,
|
||||||
goal: str,
|
goal: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
@@ -158,16 +155,15 @@ def _run_single_child(
|
|||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
max_iterations: int,
|
max_iterations: int,
|
||||||
parent_agent,
|
parent_agent,
|
||||||
task_count: int = 1,
|
|
||||||
# Credential overrides from delegation config (provider:model resolution)
|
# Credential overrides from delegation config (provider:model resolution)
|
||||||
override_provider: Optional[str] = None,
|
override_provider: Optional[str] = None,
|
||||||
override_base_url: Optional[str] = None,
|
override_base_url: Optional[str] = None,
|
||||||
override_api_key: Optional[str] = None,
|
override_api_key: Optional[str] = None,
|
||||||
override_api_mode: Optional[str] = None,
|
override_api_mode: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
):
|
||||||
"""
|
"""
|
||||||
Spawn and run a single child agent. Called from within a thread.
|
Build a child AIAgent on the main thread (thread-safe construction).
|
||||||
Returns a structured result dict.
|
Returns the constructed child agent without running it.
|
||||||
|
|
||||||
When override_* params are set (from delegation config), the child uses
|
When override_* params are set (from delegation config), the child uses
|
||||||
those credentials instead of inheriting from the parent. This enables
|
those credentials instead of inheriting from the parent. This enables
|
||||||
@@ -176,8 +172,6 @@ def _run_single_child(
|
|||||||
"""
|
"""
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
|
|
||||||
child_start = time.monotonic()
|
|
||||||
|
|
||||||
# When no explicit toolsets given, inherit from parent's enabled toolsets
|
# When no explicit toolsets given, inherit from parent's enabled toolsets
|
||||||
# so disabled tools (e.g. web) don't leak to subagents.
|
# so disabled tools (e.g. web) don't leak to subagents.
|
||||||
if toolsets:
|
if toolsets:
|
||||||
@@ -188,65 +182,84 @@ def _run_single_child(
|
|||||||
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
|
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
|
||||||
|
|
||||||
child_prompt = _build_child_system_prompt(goal, context)
|
child_prompt = _build_child_system_prompt(goal, context)
|
||||||
|
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
|
||||||
|
parent_api_key = getattr(parent_agent, "api_key", None)
|
||||||
|
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
|
||||||
|
parent_api_key = parent_agent._client_kwargs.get("api_key")
|
||||||
|
|
||||||
try:
|
# Build progress callback to relay tool calls to parent display
|
||||||
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
|
child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
|
||||||
parent_api_key = getattr(parent_agent, "api_key", None)
|
|
||||||
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
|
|
||||||
parent_api_key = parent_agent._client_kwargs.get("api_key")
|
|
||||||
|
|
||||||
# Build progress callback to relay tool calls to parent display
|
# Share the parent's iteration budget so subagent tool calls
|
||||||
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count)
|
# count toward the session-wide limit.
|
||||||
|
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
||||||
|
|
||||||
# Share the parent's iteration budget so subagent tool calls
|
# Resolve effective credentials: config override > parent inherit
|
||||||
# count toward the session-wide limit.
|
effective_model = model or parent_agent.model
|
||||||
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
effective_provider = override_provider or getattr(parent_agent, "provider", None)
|
||||||
|
effective_base_url = override_base_url or parent_agent.base_url
|
||||||
|
effective_api_key = override_api_key or parent_api_key
|
||||||
|
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
|
||||||
|
|
||||||
# Resolve effective credentials: config override > parent inherit
|
child = AIAgent(
|
||||||
effective_model = model or parent_agent.model
|
base_url=effective_base_url,
|
||||||
effective_provider = override_provider or getattr(parent_agent, "provider", None)
|
api_key=effective_api_key,
|
||||||
effective_base_url = override_base_url or parent_agent.base_url
|
model=effective_model,
|
||||||
effective_api_key = override_api_key or parent_api_key
|
provider=effective_provider,
|
||||||
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
|
api_mode=effective_api_mode,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
max_tokens=getattr(parent_agent, "max_tokens", None),
|
||||||
|
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
||||||
|
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
||||||
|
enabled_toolsets=child_toolsets,
|
||||||
|
quiet_mode=True,
|
||||||
|
ephemeral_system_prompt=child_prompt,
|
||||||
|
log_prefix=f"[subagent-{task_index}]",
|
||||||
|
platform=parent_agent.platform,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
clarify_callback=None,
|
||||||
|
session_db=getattr(parent_agent, '_session_db', None),
|
||||||
|
providers_allowed=parent_agent.providers_allowed,
|
||||||
|
providers_ignored=parent_agent.providers_ignored,
|
||||||
|
providers_order=parent_agent.providers_order,
|
||||||
|
provider_sort=parent_agent.provider_sort,
|
||||||
|
tool_progress_callback=child_progress_cb,
|
||||||
|
iteration_budget=shared_budget,
|
||||||
|
)
|
||||||
|
|
||||||
child = AIAgent(
|
# Set delegation depth so children can't spawn grandchildren
|
||||||
base_url=effective_base_url,
|
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||||
api_key=effective_api_key,
|
|
||||||
model=effective_model,
|
|
||||||
provider=effective_provider,
|
|
||||||
api_mode=effective_api_mode,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
max_tokens=getattr(parent_agent, "max_tokens", None),
|
|
||||||
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
|
||||||
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
|
||||||
enabled_toolsets=child_toolsets,
|
|
||||||
quiet_mode=True,
|
|
||||||
ephemeral_system_prompt=child_prompt,
|
|
||||||
log_prefix=f"[subagent-{task_index}]",
|
|
||||||
platform=parent_agent.platform,
|
|
||||||
skip_context_files=True,
|
|
||||||
skip_memory=True,
|
|
||||||
clarify_callback=None,
|
|
||||||
session_db=getattr(parent_agent, '_session_db', None),
|
|
||||||
providers_allowed=parent_agent.providers_allowed,
|
|
||||||
providers_ignored=parent_agent.providers_ignored,
|
|
||||||
providers_order=parent_agent.providers_order,
|
|
||||||
provider_sort=parent_agent.provider_sort,
|
|
||||||
tool_progress_callback=child_progress_cb,
|
|
||||||
iteration_budget=shared_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set delegation depth so children can't spawn grandchildren
|
# Register child for interrupt propagation
|
||||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
if hasattr(parent_agent, '_active_children'):
|
||||||
|
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||||
# Register child for interrupt propagation
|
if lock:
|
||||||
if hasattr(parent_agent, '_active_children'):
|
with lock:
|
||||||
|
parent_agent._active_children.append(child)
|
||||||
|
else:
|
||||||
parent_agent._active_children.append(child)
|
parent_agent._active_children.append(child)
|
||||||
|
|
||||||
# Run with stdout/stderr suppressed to prevent interleaved output
|
return child
|
||||||
devnull = io.StringIO()
|
|
||||||
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
|
def _run_single_child(
|
||||||
result = child.run_conversation(user_message=goal)
|
task_index: int,
|
||||||
|
goal: str,
|
||||||
|
child=None,
|
||||||
|
parent_agent=None,
|
||||||
|
**_kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run a pre-built child agent. Called from within a thread.
|
||||||
|
Returns a structured result dict.
|
||||||
|
"""
|
||||||
|
child_start = time.monotonic()
|
||||||
|
|
||||||
|
# Get the progress callback from the child agent
|
||||||
|
child_progress_cb = getattr(child, 'tool_progress_callback', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = child.run_conversation(user_message=goal)
|
||||||
|
|
||||||
# Flush any remaining batched progress to gateway
|
# Flush any remaining batched progress to gateway
|
||||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||||
@@ -355,11 +368,15 @@ def _run_single_child(
|
|||||||
# Unregister child from interrupt propagation
|
# Unregister child from interrupt propagation
|
||||||
if hasattr(parent_agent, '_active_children'):
|
if hasattr(parent_agent, '_active_children'):
|
||||||
try:
|
try:
|
||||||
parent_agent._active_children.remove(child)
|
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||||
|
if lock:
|
||||||
|
with lock:
|
||||||
|
parent_agent._active_children.remove(child)
|
||||||
|
else:
|
||||||
|
parent_agent._active_children.remove(child)
|
||||||
except (ValueError, UnboundLocalError) as e:
|
except (ValueError, UnboundLocalError) as e:
|
||||||
logger.debug("Could not remove child from active_children: %s", e)
|
logger.debug("Could not remove child from active_children: %s", e)
|
||||||
|
|
||||||
|
|
||||||
def delegate_task(
|
def delegate_task(
|
||||||
goal: Optional[str] = None,
|
goal: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
@@ -428,51 +445,38 @@ def delegate_task(
|
|||||||
# Track goal labels for progress display (truncated for readability)
|
# Track goal labels for progress display (truncated for readability)
|
||||||
task_labels = [t["goal"][:40] for t in task_list]
|
task_labels = [t["goal"][:40] for t in task_list]
|
||||||
|
|
||||||
if n_tasks == 1:
|
# Build all child agents on the main thread (thread-safe construction)
|
||||||
# Single task -- run directly (no thread pool overhead)
|
children = []
|
||||||
t = task_list[0]
|
for i, t in enumerate(task_list):
|
||||||
result = _run_single_child(
|
child = _build_child_agent(
|
||||||
task_index=0,
|
task_index=i, goal=t["goal"], context=t.get("context"),
|
||||||
goal=t["goal"],
|
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
|
||||||
context=t.get("context"),
|
max_iterations=effective_max_iter, parent_agent=parent_agent,
|
||||||
toolsets=t.get("toolsets") or toolsets,
|
override_provider=creds["provider"], override_base_url=creds["base_url"],
|
||||||
model=creds["model"],
|
|
||||||
max_iterations=effective_max_iter,
|
|
||||||
parent_agent=parent_agent,
|
|
||||||
task_count=1,
|
|
||||||
override_provider=creds["provider"],
|
|
||||||
override_base_url=creds["base_url"],
|
|
||||||
override_api_key=creds["api_key"],
|
override_api_key=creds["api_key"],
|
||||||
override_api_mode=creds["api_mode"],
|
override_api_mode=creds["api_mode"],
|
||||||
)
|
)
|
||||||
|
children.append((i, t, child))
|
||||||
|
|
||||||
|
if n_tasks == 1:
|
||||||
|
# Single task -- run directly (no thread pool overhead)
|
||||||
|
_i, _t, child = children[0]
|
||||||
|
result = _run_single_child(0, _t["goal"], child, parent_agent)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
else:
|
else:
|
||||||
# Batch -- run in parallel with per-task progress lines
|
# Batch -- run in parallel with per-task progress lines
|
||||||
completed_count = 0
|
completed_count = 0
|
||||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||||
|
|
||||||
# Save stdout/stderr before the executor — redirect_stdout in child
|
|
||||||
# threads races on sys.stdout and can leave it as devnull permanently.
|
|
||||||
_saved_stdout = sys.stdout
|
|
||||||
_saved_stderr = sys.stderr
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
|
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
|
||||||
futures = {}
|
futures = {}
|
||||||
for i, t in enumerate(task_list):
|
for i, t, child in children:
|
||||||
future = executor.submit(
|
future = executor.submit(
|
||||||
_run_single_child,
|
_run_single_child,
|
||||||
task_index=i,
|
task_index=i,
|
||||||
goal=t["goal"],
|
goal=t["goal"],
|
||||||
context=t.get("context"),
|
child=child,
|
||||||
toolsets=t.get("toolsets") or toolsets,
|
|
||||||
model=creds["model"],
|
|
||||||
max_iterations=effective_max_iter,
|
|
||||||
parent_agent=parent_agent,
|
parent_agent=parent_agent,
|
||||||
task_count=n_tasks,
|
|
||||||
override_provider=creds["provider"],
|
|
||||||
override_base_url=creds["base_url"],
|
|
||||||
override_api_key=creds["api_key"],
|
|
||||||
override_api_mode=creds["api_mode"],
|
|
||||||
)
|
)
|
||||||
futures[future] = i
|
futures[future] = i
|
||||||
|
|
||||||
@@ -515,10 +519,6 @@ def delegate_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Spinner update_text failed: %s", e)
|
logger.debug("Spinner update_text failed: %s", e)
|
||||||
|
|
||||||
# Restore stdout/stderr in case redirect_stdout race left them as devnull
|
|
||||||
sys.stdout = _saved_stdout
|
|
||||||
sys.stderr = _saved_stderr
|
|
||||||
|
|
||||||
# Sort by task_index so results match input order
|
# Sort by task_index so results match input order
|
||||||
results.sort(key=lambda r: r["task_index"])
|
results.sort(key=lambda r: r["task_index"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user