diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index d008361b5..3142e4894 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings. import json import logging import os +import threading from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple @@ -1171,6 +1172,7 @@ def auxiliary_max_tokens_param(value: int) -> dict: # Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model) _client_cache: Dict[tuple, tuple] = {} +_client_cache_lock = threading.Lock() def _get_cached_client( @@ -1182,9 +1184,11 @@ def _get_cached_client( ) -> Tuple[Optional[Any], Optional[str]]: """Get or create a cached client for the given provider.""" cache_key = (provider, async_mode, base_url or "", api_key or "") - if cache_key in _client_cache: - cached_client, cached_default = _client_cache[cache_key] - return cached_client, model or cached_default + with _client_cache_lock: + if cache_key in _client_cache: + cached_client, cached_default = _client_cache[cache_key] + return cached_client, model or cached_default + # Build outside the lock client, default_model = resolve_provider_client( provider, model, @@ -1193,7 +1197,11 @@ def _get_cached_client( explicit_api_key=api_key, ) if client is not None: - _client_cache[cache_key] = (client, default_model) + with _client_cache_lock: + if cache_key not in _client_cache: + _client_cache[cache_key] = (client, default_model) + else: + client, default_model = _client_cache[cache_key] return client, model or default_model diff --git a/cli.py b/cli.py index febe32789..2b0c4ad82 100755 --- a/cli.py +++ b/cli.py @@ -3652,8 +3652,17 @@ class HermesCLI: self.console.print(f"[bold red]Quick command error: {e}[/]") else: self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]") + elif qcmd.get("type") == "alias": + target = qcmd.get("target", "").strip() + if target: + target = target if target.startswith("/") else f"/{target}" + user_args = cmd_original[len(base_cmd):].strip() + aliased_command = f"{target} {user_args}".strip() + return self.process_command(aliased_command) + else: + self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]") else: - self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]") + self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]") # Check for skill slash commands (/gif-search, /axolotl, etc.) elif base_cmd in _skill_commands: user_instruction = cmd_original[len(base_cmd):].strip() diff --git a/gateway/run.py b/gateway/run.py index f1e1be68a..7856e6a03 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1421,8 +1421,19 @@ class GatewayRunner: return f"Quick command error: {e}" else: return f"Quick command '/{command}' has no command defined." + elif qcmd.get("type") == "alias": + target = qcmd.get("target", "").strip() + if target: + target = target if target.startswith("/") else f"/{target}" + target_command = target.lstrip("/") + user_args = event.get_command_args().strip() + event.text = f"{target} {user_args}".strip() + command = target_command + # Fall through to normal command dispatch below + else: + return f"Quick command '/{command}' has no target defined." else: - return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)." + return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')." # Skill slash commands: /skill-name loads the skill and sends to agent if command: diff --git a/hermes_state.py b/hermes_state.py index 3f4715067..d0237a5bb 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -18,6 +18,7 @@ import json import os import re import sqlite3 +import threading import time from pathlib import Path from typing import Dict, Any, List, Optional @@ -104,6 +105,7 @@ class SessionDB: self.db_path = db_path or DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() self._conn = sqlite3.connect( str(self.db_path), check_same_thread=False, @@ -173,9 +175,10 @@ class SessionDB: def close(self): """Close the database connection.""" - if self._conn: - self._conn.close() - self._conn = None + with self._lock: + if self._conn: + self._conn.close() + self._conn = None # ========================================================================= # Session lifecycle @@ -192,61 +195,66 @@ class SessionDB: parent_session_id: str = None, ) -> str: """Create a new session record. Returns the session_id.""" - self._conn.execute( - """INSERT INTO sessions (id, source, user_id, model, model_config, - system_prompt, parent_session_id, started_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - source, - user_id, - model, - json.dumps(model_config) if model_config else None, - system_prompt, - parent_session_id, - time.time(), - ), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + """INSERT INTO sessions (id, source, user_id, model, model_config, + system_prompt, parent_session_id, started_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + session_id, + source, + user_id, + model, + json.dumps(model_config) if model_config else None, + system_prompt, + parent_session_id, + time.time(), + ), + ) + self._conn.commit() return session_id def end_session(self, session_id: str, end_reason: str) -> None: """Mark a session as ended.""" - self._conn.execute( - "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", - (time.time(), end_reason, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", + (time.time(), end_reason, session_id), + ) + self._conn.commit() def update_system_prompt(self, session_id: str, system_prompt: str) -> None: """Store the full assembled system prompt snapshot.""" - self._conn.execute( - "UPDATE sessions SET system_prompt = ? WHERE id = ?", - (system_prompt, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + "UPDATE sessions SET system_prompt = ? WHERE id = ?", + (system_prompt, session_id), + ) + self._conn.commit() def update_token_counts( self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, model: str = None, ) -> None: """Increment token counters and backfill model if not already set.""" - self._conn.execute( - """UPDATE sessions SET - input_tokens = input_tokens + ?, - output_tokens = output_tokens + ?, - model = COALESCE(model, ?) - WHERE id = ?""", - (input_tokens, output_tokens, model, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + """UPDATE sessions SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + model = COALESCE(model, ?) + WHERE id = ?""", + (input_tokens, output_tokens, model, session_id), + ) + self._conn.commit() def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM sessions WHERE id = ?", (session_id,) + ) + row = cursor.fetchone() return dict(row) if row else None def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: @@ -331,38 +339,42 @@ class SessionDB: Empty/whitespace-only strings are normalized to None (clearing the title). """ title = self.sanitize_title(title) - if title: - # Check uniqueness (allow the same session to keep its own title) + with self._lock: + if title: + # Check uniqueness (allow the same session to keep its own title) + cursor = self._conn.execute( + "SELECT id FROM sessions WHERE title = ? AND id != ?", + (title, session_id), + ) + conflict = cursor.fetchone() + if conflict: + raise ValueError( + f"Title '{title}' is already in use by session {conflict['id']}" + ) cursor = self._conn.execute( - "SELECT id FROM sessions WHERE title = ? AND id != ?", + "UPDATE sessions SET title = ? WHERE id = ?", (title, session_id), ) - conflict = cursor.fetchone() - if conflict: - raise ValueError( - f"Title '{title}' is already in use by session {conflict['id']}" - ) - cursor = self._conn.execute( - "UPDATE sessions SET title = ? WHERE id = ?", - (title, session_id), - ) - self._conn.commit() - return cursor.rowcount > 0 + self._conn.commit() + rowcount = cursor.rowcount + return rowcount > 0 def get_session_title(self, session_id: str) -> Optional[str]: """Get the title for a session, or None.""" - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT title FROM sessions WHERE id = ?", (session_id,) + ) + row = cursor.fetchone() return row["title"] if row else None def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: """Look up a session by exact title. Returns session dict or None.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE title = ?", (title,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM sessions WHERE title = ?", (title,) + ) + row = cursor.fetchone() return dict(row) if row else None def resolve_session_by_title(self, title: str) -> Optional[str]: @@ -379,12 +391,13 @@ class SessionDB: # Also search for numbered variants: "title #2", "title #3", etc. # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - cursor = self._conn.execute( - "SELECT id, title, started_at FROM sessions " - "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", - (f"{escaped} #%",), - ) - numbered = cursor.fetchall() + with self._lock: + cursor = self._conn.execute( + "SELECT id, title, started_at FROM sessions " + "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", + (f"{escaped} #%",), + ) + numbered = cursor.fetchall() if numbered: # Return the most recent numbered variant @@ -409,11 +422,12 @@ class SessionDB: # Find all existing numbered variants # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", - (base, f"{escaped} #%"), - ) - existing = [row["title"] for row in cursor.fetchall()] + with self._lock: + cursor = self._conn.execute( + "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", + (base, f"{escaped} #%"), + ) + existing = [row["title"] for row in cursor.fetchall()] if not existing: return base # No conflict, use the base name as-is @@ -461,9 +475,11 @@ class SessionDB: LIMIT ? OFFSET ? """ params = (source, limit, offset) if source else (limit, offset) - cursor = self._conn.execute(query, params) + with self._lock: + cursor = self._conn.execute(query, params) + rows = cursor.fetchall() sessions = [] - for row in cursor.fetchall(): + for row in rows: s = dict(row) # Build the preview from the raw substring raw = s.pop("_preview_raw", "").strip() @@ -497,52 +513,54 @@ class SessionDB: Also increments the session's message_count (and tool_call_count if role is 'tool' or tool_calls is present). """ - cursor = self._conn.execute( - """INSERT INTO messages (session_id, role, content, tool_call_id, - tool_calls, tool_name, timestamp, token_count, finish_reason) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - role, - content, - tool_call_id, - json.dumps(tool_calls) if tool_calls else None, - tool_name, - time.time(), - token_count, - finish_reason, - ), - ) - msg_id = cursor.lastrowid - - # Update counters - # Count actual tool calls from the tool_calls list (not from tool responses). - # A single assistant message can contain multiple parallel tool calls. - num_tool_calls = 0 - if tool_calls is not None: - num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 - if num_tool_calls > 0: - self._conn.execute( - """UPDATE sessions SET message_count = message_count + 1, - tool_call_count = tool_call_count + ? WHERE id = ?""", - (num_tool_calls, session_id), - ) - else: - self._conn.execute( - "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", - (session_id,), + with self._lock: + cursor = self._conn.execute( + """INSERT INTO messages (session_id, role, content, tool_call_id, + tool_calls, tool_name, timestamp, token_count, finish_reason) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + session_id, + role, + content, + tool_call_id, + json.dumps(tool_calls) if tool_calls else None, + tool_name, + time.time(), + token_count, + finish_reason, + ), ) + msg_id = cursor.lastrowid - self._conn.commit() + # Update counters + # Count actual tool calls from the tool_calls list (not from tool responses). + # A single assistant message can contain multiple parallel tool calls. + num_tool_calls = 0 + if tool_calls is not None: + num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 + if num_tool_calls > 0: + self._conn.execute( + """UPDATE sessions SET message_count = message_count + 1, + tool_call_count = tool_call_count + ? WHERE id = ?""", + (num_tool_calls, session_id), + ) + else: + self._conn.execute( + "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", + (session_id,), + ) + + self._conn.commit() return msg_id def get_messages(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" - cursor = self._conn.execute( - "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", - (session_id,), - ) - rows = cursor.fetchall() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", + (session_id,), + ) + rows = cursor.fetchall() result = [] for row in rows: msg = dict(row) @@ -559,13 +577,15 @@ class SessionDB: Load messages in the OpenAI conversation format (role + content dicts). Used by the gateway to restore conversation history. """ - cursor = self._conn.execute( - "SELECT role, content, tool_call_id, tool_calls, tool_name " - "FROM messages WHERE session_id = ? ORDER BY timestamp, id", - (session_id,), - ) + with self._lock: + cursor = self._conn.execute( + "SELECT role, content, tool_call_id, tool_calls, tool_name " + "FROM messages WHERE session_id = ? ORDER BY timestamp, id", + (session_id,), + ) + rows = cursor.fetchall() messages = [] - for row in cursor.fetchall(): + for row in rows: msg = {"role": row["role"], "content": row["content"]} if row["tool_call_id"]: msg["tool_call_id"] = row["tool_call_id"] @@ -675,31 +695,33 @@ class SessionDB: LIMIT ? OFFSET ? """ - try: - cursor = self._conn.execute(sql, params) - except sqlite3.OperationalError: - # FTS5 query syntax error despite sanitization — return empty - return [] - matches = [dict(row) for row in cursor.fetchall()] - - # Add surrounding context (1 message before + after each match) - for match in matches: + with self._lock: try: - ctx_cursor = self._conn.execute( - """SELECT role, content FROM messages - WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 - ORDER BY id""", - (match["session_id"], match["id"], match["id"]), - ) - context_msgs = [ - {"role": r["role"], "content": (r["content"] or "")[:200]} - for r in ctx_cursor.fetchall() - ] - match["context"] = context_msgs - except Exception: - match["context"] = [] + cursor = self._conn.execute(sql, params) + except sqlite3.OperationalError: + # FTS5 query syntax error despite sanitization — return empty + return [] + matches = [dict(row) for row in cursor.fetchall()] - # Remove full content from result (snippet is enough, saves tokens) + # Add surrounding context (1 message before + after each match) + for match in matches: + try: + ctx_cursor = self._conn.execute( + """SELECT role, content FROM messages + WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 + ORDER BY id""", + (match["session_id"], match["id"], match["id"]), + ) + context_msgs = [ + {"role": r["role"], "content": (r["content"] or "")[:200]} + for r in ctx_cursor.fetchall() + ] + match["context"] = context_msgs + except Exception: + match["context"] = [] + + # Remove full content from result (snippet is enough, saves tokens) + for match in matches: match.pop("content", None) return matches diff --git a/honcho_integration/client.py b/honcho_integration/client.py index ccc2f6f25..759576ada 100644 --- a/honcho_integration/client.py +++ b/honcho_integration/client.py @@ -69,6 +69,8 @@ class HonchoClientConfig: workspace_id: str = "hermes" api_key: str | None = None environment: str = "production" + # Optional base URL for self-hosted Honcho (overrides environment mapping) + base_url: str | None = None # Identity peer_name: str | None = None ai_peer: str = "hermes" @@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho: "Install it with: pip install honcho-ai" ) - logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) + # Allow config.yaml honcho.base_url to override the SDK's environment + # mapping, enabling remote self-hosted Honcho deployments without + # requiring the server to live on localhost. + resolved_base_url = config.base_url + if not resolved_base_url: + try: + from hermes_cli.config import load_config + hermes_cfg = load_config() + honcho_cfg = hermes_cfg.get("honcho", {}) + if isinstance(honcho_cfg, dict): + resolved_base_url = honcho_cfg.get("base_url", "").strip() or None + except Exception: + pass - _honcho_client = Honcho( - workspace_id=config.workspace_id, - api_key=config.api_key, - environment=config.environment, - ) + if resolved_base_url: + logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id) + else: + logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) + + kwargs: dict = { + "workspace_id": config.workspace_id, + "api_key": config.api_key, + "environment": config.environment, + } + if resolved_base_url: + kwargs["base_url"] = resolved_base_url + + _honcho_client = Honcho(**kwargs) return _honcho_client diff --git a/run_agent.py b/run_agent.py index 2c8fad0b8..e8bf35c47 100644 --- a/run_agent.py +++ b/run_agent.py @@ -407,6 +407,7 @@ class AIAgent: # Subagent delegation state self._delegate_depth = 0 # 0 = top-level agent, incremented for children self._active_children = [] # Running child AIAgents (for interrupt propagation) + self._active_children_lock = threading.Lock() # Store OpenRouter provider preferences self.providers_allowed = providers_allowed @@ -1526,7 +1527,9 @@ class AIAgent: # Signal all tools to abort any in-flight operations immediately _set_interrupt(True) # Propagate interrupt to any running child agents (subagent delegation) - for child in self._active_children: + with self._active_children_lock: + children_copy = list(self._active_children) + for child in children_copy: try: child.interrupt(message) except Exception as e: diff --git a/tests/run_interrupt_test.py b/tests/run_interrupt_test.py index 845060ffa..a539c6ca9 100644 --- a/tests/run_interrupt_test.py +++ b/tests/run_interrupt_test.py @@ -24,6 +24,7 @@ def main() -> int: parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" diff --git a/tests/test_cli_interrupt_subagent.py b/tests/test_cli_interrupt_subagent.py index b91a7b654..f4322ea6b 100644 --- a/tests/test_cli_interrupt_subagent.py +++ b/tests/test_cli_interrupt_subagent.py @@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" @@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase): mock_instance._interrupt_requested = False mock_instance._interrupt_message = None mock_instance._active_children = [] + mock_instance._active_children_lock = threading.Lock() mock_instance.quiet_mode = True mock_instance.run_conversation = mock_child_run_conversation mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg) mock_instance.tools = [] MockAgent.return_value = mock_instance - + + # Register child manually (normally done by _build_child_agent) + parent._active_children.append(mock_instance) + result = _run_single_child( task_index=0, goal="Do something slow", - context=None, - toolsets=["terminal"], - model=None, - max_iterations=50, + child=mock_instance, parent_agent=parent, - task_count=1, ) delegate_result[0] = result except Exception as e: diff --git a/tests/test_interactive_interrupt.py b/tests/test_interactive_interrupt.py index c01404e1c..8c0d328c2 100644 --- a/tests/test_interactive_interrupt.py +++ b/tests/test_interactive_interrupt.py @@ -57,6 +57,7 @@ def main() -> int: parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" diff --git a/tests/test_interrupt_propagation.py b/tests/test_interrupt_propagation.py index ff1cafdc8..7f8cb01c3 100644 --- a/tests/test_interrupt_propagation.py +++ b/tests/test_interrupt_propagation.py @@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True child = AIAgent.__new__(AIAgent) child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True parent._active_children.append(child) @@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase): child._interrupt_message = "msg" child.quiet_mode = True child._active_children = [] + child._active_children_lock = threading.Lock() # Global is set set_interrupt(True) @@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase): child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True child.api_mode = "chat_completions" child.log_prefix = "" @@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True child = AIAgent.__new__(AIAgent) child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True # Register child (simulating what _run_single_child does) diff --git a/tests/test_quick_commands.py b/tests/test_quick_commands.py index 9708b1fb3..7a89d4ca2 100644 --- a/tests/test_quick_commands.py +++ b/tests/test_quick_commands.py @@ -47,6 +47,28 @@ class TestCLIQuickCommands: args = cli.console.print.call_args[0][0] assert "no output" in args.lower() + def test_alias_command_routes_to_target(self): + """Alias quick commands rewrite to the target command.""" + cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}}) + with patch.object(cli, "process_command", wraps=cli.process_command) as spy: + cli.process_command("/shortcut") + # Should recursively call process_command with /help + spy.assert_any_call("/help") + + def test_alias_command_passes_args(self): + """Alias quick commands forward user arguments to the target.""" + cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}}) + with patch.object(cli, "process_command", wraps=cli.process_command) as spy: + cli.process_command("/sc some args") + spy.assert_any_call("/context some args") + + def test_alias_no_target_shows_error(self): + cli = self._make_cli({"broken": {"type": "alias", "target": ""}}) + cli.process_command("/broken") + cli.console.print.assert_called_once() + args = cli.console.print.call_args[0][0] + assert "no target defined" in args.lower() + def test_unsupported_type_shows_error(self): cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}}) cli.process_command("/bad") diff --git a/tests/test_real_interrupt_subagent.py b/tests/test_real_interrupt_subagent.py index f1a16753a..e0e681cdf 100644 --- a/tests/test_real_interrupt_subagent.py +++ b/tests/test_real_interrupt_subagent.py @@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" @@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase): return original_run(self_agent, *args, **kwargs) with patch.object(AIAgent, 'run_conversation', patched_run): + # Build a real child agent (AIAgent is NOT patched here, + # only run_conversation and _build_system_prompt are) + child = AIAgent( + base_url="http://localhost:1", + api_key="test-key", + model="test/model", + provider="test", + api_mode="chat_completions", + max_iterations=5, + enabled_toolsets=["terminal"], + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + platform="cli", + ) + child._delegate_depth = 1 + parent._active_children.append(child) result = _run_single_child( task_index=0, goal="Test task", - context=None, - toolsets=["terminal"], - model="test/model", - max_iterations=5, + child=child, parent_agent=parent, - task_count=1, - override_provider="test", - override_base_url="http://localhost:1", - override_api_key="test", - override_api_mode="chat_completions", ) result_holder[0] = result except Exception as e: diff --git a/tests/tools/test_delegate.py b/tests/tools/test_delegate.py index a29560b2c..476a2401b 100644 --- a/tests/tools/test_delegate.py +++ b/tests/tools/test_delegate.py @@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v import json import os import sys +import threading import unittest from unittest.mock import MagicMock, patch @@ -44,6 +45,7 @@ def _make_mock_parent(depth=0): parent._session_db = None parent._delegate_depth = depth parent._active_children = [] + parent._active_children_lock = threading.Lock() return parent @@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase): } parent = _make_mock_parent(depth=0) - with patch("tools.delegate_tool._run_single_child") as mock_run: + # Patch _build_child_agent since credentials are now passed there + # (agents are built in the main thread before being handed to workers) + with patch("tools.delegate_tool._build_child_agent") as mock_build, \ + patch("tools.delegate_tool._run_single_child") as mock_run: + mock_child = MagicMock() + mock_build.return_value = mock_child mock_run.return_value = { "task_index": 0, "status": "completed", "summary": "Done", "api_calls": 1, "duration_seconds": 1.0 @@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase): tasks = [{"goal": "Task A"}, {"goal": "Task B"}] delegate_task(tasks=tasks, parent_agent=parent) - for call in mock_run.call_args_list: + self.assertEqual(mock_build.call_count, 2) + for call in mock_build.call_args_list: self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout") self.assertEqual(call.kwargs.get("override_provider"), "openrouter") self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1") diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index 1ac75ea88..2ef505dab 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result, never the child's intermediate tool calls or reasoning. """ -import contextlib -import io import json import logging logger = logging.getLogger(__name__) import os -import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional @@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in return _callback -def _run_single_child( +def _build_child_agent( task_index: int, goal: str, context: Optional[str], @@ -158,16 +155,15 @@ def _run_single_child( model: Optional[str], max_iterations: int, parent_agent, - task_count: int = 1, # Credential overrides from delegation config (provider:model resolution) override_provider: Optional[str] = None, override_base_url: Optional[str] = None, override_api_key: Optional[str] = None, override_api_mode: Optional[str] = None, -) -> Dict[str, Any]: +): """ - Spawn and run a single child agent. Called from within a thread. - Returns a structured result dict. + Build a child AIAgent on the main thread (thread-safe construction). + Returns the constructed child agent without running it. When override_* params are set (from delegation config), the child uses those credentials instead of inheriting from the parent. This enables @@ -176,8 +172,6 @@ def _run_single_child( """ from run_agent import AIAgent - child_start = time.monotonic() - # When no explicit toolsets given, inherit from parent's enabled toolsets # so disabled tools (e.g. web) don't leak to subagents. if toolsets: @@ -188,65 +182,84 @@ def _run_single_child( child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS) child_prompt = _build_child_system_prompt(goal, context) + # Extract parent's API key so subagents inherit auth (e.g. Nous Portal). + parent_api_key = getattr(parent_agent, "api_key", None) + if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"): + parent_api_key = parent_agent._client_kwargs.get("api_key") - try: - # Extract parent's API key so subagents inherit auth (e.g. Nous Portal). - parent_api_key = getattr(parent_agent, "api_key", None) - if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"): - parent_api_key = parent_agent._client_kwargs.get("api_key") + # Build progress callback to relay tool calls to parent display + child_progress_cb = _build_child_progress_callback(task_index, parent_agent) - # Build progress callback to relay tool calls to parent display - child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count) + # Share the parent's iteration budget so subagent tool calls + # count toward the session-wide limit. + shared_budget = getattr(parent_agent, "iteration_budget", None) - # Share the parent's iteration budget so subagent tool calls - # count toward the session-wide limit. - shared_budget = getattr(parent_agent, "iteration_budget", None) + # Resolve effective credentials: config override > parent inherit + effective_model = model or parent_agent.model + effective_provider = override_provider or getattr(parent_agent, "provider", None) + effective_base_url = override_base_url or parent_agent.base_url + effective_api_key = override_api_key or parent_api_key + effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None) - # Resolve effective credentials: config override > parent inherit - effective_model = model or parent_agent.model - effective_provider = override_provider or getattr(parent_agent, "provider", None) - effective_base_url = override_base_url or parent_agent.base_url - effective_api_key = override_api_key or parent_api_key - effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None) + child = AIAgent( + base_url=effective_base_url, + api_key=effective_api_key, + model=effective_model, + provider=effective_provider, + api_mode=effective_api_mode, + max_iterations=max_iterations, + max_tokens=getattr(parent_agent, "max_tokens", None), + reasoning_config=getattr(parent_agent, "reasoning_config", None), + prefill_messages=getattr(parent_agent, "prefill_messages", None), + enabled_toolsets=child_toolsets, + quiet_mode=True, + ephemeral_system_prompt=child_prompt, + log_prefix=f"[subagent-{task_index}]", + platform=parent_agent.platform, + skip_context_files=True, + skip_memory=True, + clarify_callback=None, + session_db=getattr(parent_agent, '_session_db', None), + providers_allowed=parent_agent.providers_allowed, + providers_ignored=parent_agent.providers_ignored, + providers_order=parent_agent.providers_order, + provider_sort=parent_agent.provider_sort, + tool_progress_callback=child_progress_cb, + iteration_budget=shared_budget, + ) - child = AIAgent( - base_url=effective_base_url, - api_key=effective_api_key, - model=effective_model, - provider=effective_provider, - api_mode=effective_api_mode, - max_iterations=max_iterations, - max_tokens=getattr(parent_agent, "max_tokens", None), - reasoning_config=getattr(parent_agent, "reasoning_config", None), - prefill_messages=getattr(parent_agent, "prefill_messages", None), - enabled_toolsets=child_toolsets, - quiet_mode=True, - ephemeral_system_prompt=child_prompt, - log_prefix=f"[subagent-{task_index}]", - platform=parent_agent.platform, - skip_context_files=True, - skip_memory=True, - clarify_callback=None, - session_db=getattr(parent_agent, '_session_db', None), - providers_allowed=parent_agent.providers_allowed, - providers_ignored=parent_agent.providers_ignored, - providers_order=parent_agent.providers_order, - provider_sort=parent_agent.provider_sort, - tool_progress_callback=child_progress_cb, - iteration_budget=shared_budget, - ) + # Set delegation depth so children can't spawn grandchildren + child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 - # Set delegation depth so children can't spawn grandchildren - child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 - - # Register child for interrupt propagation - if hasattr(parent_agent, '_active_children'): + # Register child for interrupt propagation + if hasattr(parent_agent, '_active_children'): + lock = getattr(parent_agent, '_active_children_lock', None) + if lock: + with lock: + parent_agent._active_children.append(child) + else: parent_agent._active_children.append(child) - # Run with stdout/stderr suppressed to prevent interleaved output - devnull = io.StringIO() - with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): - result = child.run_conversation(user_message=goal) + return child + +def _run_single_child( + task_index: int, + goal: str, + child=None, + parent_agent=None, + **_kwargs, +) -> Dict[str, Any]: + """ + Run a pre-built child agent. Called from within a thread. + Returns a structured result dict. + """ + child_start = time.monotonic() + + # Get the progress callback from the child agent + child_progress_cb = getattr(child, 'tool_progress_callback', None) + + try: + result = child.run_conversation(user_message=goal) # Flush any remaining batched progress to gateway if child_progress_cb and hasattr(child_progress_cb, '_flush'): @@ -355,11 +368,15 @@ def _run_single_child( # Unregister child from interrupt propagation if hasattr(parent_agent, '_active_children'): try: - parent_agent._active_children.remove(child) + lock = getattr(parent_agent, '_active_children_lock', None) + if lock: + with lock: + parent_agent._active_children.remove(child) + else: + parent_agent._active_children.remove(child) except (ValueError, UnboundLocalError) as e: logger.debug("Could not remove child from active_children: %s", e) - def delegate_task( goal: Optional[str] = None, context: Optional[str] = None, @@ -428,51 +445,38 @@ def delegate_task( # Track goal labels for progress display (truncated for readability) task_labels = [t["goal"][:40] for t in task_list] - if n_tasks == 1: - # Single task -- run directly (no thread pool overhead) - t = task_list[0] - result = _run_single_child( - task_index=0, - goal=t["goal"], - context=t.get("context"), - toolsets=t.get("toolsets") or toolsets, - model=creds["model"], - max_iterations=effective_max_iter, - parent_agent=parent_agent, - task_count=1, - override_provider=creds["provider"], - override_base_url=creds["base_url"], + # Build all child agents on the main thread (thread-safe construction) + children = [] + for i, t in enumerate(task_list): + child = _build_child_agent( + task_index=i, goal=t["goal"], context=t.get("context"), + toolsets=t.get("toolsets") or toolsets, model=creds["model"], + max_iterations=effective_max_iter, parent_agent=parent_agent, + override_provider=creds["provider"], override_base_url=creds["base_url"], override_api_key=creds["api_key"], override_api_mode=creds["api_mode"], ) + children.append((i, t, child)) + + if n_tasks == 1: + # Single task -- run directly (no thread pool overhead) + _i, _t, child = children[0] + result = _run_single_child(0, _t["goal"], child, parent_agent) results.append(result) else: # Batch -- run in parallel with per-task progress lines completed_count = 0 spinner_ref = getattr(parent_agent, '_delegate_spinner', None) - # Save stdout/stderr before the executor — redirect_stdout in child - # threads races on sys.stdout and can leave it as devnull permanently. - _saved_stdout = sys.stdout - _saved_stderr = sys.stderr - with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor: futures = {} - for i, t in enumerate(task_list): + for i, t, child in children: future = executor.submit( _run_single_child, task_index=i, goal=t["goal"], - context=t.get("context"), - toolsets=t.get("toolsets") or toolsets, - model=creds["model"], - max_iterations=effective_max_iter, + child=child, parent_agent=parent_agent, - task_count=n_tasks, - override_provider=creds["provider"], - override_base_url=creds["base_url"], - override_api_key=creds["api_key"], - override_api_mode=creds["api_mode"], ) futures[future] = i @@ -515,10 +519,6 @@ def delegate_task( except Exception as e: logger.debug("Spinner update_text failed: %s", e) - # Restore stdout/stderr in case redirect_stdout race left them as devnull - sys.stdout = _saved_stdout - sys.stderr = _saved_stderr - # Sort by task_index so results match input order results.sort(key=lambda r: r["task_index"])