fix: thread safety for concurrent subagent delegation (#1672)

* fix: thread safety for concurrent subagent delegation

Four thread-safety fixes that prevent crashes and data races when
running multiple subagents concurrently via delegate_task:

1. Remove redirect_stdout/stderr from delegate_tool — mutating global
   sys.stdout races with the spinner thread when multiple children start
   concurrently, causing segfaults. Children already run with
   quiet_mode=True so the redirect was redundant.

2. Split _run_single_child into _build_child_agent (main thread) +
   _run_single_child (worker thread). AIAgent construction creates
   httpx/SSL clients which are not thread-safe to initialize
   concurrently.

3. Add threading.Lock to SessionDB — subagents share the parent's
   SessionDB and call create_session/append_message from worker threads
   with no synchronization.

4. Add _active_children_lock to AIAgent — interrupt() iterates
   _active_children while worker threads append/remove children.

5. Add _client_cache_lock to auxiliary_client — multiple subagent
   threads may resolve clients concurrently via call_llm().

Based on PR #1471 by peteromallet.

* feat: Honcho base_url override via config.yaml + quick command alias type

Two features salvaged from PR #1576:

1. Honcho base_url override: allows pointing Hermes at a remote
   self-hosted Honcho deployment via config.yaml:

     honcho:
       base_url: "http://192.168.x.x:8000"

   When set, this overrides the Honcho SDK's environment mapping
   (production/local), enabling LAN/VPN Honcho deployments without
   requiring the server to live on localhost. Uses config.yaml instead
   of env var (HONCHO_URL) per project convention.

2. Quick command alias type: adds a new 'alias' quick command type
   that rewrites to another slash command before normal dispatch:

     quick_commands:
       sc:
         type: alias
         target: /context

   Supports both CLI and gateway. Arguments are forwarded to the
   target command.

Based on PR #1576 by redhelix.

---------

Co-authored-by: peteromallet <peteromallet@users.noreply.github.com>
Co-authored-by: redhelix <redhelix@users.noreply.github.com>
This commit is contained in:
Teknium
2026-03-17 02:53:33 -07:00
committed by GitHub
parent fd61ae13e5
commit 1d5a39e002
14 changed files with 397 additions and 272 deletions

View File

@@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings.
import json import json
import logging import logging
import os import os
import threading
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@@ -1171,6 +1172,7 @@ def auxiliary_max_tokens_param(value: int) -> dict:
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model) # Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
_client_cache: Dict[tuple, tuple] = {} _client_cache: Dict[tuple, tuple] = {}
_client_cache_lock = threading.Lock()
def _get_cached_client( def _get_cached_client(
@@ -1182,9 +1184,11 @@ def _get_cached_client(
) -> Tuple[Optional[Any], Optional[str]]: ) -> Tuple[Optional[Any], Optional[str]]:
"""Get or create a cached client for the given provider.""" """Get or create a cached client for the given provider."""
cache_key = (provider, async_mode, base_url or "", api_key or "") cache_key = (provider, async_mode, base_url or "", api_key or "")
if cache_key in _client_cache: with _client_cache_lock:
cached_client, cached_default = _client_cache[cache_key] if cache_key in _client_cache:
return cached_client, model or cached_default cached_client, cached_default = _client_cache[cache_key]
return cached_client, model or cached_default
# Build outside the lock
client, default_model = resolve_provider_client( client, default_model = resolve_provider_client(
provider, provider,
model, model,
@@ -1193,7 +1197,11 @@ def _get_cached_client(
explicit_api_key=api_key, explicit_api_key=api_key,
) )
if client is not None: if client is not None:
_client_cache[cache_key] = (client, default_model) with _client_cache_lock:
if cache_key not in _client_cache:
_client_cache[cache_key] = (client, default_model)
else:
client, default_model = _client_cache[cache_key]
return client, model or default_model return client, model or default_model

11
cli.py
View File

@@ -3652,8 +3652,17 @@ class HermesCLI:
self.console.print(f"[bold red]Quick command error: {e}[/]") self.console.print(f"[bold red]Quick command error: {e}[/]")
else: else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]") self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
user_args = cmd_original[len(base_cmd):].strip()
aliased_command = f"{target} {user_args}".strip()
return self.process_command(aliased_command)
else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
else: else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]") self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
# Check for skill slash commands (/gif-search, /axolotl, etc.) # Check for skill slash commands (/gif-search, /axolotl, etc.)
elif base_cmd in _skill_commands: elif base_cmd in _skill_commands:
user_instruction = cmd_original[len(base_cmd):].strip() user_instruction = cmd_original[len(base_cmd):].strip()

View File

@@ -1421,8 +1421,19 @@ class GatewayRunner:
return f"Quick command error: {e}" return f"Quick command error: {e}"
else: else:
return f"Quick command '/{command}' has no command defined." return f"Quick command '/{command}' has no command defined."
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
target_command = target.lstrip("/")
user_args = event.get_command_args().strip()
event.text = f"{target} {user_args}".strip()
command = target_command
# Fall through to normal command dispatch below
else:
return f"Quick command '/{command}' has no target defined."
else: else:
return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)." return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')."
# Skill slash commands: /skill-name loads the skill and sends to agent # Skill slash commands: /skill-name loads the skill and sends to agent
if command: if command:

View File

@@ -18,6 +18,7 @@ import json
import os import os
import re import re
import sqlite3 import sqlite3
import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
@@ -104,6 +105,7 @@ class SessionDB:
self.db_path = db_path or DEFAULT_DB_PATH self.db_path = db_path or DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
str(self.db_path), str(self.db_path),
check_same_thread=False, check_same_thread=False,
@@ -173,9 +175,10 @@ class SessionDB:
def close(self): def close(self):
"""Close the database connection.""" """Close the database connection."""
if self._conn: with self._lock:
self._conn.close() if self._conn:
self._conn = None self._conn.close()
self._conn = None
# ========================================================================= # =========================================================================
# Session lifecycle # Session lifecycle
@@ -192,61 +195,66 @@ class SessionDB:
parent_session_id: str = None, parent_session_id: str = None,
) -> str: ) -> str:
"""Create a new session record. Returns the session_id.""" """Create a new session record. Returns the session_id."""
self._conn.execute( with self._lock:
"""INSERT INTO sessions (id, source, user_id, model, model_config, self._conn.execute(
system_prompt, parent_session_id, started_at) """INSERT INTO sessions (id, source, user_id, model, model_config,
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", system_prompt, parent_session_id, started_at)
( VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
session_id, (
source, session_id,
user_id, source,
model, user_id,
json.dumps(model_config) if model_config else None, model,
system_prompt, json.dumps(model_config) if model_config else None,
parent_session_id, system_prompt,
time.time(), parent_session_id,
), time.time(),
) ),
self._conn.commit() )
self._conn.commit()
return session_id return session_id
def end_session(self, session_id: str, end_reason: str) -> None: def end_session(self, session_id: str, end_reason: str) -> None:
"""Mark a session as ended.""" """Mark a session as ended."""
self._conn.execute( with self._lock:
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", self._conn.execute(
(time.time(), end_reason, session_id), "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
) (time.time(), end_reason, session_id),
self._conn.commit() )
self._conn.commit()
def update_system_prompt(self, session_id: str, system_prompt: str) -> None: def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""Store the full assembled system prompt snapshot.""" """Store the full assembled system prompt snapshot."""
self._conn.execute( with self._lock:
"UPDATE sessions SET system_prompt = ? WHERE id = ?", self._conn.execute(
(system_prompt, session_id), "UPDATE sessions SET system_prompt = ? WHERE id = ?",
) (system_prompt, session_id),
self._conn.commit() )
self._conn.commit()
def update_token_counts( def update_token_counts(
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
model: str = None, model: str = None,
) -> None: ) -> None:
"""Increment token counters and backfill model if not already set.""" """Increment token counters and backfill model if not already set."""
self._conn.execute( with self._lock:
"""UPDATE sessions SET self._conn.execute(
input_tokens = input_tokens + ?, """UPDATE sessions SET
output_tokens = output_tokens + ?, input_tokens = input_tokens + ?,
model = COALESCE(model, ?) output_tokens = output_tokens + ?,
WHERE id = ?""", model = COALESCE(model, ?)
(input_tokens, output_tokens, model, session_id), WHERE id = ?""",
) (input_tokens, output_tokens, model, session_id),
self._conn.commit() )
self._conn.commit()
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID.""" """Get a session by ID."""
cursor = self._conn.execute( with self._lock:
"SELECT * FROM sessions WHERE id = ?", (session_id,) cursor = self._conn.execute(
) "SELECT * FROM sessions WHERE id = ?", (session_id,)
row = cursor.fetchone() )
row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
@@ -331,38 +339,42 @@ class SessionDB:
Empty/whitespace-only strings are normalized to None (clearing the title). Empty/whitespace-only strings are normalized to None (clearing the title).
""" """
title = self.sanitize_title(title) title = self.sanitize_title(title)
if title: with self._lock:
# Check uniqueness (allow the same session to keep its own title) if title:
# Check uniqueness (allow the same session to keep its own title)
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?",
(title, session_id),
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?", "UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id), (title, session_id),
) )
conflict = cursor.fetchone() self._conn.commit()
if conflict: rowcount = cursor.rowcount
raise ValueError( return rowcount > 0
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = self._conn.execute(
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
)
self._conn.commit()
return cursor.rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]: def get_session_title(self, session_id: str) -> Optional[str]:
"""Get the title for a session, or None.""" """Get the title for a session, or None."""
cursor = self._conn.execute( with self._lock:
"SELECT title FROM sessions WHERE id = ?", (session_id,) cursor = self._conn.execute(
) "SELECT title FROM sessions WHERE id = ?", (session_id,)
row = cursor.fetchone() )
row = cursor.fetchone()
return row["title"] if row else None return row["title"] if row else None
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
"""Look up a session by exact title. Returns session dict or None.""" """Look up a session by exact title. Returns session dict or None."""
cursor = self._conn.execute( with self._lock:
"SELECT * FROM sessions WHERE title = ?", (title,) cursor = self._conn.execute(
) "SELECT * FROM sessions WHERE title = ?", (title,)
row = cursor.fetchone() )
row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
def resolve_session_by_title(self, title: str) -> Optional[str]: def resolve_session_by_title(self, title: str) -> Optional[str]:
@@ -379,12 +391,13 @@ class SessionDB:
# Also search for numbered variants: "title #2", "title #3", etc. # Also search for numbered variants: "title #2", "title #3", etc.
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute( with self._lock:
"SELECT id, title, started_at FROM sessions " cursor = self._conn.execute(
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", "SELECT id, title, started_at FROM sessions "
(f"{escaped} #%",), "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
) (f"{escaped} #%",),
numbered = cursor.fetchall() )
numbered = cursor.fetchall()
if numbered: if numbered:
# Return the most recent numbered variant # Return the most recent numbered variant
@@ -409,11 +422,12 @@ class SessionDB:
# Find all existing numbered variants # Find all existing numbered variants
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute( with self._lock:
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", cursor = self._conn.execute(
(base, f"{escaped} #%"), "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
) (base, f"{escaped} #%"),
existing = [row["title"] for row in cursor.fetchall()] )
existing = [row["title"] for row in cursor.fetchall()]
if not existing: if not existing:
return base # No conflict, use the base name as-is return base # No conflict, use the base name as-is
@@ -461,9 +475,11 @@ class SessionDB:
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
params = (source, limit, offset) if source else (limit, offset) params = (source, limit, offset) if source else (limit, offset)
cursor = self._conn.execute(query, params) with self._lock:
cursor = self._conn.execute(query, params)
rows = cursor.fetchall()
sessions = [] sessions = []
for row in cursor.fetchall(): for row in rows:
s = dict(row) s = dict(row)
# Build the preview from the raw substring # Build the preview from the raw substring
raw = s.pop("_preview_raw", "").strip() raw = s.pop("_preview_raw", "").strip()
@@ -497,52 +513,54 @@ class SessionDB:
Also increments the session's message_count (and tool_call_count Also increments the session's message_count (and tool_call_count
if role is 'tool' or tool_calls is present). if role is 'tool' or tool_calls is present).
""" """
cursor = self._conn.execute( with self._lock:
"""INSERT INTO messages (session_id, role, content, tool_call_id, cursor = self._conn.execute(
tool_calls, tool_name, timestamp, token_count, finish_reason) """INSERT INTO messages (session_id, role, content, tool_call_id,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", tool_calls, tool_name, timestamp, token_count, finish_reason)
( VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
session_id, (
role, session_id,
content, role,
tool_call_id, content,
json.dumps(tool_calls) if tool_calls else None, tool_call_id,
tool_name, json.dumps(tool_calls) if tool_calls else None,
time.time(), tool_name,
token_count, time.time(),
finish_reason, token_count,
), finish_reason,
) ),
msg_id = cursor.lastrowid
# Update counters
# Count actual tool calls from the tool_calls list (not from tool responses).
# A single assistant message can contain multiple parallel tool calls.
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
if num_tool_calls > 0:
self._conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
self._conn.execute(
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
(session_id,),
) )
msg_id = cursor.lastrowid
self._conn.commit() # Update counters
# Count actual tool calls from the tool_calls list (not from tool responses).
# A single assistant message can contain multiple parallel tool calls.
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
if num_tool_calls > 0:
self._conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
self._conn.execute(
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
(session_id,),
)
self._conn.commit()
return msg_id return msg_id
def get_messages(self, session_id: str) -> List[Dict[str, Any]]: def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by timestamp.""" """Load all messages for a session, ordered by timestamp."""
cursor = self._conn.execute( with self._lock:
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", cursor = self._conn.execute(
(session_id,), "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
) (session_id,),
rows = cursor.fetchall() )
rows = cursor.fetchall()
result = [] result = []
for row in rows: for row in rows:
msg = dict(row) msg = dict(row)
@@ -559,13 +577,15 @@ class SessionDB:
Load messages in the OpenAI conversation format (role + content dicts). Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history. Used by the gateway to restore conversation history.
""" """
cursor = self._conn.execute( with self._lock:
"SELECT role, content, tool_call_id, tool_calls, tool_name " cursor = self._conn.execute(
"FROM messages WHERE session_id = ? ORDER BY timestamp, id", "SELECT role, content, tool_call_id, tool_calls, tool_name "
(session_id,), "FROM messages WHERE session_id = ? ORDER BY timestamp, id",
) (session_id,),
)
rows = cursor.fetchall()
messages = [] messages = []
for row in cursor.fetchall(): for row in rows:
msg = {"role": row["role"], "content": row["content"]} msg = {"role": row["role"], "content": row["content"]}
if row["tool_call_id"]: if row["tool_call_id"]:
msg["tool_call_id"] = row["tool_call_id"] msg["tool_call_id"] = row["tool_call_id"]
@@ -675,31 +695,33 @@ class SessionDB:
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
try: with self._lock:
cursor = self._conn.execute(sql, params)
except sqlite3.OperationalError:
# FTS5 query syntax error despite sanitization — return empty
return []
matches = [dict(row) for row in cursor.fetchall()]
# Add surrounding context (1 message before + after each match)
for match in matches:
try: try:
ctx_cursor = self._conn.execute( cursor = self._conn.execute(sql, params)
"""SELECT role, content FROM messages except sqlite3.OperationalError:
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 # FTS5 query syntax error despite sanitization — return empty
ORDER BY id""", return []
(match["session_id"], match["id"], match["id"]), matches = [dict(row) for row in cursor.fetchall()]
)
context_msgs = [
{"role": r["role"], "content": (r["content"] or "")[:200]}
for r in ctx_cursor.fetchall()
]
match["context"] = context_msgs
except Exception:
match["context"] = []
# Remove full content from result (snippet is enough, saves tokens) # Add surrounding context (1 message before + after each match)
for match in matches:
try:
ctx_cursor = self._conn.execute(
"""SELECT role, content FROM messages
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
ORDER BY id""",
(match["session_id"], match["id"], match["id"]),
)
context_msgs = [
{"role": r["role"], "content": (r["content"] or "")[:200]}
for r in ctx_cursor.fetchall()
]
match["context"] = context_msgs
except Exception:
match["context"] = []
# Remove full content from result (snippet is enough, saves tokens)
for match in matches:
match.pop("content", None) match.pop("content", None)
return matches return matches

View File

@@ -69,6 +69,8 @@ class HonchoClientConfig:
workspace_id: str = "hermes" workspace_id: str = "hermes"
api_key: str | None = None api_key: str | None = None
environment: str = "production" environment: str = "production"
# Optional base URL for self-hosted Honcho (overrides environment mapping)
base_url: str | None = None
# Identity # Identity
peer_name: str | None = None peer_name: str | None = None
ai_peer: str = "hermes" ai_peer: str = "hermes"
@@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
"Install it with: pip install honcho-ai" "Install it with: pip install honcho-ai"
) )
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) # Allow config.yaml honcho.base_url to override the SDK's environment
# mapping, enabling remote self-hosted Honcho deployments without
# requiring the server to live on localhost.
resolved_base_url = config.base_url
if not resolved_base_url:
try:
from hermes_cli.config import load_config
hermes_cfg = load_config()
honcho_cfg = hermes_cfg.get("honcho", {})
if isinstance(honcho_cfg, dict):
resolved_base_url = honcho_cfg.get("base_url", "").strip() or None
except Exception:
pass
_honcho_client = Honcho( if resolved_base_url:
workspace_id=config.workspace_id, logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id)
api_key=config.api_key, else:
environment=config.environment, logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
)
kwargs: dict = {
"workspace_id": config.workspace_id,
"api_key": config.api_key,
"environment": config.environment,
}
if resolved_base_url:
kwargs["base_url"] = resolved_base_url
_honcho_client = Honcho(**kwargs)
return _honcho_client return _honcho_client

View File

@@ -407,6 +407,7 @@ class AIAgent:
# Subagent delegation state # Subagent delegation state
self._delegate_depth = 0 # 0 = top-level agent, incremented for children self._delegate_depth = 0 # 0 = top-level agent, incremented for children
self._active_children = [] # Running child AIAgents (for interrupt propagation) self._active_children = [] # Running child AIAgents (for interrupt propagation)
self._active_children_lock = threading.Lock()
# Store OpenRouter provider preferences # Store OpenRouter provider preferences
self.providers_allowed = providers_allowed self.providers_allowed = providers_allowed
@@ -1526,7 +1527,9 @@ class AIAgent:
# Signal all tools to abort any in-flight operations immediately # Signal all tools to abort any in-flight operations immediately
_set_interrupt(True) _set_interrupt(True)
# Propagate interrupt to any running child agents (subagent delegation) # Propagate interrupt to any running child agents (subagent delegation)
for child in self._active_children: with self._active_children_lock:
children_copy = list(self._active_children)
for child in children_copy:
try: try:
child.interrupt(message) child.interrupt(message)
except Exception as e: except Exception as e:

View File

@@ -24,6 +24,7 @@ def main() -> int:
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"

View File

@@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"
@@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase):
mock_instance._interrupt_requested = False mock_instance._interrupt_requested = False
mock_instance._interrupt_message = None mock_instance._interrupt_message = None
mock_instance._active_children = [] mock_instance._active_children = []
mock_instance._active_children_lock = threading.Lock()
mock_instance.quiet_mode = True mock_instance.quiet_mode = True
mock_instance.run_conversation = mock_child_run_conversation mock_instance.run_conversation = mock_child_run_conversation
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg) mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
mock_instance.tools = [] mock_instance.tools = []
MockAgent.return_value = mock_instance MockAgent.return_value = mock_instance
# Register child manually (normally done by _build_child_agent)
parent._active_children.append(mock_instance)
result = _run_single_child( result = _run_single_child(
task_index=0, task_index=0,
goal="Do something slow", goal="Do something slow",
context=None, child=mock_instance,
toolsets=["terminal"],
model=None,
max_iterations=50,
parent_agent=parent, parent_agent=parent,
task_count=1,
) )
delegate_result[0] = result delegate_result[0] = result
except Exception as e: except Exception as e:

View File

@@ -57,6 +57,7 @@ def main() -> int:
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"

View File

@@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
child = AIAgent.__new__(AIAgent) child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
parent._active_children.append(child) parent._active_children.append(child)
@@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_message = "msg" child._interrupt_message = "msg"
child.quiet_mode = True child.quiet_mode = True
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set # Global is set
set_interrupt(True) set_interrupt(True)
@@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
child.api_mode = "chat_completions" child.api_mode = "chat_completions"
child.log_prefix = "" child.log_prefix = ""
@@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
child = AIAgent.__new__(AIAgent) child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
# Register child (simulating what _run_single_child does) # Register child (simulating what _run_single_child does)

View File

@@ -47,6 +47,28 @@ class TestCLIQuickCommands:
args = cli.console.print.call_args[0][0] args = cli.console.print.call_args[0][0]
assert "no output" in args.lower() assert "no output" in args.lower()
def test_alias_command_routes_to_target(self):
"""Alias quick commands rewrite to the target command."""
cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/shortcut")
# Should recursively call process_command with /help
spy.assert_any_call("/help")
def test_alias_command_passes_args(self):
"""Alias quick commands forward user arguments to the target."""
cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/sc some args")
spy.assert_any_call("/context some args")
def test_alias_no_target_shows_error(self):
cli = self._make_cli({"broken": {"type": "alias", "target": ""}})
cli.process_command("/broken")
cli.console.print.assert_called_once()
args = cli.console.print.call_args[0][0]
assert "no target defined" in args.lower()
def test_unsupported_type_shows_error(self): def test_unsupported_type_shows_error(self):
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}}) cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
cli.process_command("/bad") cli.process_command("/bad")

View File

@@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"
@@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase):
return original_run(self_agent, *args, **kwargs) return original_run(self_agent, *args, **kwargs)
with patch.object(AIAgent, 'run_conversation', patched_run): with patch.object(AIAgent, 'run_conversation', patched_run):
# Build a real child agent (AIAgent is NOT patched here,
# only run_conversation and _build_system_prompt are)
child = AIAgent(
base_url="http://localhost:1",
api_key="test-key",
model="test/model",
provider="test",
api_mode="chat_completions",
max_iterations=5,
enabled_toolsets=["terminal"],
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
platform="cli",
)
child._delegate_depth = 1
parent._active_children.append(child)
result = _run_single_child( result = _run_single_child(
task_index=0, task_index=0,
goal="Test task", goal="Test task",
context=None, child=child,
toolsets=["terminal"],
model="test/model",
max_iterations=5,
parent_agent=parent, parent_agent=parent,
task_count=1,
override_provider="test",
override_base_url="http://localhost:1",
override_api_key="test",
override_api_mode="chat_completions",
) )
result_holder[0] = result result_holder[0] = result
except Exception as e: except Exception as e:

View File

@@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v
import json import json
import os import os
import sys import sys
import threading
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -44,6 +45,7 @@ def _make_mock_parent(depth=0):
parent._session_db = None parent._session_db = None
parent._delegate_depth = depth parent._delegate_depth = depth
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
return parent return parent
@@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase):
} }
parent = _make_mock_parent(depth=0) parent = _make_mock_parent(depth=0)
with patch("tools.delegate_tool._run_single_child") as mock_run: # Patch _build_child_agent since credentials are now passed there
# (agents are built in the main thread before being handed to workers)
with patch("tools.delegate_tool._build_child_agent") as mock_build, \
patch("tools.delegate_tool._run_single_child") as mock_run:
mock_child = MagicMock()
mock_build.return_value = mock_child
mock_run.return_value = { mock_run.return_value = {
"task_index": 0, "status": "completed", "task_index": 0, "status": "completed",
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0 "summary": "Done", "api_calls": 1, "duration_seconds": 1.0
@@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase):
tasks = [{"goal": "Task A"}, {"goal": "Task B"}] tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
delegate_task(tasks=tasks, parent_agent=parent) delegate_task(tasks=tasks, parent_agent=parent)
for call in mock_run.call_args_list: self.assertEqual(mock_build.call_count, 2)
for call in mock_build.call_args_list:
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout") self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
self.assertEqual(call.kwargs.get("override_provider"), "openrouter") self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1") self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")

View File

@@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result,
never the child's intermediate tool calls or reasoning. never the child's intermediate tool calls or reasoning.
""" """
import contextlib
import io
import json import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
import sys
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
return _callback return _callback
def _run_single_child( def _build_child_agent(
task_index: int, task_index: int,
goal: str, goal: str,
context: Optional[str], context: Optional[str],
@@ -158,16 +155,15 @@ def _run_single_child(
model: Optional[str], model: Optional[str],
max_iterations: int, max_iterations: int,
parent_agent, parent_agent,
task_count: int = 1,
# Credential overrides from delegation config (provider:model resolution) # Credential overrides from delegation config (provider:model resolution)
override_provider: Optional[str] = None, override_provider: Optional[str] = None,
override_base_url: Optional[str] = None, override_base_url: Optional[str] = None,
override_api_key: Optional[str] = None, override_api_key: Optional[str] = None,
override_api_mode: Optional[str] = None, override_api_mode: Optional[str] = None,
) -> Dict[str, Any]: ):
""" """
Spawn and run a single child agent. Called from within a thread. Build a child AIAgent on the main thread (thread-safe construction).
Returns a structured result dict. Returns the constructed child agent without running it.
When override_* params are set (from delegation config), the child uses When override_* params are set (from delegation config), the child uses
those credentials instead of inheriting from the parent. This enables those credentials instead of inheriting from the parent. This enables
@@ -176,8 +172,6 @@ def _run_single_child(
""" """
from run_agent import AIAgent from run_agent import AIAgent
child_start = time.monotonic()
# When no explicit toolsets given, inherit from parent's enabled toolsets # When no explicit toolsets given, inherit from parent's enabled toolsets
# so disabled tools (e.g. web) don't leak to subagents. # so disabled tools (e.g. web) don't leak to subagents.
if toolsets: if toolsets:
@@ -188,65 +182,84 @@ def _run_single_child(
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS) child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
child_prompt = _build_child_system_prompt(goal, context) child_prompt = _build_child_system_prompt(goal, context)
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key")
try: # Build progress callback to relay tool calls to parent display
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal). child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key")
# Build progress callback to relay tool calls to parent display # Share the parent's iteration budget so subagent tool calls
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count) # count toward the session-wide limit.
shared_budget = getattr(parent_agent, "iteration_budget", None)
# Share the parent's iteration budget so subagent tool calls # Resolve effective credentials: config override > parent inherit
# count toward the session-wide limit. effective_model = model or parent_agent.model
shared_budget = getattr(parent_agent, "iteration_budget", None) effective_provider = override_provider or getattr(parent_agent, "provider", None)
effective_base_url = override_base_url or parent_agent.base_url
effective_api_key = override_api_key or parent_api_key
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
# Resolve effective credentials: config override > parent inherit child = AIAgent(
effective_model = model or parent_agent.model base_url=effective_base_url,
effective_provider = override_provider or getattr(parent_agent, "provider", None) api_key=effective_api_key,
effective_base_url = override_base_url or parent_agent.base_url model=effective_model,
effective_api_key = override_api_key or parent_api_key provider=effective_provider,
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None) api_mode=effective_api_mode,
max_iterations=max_iterations,
max_tokens=getattr(parent_agent, "max_tokens", None),
reasoning_config=getattr(parent_agent, "reasoning_config", None),
prefill_messages=getattr(parent_agent, "prefill_messages", None),
enabled_toolsets=child_toolsets,
quiet_mode=True,
ephemeral_system_prompt=child_prompt,
log_prefix=f"[subagent-{task_index}]",
platform=parent_agent.platform,
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
provider_sort=parent_agent.provider_sort,
tool_progress_callback=child_progress_cb,
iteration_budget=shared_budget,
)
child = AIAgent( # Set delegation depth so children can't spawn grandchildren
base_url=effective_base_url, child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
api_key=effective_api_key,
model=effective_model,
provider=effective_provider,
api_mode=effective_api_mode,
max_iterations=max_iterations,
max_tokens=getattr(parent_agent, "max_tokens", None),
reasoning_config=getattr(parent_agent, "reasoning_config", None),
prefill_messages=getattr(parent_agent, "prefill_messages", None),
enabled_toolsets=child_toolsets,
quiet_mode=True,
ephemeral_system_prompt=child_prompt,
log_prefix=f"[subagent-{task_index}]",
platform=parent_agent.platform,
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
provider_sort=parent_agent.provider_sort,
tool_progress_callback=child_progress_cb,
iteration_budget=shared_budget,
)
# Set delegation depth so children can't spawn grandchildren # Register child for interrupt propagation
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 if hasattr(parent_agent, '_active_children'):
lock = getattr(parent_agent, '_active_children_lock', None)
# Register child for interrupt propagation if lock:
if hasattr(parent_agent, '_active_children'): with lock:
parent_agent._active_children.append(child)
else:
parent_agent._active_children.append(child) parent_agent._active_children.append(child)
# Run with stdout/stderr suppressed to prevent interleaved output return child
devnull = io.StringIO()
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): def _run_single_child(
result = child.run_conversation(user_message=goal) task_index: int,
goal: str,
child=None,
parent_agent=None,
**_kwargs,
) -> Dict[str, Any]:
"""
Run a pre-built child agent. Called from within a thread.
Returns a structured result dict.
"""
child_start = time.monotonic()
# Get the progress callback from the child agent
child_progress_cb = getattr(child, 'tool_progress_callback', None)
try:
result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway # Flush any remaining batched progress to gateway
if child_progress_cb and hasattr(child_progress_cb, '_flush'): if child_progress_cb and hasattr(child_progress_cb, '_flush'):
@@ -355,11 +368,15 @@ def _run_single_child(
# Unregister child from interrupt propagation # Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'): if hasattr(parent_agent, '_active_children'):
try: try:
parent_agent._active_children.remove(child) lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.remove(child)
else:
parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError) as e: except (ValueError, UnboundLocalError) as e:
logger.debug("Could not remove child from active_children: %s", e) logger.debug("Could not remove child from active_children: %s", e)
def delegate_task( def delegate_task(
goal: Optional[str] = None, goal: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
@@ -428,51 +445,38 @@ def delegate_task(
# Track goal labels for progress display (truncated for readability) # Track goal labels for progress display (truncated for readability)
task_labels = [t["goal"][:40] for t in task_list] task_labels = [t["goal"][:40] for t in task_list]
if n_tasks == 1: # Build all child agents on the main thread (thread-safe construction)
# Single task -- run directly (no thread pool overhead) children = []
t = task_list[0] for i, t in enumerate(task_list):
result = _run_single_child( child = _build_child_agent(
task_index=0, task_index=i, goal=t["goal"], context=t.get("context"),
goal=t["goal"], toolsets=t.get("toolsets") or toolsets, model=creds["model"],
context=t.get("context"), max_iterations=effective_max_iter, parent_agent=parent_agent,
toolsets=t.get("toolsets") or toolsets, override_provider=creds["provider"], override_base_url=creds["base_url"],
model=creds["model"],
max_iterations=effective_max_iter,
parent_agent=parent_agent,
task_count=1,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
override_api_key=creds["api_key"], override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"], override_api_mode=creds["api_mode"],
) )
children.append((i, t, child))
if n_tasks == 1:
# Single task -- run directly (no thread pool overhead)
_i, _t, child = children[0]
result = _run_single_child(0, _t["goal"], child, parent_agent)
results.append(result) results.append(result)
else: else:
# Batch -- run in parallel with per-task progress lines # Batch -- run in parallel with per-task progress lines
completed_count = 0 completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None) spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
# Save stdout/stderr before the executor — redirect_stdout in child
# threads races on sys.stdout and can leave it as devnull permanently.
_saved_stdout = sys.stdout
_saved_stderr = sys.stderr
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor: with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
futures = {} futures = {}
for i, t in enumerate(task_list): for i, t, child in children:
future = executor.submit( future = executor.submit(
_run_single_child, _run_single_child,
task_index=i, task_index=i,
goal=t["goal"], goal=t["goal"],
context=t.get("context"), child=child,
toolsets=t.get("toolsets") or toolsets,
model=creds["model"],
max_iterations=effective_max_iter,
parent_agent=parent_agent, parent_agent=parent_agent,
task_count=n_tasks,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"],
) )
futures[future] = i futures[future] = i
@@ -515,10 +519,6 @@ def delegate_task(
except Exception as e: except Exception as e:
logger.debug("Spinner update_text failed: %s", e) logger.debug("Spinner update_text failed: %s", e)
# Restore stdout/stderr in case redirect_stdout race left them as devnull
sys.stdout = _saved_stdout
sys.stderr = _saved_stderr
# Sort by task_index so results match input order # Sort by task_index so results match input order
results.sort(key=lambda r: r["task_index"]) results.sort(key=lambda r: r["task_index"])