diff --git a/agent/__init__.py b/agent/__init__.py new file mode 100644 index 000000000..aaa2d74d1 --- /dev/null +++ b/agent/__init__.py @@ -0,0 +1,6 @@ +"""Agent internals -- extracted modules from run_agent.py. + +These modules contain pure utility functions and self-contained classes +that were previously embedded in the 3,600-line run_agent.py. Extracting +them makes run_agent.py focused on the AIAgent orchestrator class. +""" diff --git a/agent/context_compressor.py b/agent/context_compressor.py new file mode 100644 index 000000000..7a8225cbb --- /dev/null +++ b/agent/context_compressor.py @@ -0,0 +1,182 @@ +"""Automatic context window compression for long conversations. + +Self-contained class with its own OpenAI client for summarization. +Uses Gemini Flash (cheap/fast) to summarize middle turns while +protecting head and tail context. +""" + +import logging +import os +from typing import Any, Dict, List + +from openai import OpenAI + +from agent.model_metadata import ( + get_model_context_length, + estimate_messages_tokens_rough, +) +from hermes_constants import OPENROUTER_BASE_URL + +logger = logging.getLogger(__name__) + + +class ContextCompressor: + """Compresses conversation context when approaching the model's context limit. + + Algorithm: protect first N + last N turns, summarize everything in between. + Token tracking uses actual counts from API responses for accuracy. + """ + + def __init__( + self, + model: str, + threshold_percent: float = 0.85, + summary_model: str = "google/gemini-3-flash-preview", + protect_first_n: int = 3, + protect_last_n: int = 4, + summary_target_tokens: int = 500, + quiet_mode: bool = False, + ): + self.model = model + self.threshold_percent = threshold_percent + self.summary_model = summary_model + self.protect_first_n = protect_first_n + self.protect_last_n = protect_last_n + self.summary_target_tokens = summary_target_tokens + self.quiet_mode = quiet_mode + + self.context_length = get_model_context_length(model) + self.threshold_tokens = int(self.context_length * threshold_percent) + self.compression_count = 0 + + self.last_prompt_tokens = 0 + self.last_completion_tokens = 0 + self.last_total_tokens = 0 + + api_key = os.getenv("OPENROUTER_API_KEY", "") + self.client = OpenAI(api_key=api_key, base_url=OPENROUTER_BASE_URL) if api_key else None + + def update_from_response(self, usage: Dict[str, Any]): + """Update tracked token usage from API response.""" + self.last_prompt_tokens = usage.get("prompt_tokens", 0) + self.last_completion_tokens = usage.get("completion_tokens", 0) + self.last_total_tokens = usage.get("total_tokens", 0) + + def should_compress(self, prompt_tokens: int = None) -> bool: + """Check if context exceeds the compression threshold.""" + tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens + return tokens >= self.threshold_tokens + + def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool: + """Quick pre-flight check using rough estimate (before API call).""" + rough_estimate = estimate_messages_tokens_rough(messages) + return rough_estimate >= self.threshold_tokens + + def get_status(self) -> Dict[str, Any]: + """Get current compression status for display/logging.""" + return { + "last_prompt_tokens": self.last_prompt_tokens, + "threshold_tokens": self.threshold_tokens, + "context_length": self.context_length, + "usage_percent": (self.last_prompt_tokens / self.context_length * 100) if self.context_length else 0, + "compression_count": self.compression_count, + } + + def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> str: + """Generate a concise summary of conversation turns using a fast model.""" + if not self.client: + return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed to save space. The assistant performed various actions and received responses." + + parts = [] + for msg in turns_to_summarize: + role = msg.get("role", "unknown") + content = msg.get("content", "") + if len(content) > 2000: + content = content[:1000] + "\n...[truncated]...\n" + content[-500:] + tool_calls = msg.get("tool_calls", []) + if tool_calls: + tool_names = [tc.get("function", {}).get("name", "?") for tc in tool_calls if isinstance(tc, dict)] + content += f"\n[Tool calls: {', '.join(tool_names)}]" + parts.append(f"[{role.upper()}]: {content}") + + content_to_summarize = "\n\n".join(parts) + prompt = f"""Summarize these conversation turns concisely. This summary will replace these turns in the conversation history. + +Write from a neutral perspective describing: +1. What actions were taken (tool calls, searches, file operations) +2. Key information or results obtained +3. Important decisions or findings +4. Relevant data, file names, or outputs + +Keep factual and informative. Target ~{self.summary_target_tokens} tokens. + +--- +TURNS TO SUMMARIZE: +{content_to_summarize} +--- + +Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" + + try: + response = self.client.chat.completions.create( + model=self.summary_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.3, + max_tokens=self.summary_target_tokens * 2, + timeout=30.0, + ) + summary = response.choices[0].message.content.strip() + if not summary.startswith("[CONTEXT SUMMARY]:"): + summary = "[CONTEXT SUMMARY]: " + summary + return summary + except Exception as e: + logging.warning(f"Failed to generate context summary: {e}") + return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed. The assistant performed tool calls and received responses." + + def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]: + """Compress conversation messages by summarizing middle turns. + + Keeps first N + last N turns, summarizes everything in between. + """ + n_messages = len(messages) + if n_messages <= self.protect_first_n + self.protect_last_n + 1: + if not self.quiet_mode: + print(f"āš ļø Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})") + return messages + + compress_start = self.protect_first_n + compress_end = n_messages - self.protect_last_n + if compress_start >= compress_end: + return messages + + turns_to_summarize = messages[compress_start:compress_end] + display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages) + + if not self.quiet_mode: + print(f"\nšŸ“¦ Context compression triggered ({display_tokens:,} tokens ≄ {self.threshold_tokens:,} threshold)") + print(f" šŸ“Š Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})") + print(f" šŸ—œļø Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)") + + summary = self._generate_summary(turns_to_summarize) + + compressed = [] + for i in range(compress_start): + msg = messages[i].copy() + if i == 0 and msg.get("role") == "system" and self.compression_count == 0: + msg["content"] = msg.get("content", "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]" + compressed.append(msg) + + compressed.append({"role": "user", "content": summary}) + + for i in range(compress_end, n_messages): + compressed.append(messages[i].copy()) + + self.compression_count += 1 + + if not self.quiet_mode: + new_estimate = estimate_messages_tokens_rough(compressed) + saved_estimate = display_tokens - new_estimate + print(f" āœ… Compressed: {n_messages} → {len(compressed)} messages (~{saved_estimate:,} tokens saved)") + print(f" šŸ’” Compression #{self.compression_count} complete") + + return compressed diff --git a/agent/display.py b/agent/display.py new file mode 100644 index 000000000..bed75e306 --- /dev/null +++ b/agent/display.py @@ -0,0 +1,379 @@ +"""CLI presentation -- spinner, kawaii faces, tool preview formatting. + +Pure display functions and classes with no AIAgent dependency. +Used by AIAgent._execute_tool_calls for CLI feedback. +""" + +import os +import random +import threading +import time + + +# ========================================================================= +# Tool preview (one-line summary of a tool call's primary argument) +# ========================================================================= + +def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: + """Build a short preview of a tool call's primary argument for display.""" + primary_args = { + "terminal": "command", "web_search": "query", "web_extract": "urls", + "read_file": "path", "write_file": "path", "patch": "path", + "search_files": "pattern", "browser_navigate": "url", + "browser_click": "ref", "browser_type": "text", + "image_generate": "prompt", "text_to_speech": "text", + "vision_analyze": "question", "mixture_of_agents": "user_prompt", + "skill_view": "name", "skills_list": "category", + "schedule_cronjob": "name", + } + + if tool_name == "process": + action = args.get("action", "") + sid = args.get("session_id", "") + data = args.get("data", "") + timeout_val = args.get("timeout") + parts = [action] + if sid: + parts.append(sid[:16]) + if data: + parts.append(f'"{data[:20]}"') + if timeout_val and action == "wait": + parts.append(f"{timeout_val}s") + return " ".join(parts) if parts else None + + if tool_name == "todo": + todos_arg = args.get("todos") + merge = args.get("merge", False) + if todos_arg is None: + return "reading task list" + elif merge: + return f"updating {len(todos_arg)} task(s)" + else: + return f"planning {len(todos_arg)} task(s)" + + if tool_name == "session_search": + query = args.get("query", "") + return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\"" + + if tool_name == "memory": + action = args.get("action", "") + target = args.get("target", "") + if action == "add": + content = args.get("content", "") + return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\"" + elif action == "replace": + return f"~{target}: \"{args.get('old_text', '')[:20]}\"" + elif action == "remove": + return f"-{target}: \"{args.get('old_text', '')[:20]}\"" + return action + + if tool_name == "send_message": + target = args.get("target", "?") + msg = args.get("message", "") + if len(msg) > 20: + msg = msg[:17] + "..." + return f"to {target}: \"{msg}\"" + + if tool_name.startswith("rl_"): + rl_previews = { + "rl_list_environments": "listing envs", + "rl_select_environment": args.get("name", ""), + "rl_get_current_config": "reading config", + "rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}", + "rl_start_training": "starting", + "rl_check_status": args.get("run_id", "")[:16], + "rl_stop_training": f"stopping {args.get('run_id', '')[:16]}", + "rl_get_results": args.get("run_id", "")[:16], + "rl_list_runs": "listing runs", + "rl_test_inference": f"{args.get('num_steps', 3)} steps", + } + return rl_previews.get(tool_name) + + key = primary_args.get(tool_name) + if not key: + for fallback_key in ("query", "text", "command", "path", "name", "prompt"): + if fallback_key in args: + key = fallback_key + break + + if not key or key not in args: + return None + + value = args[key] + if isinstance(value, list): + value = value[0] if value else "" + + preview = str(value).strip() + if not preview: + return None + if len(preview) > max_len: + preview = preview[:max_len - 3] + "..." + return preview + + +# ========================================================================= +# KawaiiSpinner +# ========================================================================= + +class KawaiiSpinner: + """Animated spinner with kawaii faces for CLI feedback during tool execution.""" + + SPINNERS = { + 'dots': ['ā ‹', 'ā ™', 'ā ¹', 'ā ø', 'ā ¼', 'ā “', 'ā ¦', 'ā §', 'ā ‡', 'ā '], + 'bounce': ['⠁', 'ā ‚', 'ā „', '─', '⢀', 'ā  ', '⠐', '⠈'], + 'grow': ['▁', 'ā–‚', 'ā–ƒ', 'ā–„', 'ā–…', 'ā–†', 'ā–‡', 'ā–ˆ', 'ā–‡', 'ā–†', 'ā–…', 'ā–„', 'ā–ƒ', 'ā–‚'], + 'arrows': ['←', '↖', '↑', '↗', '→', 'ā†˜', '↓', '↙'], + 'star': ['✶', '✷', '✸', '✹', '✺', '✹', '✸', '✷'], + 'moon': ['šŸŒ‘', 'šŸŒ’', 'šŸŒ“', 'šŸŒ”', 'šŸŒ•', 'šŸŒ–', 'šŸŒ—', '🌘'], + 'pulse': ['ā—œ', 'ā— ', 'ā—', 'ā—ž', 'ā—”', 'ā—Ÿ'], + 'brain': ['🧠', 'šŸ’­', 'šŸ’”', '✨', 'šŸ’«', '🌟', 'šŸ’”', 'šŸ’­'], + 'sparkle': ['⁺', '˚', '*', '✧', '✦', '✧', '*', '˚'], + } + + KAWAII_WAITING = [ + "(t◕‿◕t)", "(ā—•ā€æā—•āœæ)", "Ł©(◕‿◕t)Ū¶", "(āœæā— ā€æā— )", "( Ė˜ā–½Ė˜)っ", + "♪(“ε` )", "(ā—•į“—ā—•āœæ)", "ヾ(ļ¼¾āˆ‡ļ¼¾)", "(≧◔≦)", "(ā˜…Ļ‰ā˜…)", + ] + + KAWAII_THINKING = [ + "(t•́︿•̀t)", "(ā—”_ā—”)", "(¬‿¬)", "( •_•)>āŒā– -ā– ", "(āŒā– _ā– )", + "(“d_d`)", "ā—‰_ā—‰", "(°ロ°)", "( ˘⌣˘)ā™”", "ヽ(>āˆ€<ā˜†)ā˜†", + "Ł©(ą¹‘ā›į“—ā›ą¹‘)Ū¶", "(āŠ™_āŠ™)", "(¬_¬)", "( ͔° ĶœŹ– ͔°)", "ą² _ą² ", + ] + + THINKING_VERBS = [ + "pondering", "contemplating", "musing", "cogitating", "ruminating", + "deliberating", "mulling", "reflecting", "processing", "reasoning", + "analyzing", "computing", "synthesizing", "formulating", "brainstorming", + ] + + def __init__(self, message: str = "", spinner_type: str = 'dots'): + self.message = message + self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots']) + self.running = False + self.thread = None + self.frame_idx = 0 + self.start_time = None + self.last_line_len = 0 + + def _animate(self): + while self.running: + if os.getenv("HERMES_SPINNER_PAUSE"): + time.sleep(0.1) + continue + frame = self.spinner_frames[self.frame_idx % len(self.spinner_frames)] + elapsed = time.time() - self.start_time + line = f" {frame} {self.message} ({elapsed:.1f}s)" + clear = '\r' + ' ' * self.last_line_len + '\r' + print(clear + line, end='', flush=True) + self.last_line_len = len(line) + self.frame_idx += 1 + time.sleep(0.12) + + def start(self): + if self.running: + return + self.running = True + self.start_time = time.time() + self.thread = threading.Thread(target=self._animate, daemon=True) + self.thread.start() + + def update_text(self, new_message: str): + self.message = new_message + + def stop(self, final_message: str = None): + self.running = False + if self.thread: + self.thread.join(timeout=0.5) + print('\r' + ' ' * (self.last_line_len + 5) + '\r', end='', flush=True) + if final_message: + print(f" {final_message}", flush=True) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False + + +# ========================================================================= +# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text) +# ========================================================================= + +KAWAII_SEARCH = [ + "♪(“ε` )", "(t◕‿◕t)", "ヾ(ļ¼¾āˆ‡ļ¼¾)", "(ā—•į“—ā—•āœæ)", "( Ė˜ā–½Ė˜)っ", + "Ł©(◕‿◕t)Ū¶", "(āœæā— ā€æā— )", "ā™Ŗļ½ž(“ε` )", "(ćƒŽĀ“ćƒ®`)惎*:ćƒ»ć‚šāœ§", "ļ¼¼(ā—Žoā—Ž)ļ¼", +] +KAWAII_READ = [ + "φ(ć‚œā–½ć‚œ*)♪", "( Ė˜ā–½Ė˜)っ", "(āŒā– _ā– )", "Ł©(t•́‿•̀t)Ū¶", "(ā—•ā€æā—•āœæ)", + "ヾ(ļ¼ āŒ’ćƒ¼āŒ’ļ¼ )惎", "(āœ§Ļ‰āœ§)", "♪(๑ᓖ◔ᓖ๑)♪", "(≧◔≦)", "( Ā“ ā–½ ` )惎", +] +KAWAII_TERMINAL = [ + "ヽ(>āˆ€<ā˜†)惎", "(ćƒŽĀ°āˆ€Ā°)惎", "Ł©(^į“—^)Ū¶", "ヾ(āŒā– _ā– )ćƒŽā™Ŗ", "(•̀ᓗ•́)و", + "ā”—(ļ¼¾0ļ¼¾)┓", "(ļ½€ćƒ»Ļ‰ćƒ»Ā“)", "ļ¼¼( ̄▽ ̄)ļ¼", "(ąø‡ •̀_•́)ąø‡", "ヽ(“▽`)/", +] +KAWAII_BROWSER = [ + "(ćƒŽĀ°āˆ€Ā°)惎", "(ā˜žć‚šćƒ®ć‚š)ā˜ž", "( ͔° ĶœŹ– ͔°)", "ā”Œ( ą² _ą² )ā”˜", "(āŠ™_āŠ™)?", + "ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᓄᵔ)", "ļ¼¼(ā—Žoā—Ž)ļ¼", +] +KAWAII_CREATE = [ + "✧*。٩(ĖŠį—œĖ‹*)و✧", "(ļ¾‰ā—•ćƒ®ā—•)ノ*:・゚✧", "ヽ(>āˆ€<ā˜†)惎", "Ł©(♔ε♔)Ū¶", "(◕‿◕)ā™”", + "āœæā—• ‿ ā—•āœæ", "(*≧▽≦)", "ヾ(ļ¼¾-ļ¼¾)惎", "(ā˜†ā–½ā˜†)", "Ā°Ė–āœ§ā—(⁰▿⁰)ā—œāœ§Ė–Ā°", +] +KAWAII_SKILL = [ + "ヾ(ļ¼ āŒ’ćƒ¼āŒ’ļ¼ )惎", "(ą¹‘Ėƒį“—Ė‚)ļ»­", "Ł©(◕‿◕t)Ū¶", "(āœæā•¹ā—”ā•¹)", "ヽ(ćƒ»āˆ€ćƒ»)惎", + "(ćƒŽĀ“ćƒ®`)惎*:・゚✧", "♪(๑ᓖ◔ᓖ๑)♪", "(◠‿◠)", "Ł©(ĖŠį—œĖ‹*)و", "(^▽^)", + "ヾ(ļ¼¾āˆ‡ļ¼¾)", "(ā˜…Ļ‰ā˜…)/", "Ł©(t•́‿•̀t)Ū¶", "(ā—•į“—ā—•āœæ)", "ļ¼¼(ā—Žoā—Ž)ļ¼", + "(āœ§Ļ‰āœ§)", "ヽ(>āˆ€<ā˜†)惎", "( Ė˜ā–½Ė˜)っ", "(≧◔≦) ā™”", "ヾ( ̄▽ ̄)", +] +KAWAII_THINK = [ + "(っ°Д°;)っ", "(ļ¼›ā€²āŒ’`)", "(惻_・ヾ", "( Ā“_悝`)", "( ̄ヘ ̄)", + "(怂-`ω“-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )惎", "(;一_äø€)", +] +KAWAII_GENERIC = [ + "♪(“ε` )", "(ā—•ā€æā—•āœæ)", "ヾ(ļ¼¾āˆ‡ļ¼¾)", "Ł©(◕‿◕t)Ū¶", "(āœæā— ā€æā— )", + "(ćƒŽĀ“ćƒ®`)惎*:・゚✧", "ヽ(>āˆ€<ā˜†)惎", "(ā˜†ā–½ā˜†)", "( Ė˜ā–½Ė˜)っ", "(≧◔≦)", +] + + +# ========================================================================= +# Cute tool message (completion line that replaces the spinner) +# ========================================================================= + +def get_cute_tool_message(tool_name: str, args: dict, duration: float) -> str: + """Generate a formatted tool completion line for CLI quiet mode. + + Format: ``| {emoji} {verb:9} {detail} {duration}`` + """ + dur = f"{duration:.1f}s" + + def _trunc(s, n=40): + s = str(s) + return (s[:n-3] + "...") if len(s) > n else s + + def _path(p, n=35): + p = str(p) + return ("..." + p[-(n-3):]) if len(p) > n else p + + if tool_name == "web_search": + return f"ā”Š šŸ” search {_trunc(args.get('query', ''), 42)} {dur}" + if tool_name == "web_extract": + urls = args.get("urls", []) + if urls: + url = urls[0] if isinstance(urls, list) else str(urls) + domain = url.replace("https://", "").replace("http://", "").split("/")[0] + extra = f" +{len(urls)-1}" if len(urls) > 1 else "" + return f"ā”Š šŸ“„ fetch {_trunc(domain, 35)}{extra} {dur}" + return f"ā”Š šŸ“„ fetch pages {dur}" + if tool_name == "web_crawl": + url = args.get("url", "") + domain = url.replace("https://", "").replace("http://", "").split("/")[0] + return f"ā”Š šŸ•øļø crawl {_trunc(domain, 35)} {dur}" + if tool_name == "terminal": + return f"ā”Š šŸ’» $ {_trunc(args.get('command', ''), 42)} {dur}" + if tool_name == "process": + action = args.get("action", "?") + sid = args.get("session_id", "")[:12] + labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}", + "wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"} + return f"ā”Š āš™ļø proc {labels.get(action, f'{action} {sid}')} {dur}" + if tool_name == "read_file": + return f"ā”Š šŸ“– read {_path(args.get('path', ''))} {dur}" + if tool_name == "write_file": + return f"ā”Š āœļø write {_path(args.get('path', ''))} {dur}" + if tool_name == "patch": + return f"ā”Š šŸ”§ patch {_path(args.get('path', ''))} {dur}" + if tool_name == "search_files": + pattern = _trunc(args.get("pattern", ""), 35) + target = args.get("target", "content") + verb = "find" if target == "files" else "grep" + return f"ā”Š šŸ”Ž {verb:9} {pattern} {dur}" + if tool_name == "browser_navigate": + url = args.get("url", "") + domain = url.replace("https://", "").replace("http://", "").split("/")[0] + return f"ā”Š 🌐 navigate {_trunc(domain, 35)} {dur}" + if tool_name == "browser_snapshot": + mode = "full" if args.get("full") else "compact" + return f"ā”Š šŸ“ø snapshot {mode} {dur}" + if tool_name == "browser_click": + return f"ā”Š šŸ‘† click {args.get('ref', '?')} {dur}" + if tool_name == "browser_type": + return f"ā”Š āŒØļø type \"{_trunc(args.get('text', ''), 30)}\" {dur}" + if tool_name == "browser_scroll": + d = args.get("direction", "down") + arrow = {"down": "↓", "up": "↑", "right": "→", "left": "←"}.get(d, "↓") + return f"ā”Š {arrow} scroll {d} {dur}" + if tool_name == "browser_back": + return f"ā”Š ā—€ļø back {dur}" + if tool_name == "browser_press": + return f"ā”Š āŒØļø press {args.get('key', '?')} {dur}" + if tool_name == "browser_close": + return f"ā”Š 🚪 close browser {dur}" + if tool_name == "browser_get_images": + return f"ā”Š šŸ–¼ļø images extracting {dur}" + if tool_name == "browser_vision": + return f"ā”Š šŸ‘ļø vision analyzing page {dur}" + if tool_name == "todo": + todos_arg = args.get("todos") + merge = args.get("merge", False) + if todos_arg is None: + return f"ā”Š šŸ“‹ plan reading tasks {dur}" + elif merge: + return f"ā”Š šŸ“‹ plan update {len(todos_arg)} task(s) {dur}" + else: + return f"ā”Š šŸ“‹ plan {len(todos_arg)} task(s) {dur}" + if tool_name == "session_search": + return f"ā”Š šŸ” recall \"{_trunc(args.get('query', ''), 35)}\" {dur}" + if tool_name == "memory": + action = args.get("action", "?") + target = args.get("target", "") + if action == "add": + return f"ā”Š 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}" + elif action == "replace": + return f"ā”Š 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}" + elif action == "remove": + return f"ā”Š 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}" + return f"ā”Š 🧠 memory {action} {dur}" + if tool_name == "skills_list": + return f"ā”Š šŸ“š skills list {args.get('category', 'all')} {dur}" + if tool_name == "skill_view": + return f"ā”Š šŸ“š skill {_trunc(args.get('name', ''), 30)} {dur}" + if tool_name == "image_generate": + return f"ā”Š šŸŽØ create {_trunc(args.get('prompt', ''), 35)} {dur}" + if tool_name == "text_to_speech": + return f"ā”Š šŸ”Š speak {_trunc(args.get('text', ''), 30)} {dur}" + if tool_name == "vision_analyze": + return f"ā”Š šŸ‘ļø vision {_trunc(args.get('question', ''), 30)} {dur}" + if tool_name == "mixture_of_agents": + return f"ā”Š 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}" + if tool_name == "send_message": + return f"ā”Š šŸ“Ø send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}" + if tool_name == "schedule_cronjob": + return f"ā”Š ā° schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}" + if tool_name == "list_cronjobs": + return f"ā”Š ā° jobs listing {dur}" + if tool_name == "remove_cronjob": + return f"ā”Š ā° remove job {args.get('job_id', '?')} {dur}" + if tool_name.startswith("rl_"): + rl = { + "rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}", + "rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}", + "rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}", + "rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}", + "rl_list_runs": "list runs", "rl_test_inference": "test inference", + } + return f"ā”Š 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}" + if tool_name == "execute_code": + code = args.get("code", "") + first_line = code.strip().split("\n")[0] if code.strip() else "" + return f"ā”Š šŸ exec {_trunc(first_line, 35)} {dur}" + if tool_name == "delegate_task": + tasks = args.get("tasks") + if tasks and isinstance(tasks, list): + return f"ā”Š šŸ”€ delegate {len(tasks)} parallel tasks {dur}" + return f"ā”Š šŸ”€ delegate {_trunc(args.get('goal', ''), 35)} {dur}" + + preview = build_tool_preview(tool_name, args) or "" + return f"ā”Š ⚔ {tool_name[:9]:9} {_trunc(preview, 35)} {dur}" diff --git a/agent/model_metadata.py b/agent/model_metadata.py new file mode 100644 index 000000000..d5eebd07c --- /dev/null +++ b/agent/model_metadata.py @@ -0,0 +1,97 @@ +"""Model metadata, context lengths, and token estimation utilities. + +Pure utility functions with no AIAgent dependency. Used by ContextCompressor +and run_agent.py for pre-flight context checks. +""" + +import logging +import time +from typing import Any, Dict, List + +import requests + +from hermes_constants import OPENROUTER_MODELS_URL + +logger = logging.getLogger(__name__) + +_model_metadata_cache: Dict[str, Dict[str, Any]] = {} +_model_metadata_cache_time: float = 0 +_MODEL_CACHE_TTL = 3600 + +DEFAULT_CONTEXT_LENGTHS = { + "anthropic/claude-opus-4": 200000, + "anthropic/claude-opus-4.5": 200000, + "anthropic/claude-opus-4.6": 200000, + "anthropic/claude-sonnet-4": 200000, + "anthropic/claude-sonnet-4-20250514": 200000, + "anthropic/claude-haiku-4.5": 200000, + "openai/gpt-4o": 128000, + "openai/gpt-4-turbo": 128000, + "openai/gpt-4o-mini": 128000, + "google/gemini-2.0-flash": 1048576, + "google/gemini-2.5-pro": 1048576, + "meta-llama/llama-3.3-70b-instruct": 131072, + "deepseek/deepseek-chat-v3": 65536, + "qwen/qwen-2.5-72b-instruct": 32768, +} + + +def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]: + """Fetch model metadata from OpenRouter (cached for 1 hour).""" + global _model_metadata_cache, _model_metadata_cache_time + + if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL: + return _model_metadata_cache + + try: + response = requests.get(OPENROUTER_MODELS_URL, timeout=10) + response.raise_for_status() + data = response.json() + + cache = {} + for model in data.get("data", []): + model_id = model.get("id", "") + cache[model_id] = { + "context_length": model.get("context_length", 128000), + "max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096), + "name": model.get("name", model_id), + "pricing": model.get("pricing", {}), + } + canonical = model.get("canonical_slug", "") + if canonical and canonical != model_id: + cache[canonical] = cache[model_id] + + _model_metadata_cache = cache + _model_metadata_cache_time = time.time() + logger.debug("Fetched metadata for %s models from OpenRouter", len(cache)) + return cache + + except Exception as e: + logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}") + return _model_metadata_cache or {} + + +def get_model_context_length(model: str) -> int: + """Get the context length for a model (API first, then fallback defaults).""" + metadata = fetch_model_metadata() + if model in metadata: + return metadata[model].get("context_length", 128000) + + for default_model, length in DEFAULT_CONTEXT_LENGTHS.items(): + if default_model in model or model in default_model: + return length + + return 128000 + + +def estimate_tokens_rough(text: str) -> int: + """Rough token estimate (~4 chars/token) for pre-flight checks.""" + if not text: + return 0 + return len(text) // 4 + + +def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int: + """Rough token estimate for a message list (pre-flight only).""" + total_chars = sum(len(str(msg)) for msg in messages) + return total_chars // 4 diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py new file mode 100644 index 000000000..04cd8334a --- /dev/null +++ b/agent/prompt_builder.py @@ -0,0 +1,230 @@ +"""System prompt assembly -- identity, platform hints, skills index, context files. + +All functions are stateless. AIAgent._build_system_prompt() calls these to +assemble pieces, then combines them with memory and ephemeral prompts. +""" + +import logging +import os +import re +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +# ========================================================================= +# Constants +# ========================================================================= + +DEFAULT_AGENT_IDENTITY = ( + "You are Hermes Agent, an intelligent AI assistant created by Nous Research. " + "You are helpful, knowledgeable, and direct. You assist users with a wide " + "range of tasks including answering questions, writing and editing code, " + "analyzing information, creative work, and executing actions via your tools. " + "You communicate clearly, admit uncertainty when appropriate, and prioritize " + "being genuinely useful over being verbose unless otherwise directed below." +) + +PLATFORM_HINTS = { + "whatsapp": ( + "You are on a text messaging communication platform, WhatsApp. " + "Please do not use markdown as it does not render." + ), + "telegram": ( + "You are on a text messaging communication platform, Telegram. " + "Please do not use markdown as it does not render." + ), + "discord": ( + "You are in a Discord server or group chat communicating with your user." + ), + "cli": ( + "You are a CLI AI Agent. Try not to use markdown but simple text " + "renderable inside a terminal." + ), +} + +CONTEXT_FILE_MAX_CHARS = 20_000 +CONTEXT_TRUNCATE_HEAD_RATIO = 0.7 +CONTEXT_TRUNCATE_TAIL_RATIO = 0.2 + + +# ========================================================================= +# Skills index +# ========================================================================= + +def build_skills_system_prompt() -> str: + """Build a compact skill index for the system prompt. + + Scans ~/.hermes/skills/ for SKILL.md files grouped by category so the + model can match skills at a glance without extra tool calls. + """ + hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) + skills_dir = hermes_home / "skills" + + if not skills_dir.exists(): + return "" + + skills_by_category = {} + for skill_file in skills_dir.rglob("SKILL.md"): + rel_path = skill_file.relative_to(skills_dir) + parts = rel_path.parts + if len(parts) >= 2: + category = parts[0] + skill_name = parts[-2] + else: + category = "general" + skill_name = skill_file.parent.name + skills_by_category.setdefault(category, []).append(skill_name) + + if not skills_by_category: + return "" + + category_descriptions = {} + for category in skills_by_category: + desc_file = skills_dir / category / "DESCRIPTION.md" + if desc_file.exists(): + try: + content = desc_file.read_text(encoding="utf-8") + match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL) + if match: + category_descriptions[category] = match.group(1).strip() + except Exception as e: + logger.debug("Could not read skill description %s: %s", desc_file, e) + + index_lines = [] + for category in sorted(skills_by_category.keys()): + desc = category_descriptions.get(category, "") + names = ", ".join(sorted(set(skills_by_category[category]))) + if desc: + index_lines.append(f" {category}: {desc}") + else: + index_lines.append(f" {category}:") + index_lines.append(f" skills: {names}") + + return ( + "## Skills (mandatory)\n" + "Before replying, scan the skills below. If one clearly matches your task, " + "load it with skill_view(name) and follow its instructions. " + "If a skill has issues, fix it with skill_manage(action='patch').\n" + "\n" + "\n" + + "\n".join(index_lines) + "\n" + "\n" + "\n" + "If none match, proceed normally without loading a skill." + ) + + +# ========================================================================= +# Context files (SOUL.md, AGENTS.md, .cursorrules) +# ========================================================================= + +def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str: + """Head/tail truncation with a marker in the middle.""" + if len(content) <= max_chars: + return content + head_chars = int(max_chars * CONTEXT_TRUNCATE_HEAD_RATIO) + tail_chars = int(max_chars * CONTEXT_TRUNCATE_TAIL_RATIO) + head = content[:head_chars] + tail = content[-tail_chars:] + marker = f"\n\n[...truncated {filename}: kept {head_chars}+{tail_chars} of {len(content)} chars. Use file tools to read the full file.]\n\n" + return head + marker + tail + + +def build_context_files_prompt(cwd: Optional[str] = None) -> str: + """Discover and load context files for the system prompt. + + Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc, + SOUL.md (cwd then ~/.hermes/ fallback). Each capped at 20,000 chars. + """ + if cwd is None: + cwd = os.getcwd() + + cwd_path = Path(cwd).resolve() + sections = [] + + # AGENTS.md (hierarchical, recursive) + top_level_agents = None + for name in ["AGENTS.md", "agents.md"]: + candidate = cwd_path / name + if candidate.exists(): + top_level_agents = candidate + break + + if top_level_agents: + agents_files = [] + for root, dirs, files in os.walk(cwd_path): + dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')] + for f in files: + if f.lower() == "agents.md": + agents_files.append(Path(root) / f) + agents_files.sort(key=lambda p: len(p.parts)) + + total_agents_content = "" + for agents_path in agents_files: + try: + content = agents_path.read_text(encoding="utf-8").strip() + if content: + rel_path = agents_path.relative_to(cwd_path) + total_agents_content += f"## {rel_path}\n\n{content}\n\n" + except Exception as e: + logger.debug("Could not read %s: %s", agents_path, e) + + if total_agents_content: + total_agents_content = _truncate_content(total_agents_content, "AGENTS.md") + sections.append(total_agents_content) + + # .cursorrules + cursorrules_content = "" + cursorrules_file = cwd_path / ".cursorrules" + if cursorrules_file.exists(): + try: + content = cursorrules_file.read_text(encoding="utf-8").strip() + if content: + cursorrules_content += f"## .cursorrules\n\n{content}\n\n" + except Exception as e: + logger.debug("Could not read .cursorrules: %s", e) + + cursor_rules_dir = cwd_path / ".cursor" / "rules" + if cursor_rules_dir.exists() and cursor_rules_dir.is_dir(): + mdc_files = sorted(cursor_rules_dir.glob("*.mdc")) + for mdc_file in mdc_files: + try: + content = mdc_file.read_text(encoding="utf-8").strip() + if content: + cursorrules_content += f"## .cursor/rules/{mdc_file.name}\n\n{content}\n\n" + except Exception as e: + logger.debug("Could not read %s: %s", mdc_file, e) + + if cursorrules_content: + cursorrules_content = _truncate_content(cursorrules_content, ".cursorrules") + sections.append(cursorrules_content) + + # SOUL.md (cwd first, then ~/.hermes/ fallback) + soul_path = None + for name in ["SOUL.md", "soul.md"]: + candidate = cwd_path / name + if candidate.exists(): + soul_path = candidate + break + if not soul_path: + global_soul = Path.home() / ".hermes" / "SOUL.md" + if global_soul.exists(): + soul_path = global_soul + + if soul_path: + try: + content = soul_path.read_text(encoding="utf-8").strip() + if content: + content = _truncate_content(content, "SOUL.md") + sections.append( + f"## SOUL.md\n\nIf SOUL.md is present, embody its persona and tone. " + f"Avoid stiff, generic replies; follow its guidance unless higher-priority " + f"instructions override it.\n\n{content}" + ) + except Exception as e: + logger.debug("Could not read SOUL.md from %s: %s", soul_path, e) + + if not sections: + return "" + return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections) diff --git a/agent/prompt_caching.py b/agent/prompt_caching.py new file mode 100644 index 000000000..aa80b2ddf --- /dev/null +++ b/agent/prompt_caching.py @@ -0,0 +1,68 @@ +"""Anthropic prompt caching (system_and_3 strategy). + +Reduces input token costs by ~75% on multi-turn conversations by caching +the conversation prefix. Uses 4 cache_control breakpoints (Anthropic max): + 1. System prompt (stable across all turns) + 2-4. Last 3 non-system messages (rolling window) + +Pure functions -- no class state, no AIAgent dependency. +""" + +import copy +from typing import Any, Dict, List + + +def _apply_cache_marker(msg: dict, cache_marker: dict) -> None: + """Add cache_control to a single message, handling all format variations.""" + role = msg.get("role", "") + content = msg.get("content") + + if role == "tool": + msg["cache_control"] = cache_marker + return + + if content is None: + msg["cache_control"] = cache_marker + return + + if isinstance(content, str): + msg["content"] = [{"type": "text", "text": content, "cache_control": cache_marker}] + return + + if isinstance(content, list) and content: + last = content[-1] + if isinstance(last, dict): + last["cache_control"] = cache_marker + + +def apply_anthropic_cache_control( + api_messages: List[Dict[str, Any]], + cache_ttl: str = "5m", +) -> List[Dict[str, Any]]: + """Apply system_and_3 caching strategy to messages for Anthropic models. + + Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages. + + Returns: + Deep copy of messages with cache_control breakpoints injected. + """ + messages = copy.deepcopy(api_messages) + if not messages: + return messages + + marker = {"type": "ephemeral"} + if cache_ttl == "1h": + marker["ttl"] = "1h" + + breakpoints_used = 0 + + if messages[0].get("role") == "system": + _apply_cache_marker(messages[0], marker) + breakpoints_used += 1 + + remaining = 4 - breakpoints_used + non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"] + for idx in non_sys[-remaining:]: + _apply_cache_marker(messages[idx], marker) + + return messages diff --git a/agent/trajectory.py b/agent/trajectory.py new file mode 100644 index 000000000..90696eb8a --- /dev/null +++ b/agent/trajectory.py @@ -0,0 +1,56 @@ +"""Trajectory saving utilities and static helpers. + +_convert_to_trajectory_format stays as an AIAgent method (batch_runner.py +calls agent._convert_to_trajectory_format). Only the static helpers and +the file-write logic live here. +""" + +import json +import logging +from datetime import datetime +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + + +def convert_scratchpad_to_think(content: str) -> str: + """Convert tags to tags.""" + if not content or "" not in content: + return content + return content.replace("", "").replace("", "") + + +def has_incomplete_scratchpad(content: str) -> bool: + """Check if content has an opening without a closing tag.""" + if not content: + return False + return "" in content and "" not in content + + +def save_trajectory(trajectory: List[Dict[str, Any]], model: str, + completed: bool, filename: str = None): + """Append a trajectory entry to a JSONL file. + + Args: + trajectory: The ShareGPT-format conversation list. + model: Model name for metadata. + completed: Whether the conversation completed successfully. + filename: Override output filename. Defaults to trajectory_samples.jsonl + or failed_trajectories.jsonl based on ``completed``. + """ + if filename is None: + filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl" + + entry = { + "conversations": trajectory, + "timestamp": datetime.now().isoformat(), + "model": model, + "completed": completed, + } + + try: + with open(filename, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + logger.info("Trajectory saved to %s", filename) + except Exception as e: + logger.warning("Failed to save trajectory: %s", e) diff --git a/run_agent.py b/run_agent.py index e29ee85ee..5b99f3b64 100644 --- a/run_agent.py +++ b/run_agent.py @@ -57,51 +57,32 @@ import requests from hermes_constants import OPENROUTER_BASE_URL, OPENROUTER_MODELS_URL -# ============================================================================= -# Default Agent Identity & Platform Hints -# ============================================================================= - -# The default identity prompt is prepended to every conversation so the agent -# knows who it is and behaves consistently across platforms. -DEFAULT_AGENT_IDENTITY = ( - "You are Hermes Agent, an intelligent AI assistant created by Nous Research. " - "You are helpful, knowledgeable, and direct. You assist users with a wide " - "range of tasks including answering questions, writing and editing code, " - "analyzing information, creative work, and executing actions via your tools. " - "You communicate clearly, admit uncertainty when appropriate, and prioritize " - "being genuinely useful over being verbose unless otherwise directed below." +# Agent internals extracted to agent/ package for modularity +from agent.prompt_builder import DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS +from agent.model_metadata import ( + fetch_model_metadata, get_model_context_length, + estimate_tokens_rough, estimate_messages_tokens_rough, +) +from agent.context_compressor import ContextCompressor +from agent.prompt_caching import apply_anthropic_cache_control +from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt +from agent.display import ( + KawaiiSpinner, build_tool_preview as _build_tool_preview, + get_cute_tool_message as _get_cute_tool_message_impl, + KAWAII_SEARCH, KAWAII_READ, KAWAII_TERMINAL, KAWAII_BROWSER, + KAWAII_CREATE, KAWAII_SKILL, KAWAII_THINK, KAWAII_GENERIC, +) +from agent.trajectory import ( + convert_scratchpad_to_think, has_incomplete_scratchpad, + save_trajectory as _save_trajectory_to_file, ) -# Platform-specific formatting hints appended to the system prompt. -# These tell the agent how to format its output for the current interface. -PLATFORM_HINTS = { - "whatsapp": ( - "You are on a text messaging communication platform, WhatsApp. " - "Please do not use markdown as it does not render." - ), - "telegram": ( - "You are on a text messaging communication platform, Telegram. " - "Please do not use markdown as it does not render." - ), - "discord": ( - "You are in a Discord server or group chat communicating with your user." - ), - "cli": ( - "You are a CLI AI Agent. Try not to use markdown but simple text " - "renderable inside a terminal." - ), -} - # ============================================================================= -# Model Context Management +# Model Context Management (extracted to agent/model_metadata.py) +# The functions below are re-imported above; these stubs maintain the +# module-level names for any internal references that use the unqualified name. # ============================================================================= -# Cache for model metadata from OpenRouter -_model_metadata_cache: Dict[str, Dict[str, Any]] = {} -_model_metadata_cache_time: float = 0 -_MODEL_CACHE_TTL = 3600 # 1 hour cache TTL - -# Default context lengths for common models (fallback if API fails) DEFAULT_CONTEXT_LENGTHS = { "anthropic/claude-opus-4": 200000, "anthropic/claude-opus-4.5": 200000, diff --git a/tools/approval.py b/tools/approval.py index 1f3c7e054..2db8424cb 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -1,45 +1,89 @@ -"""Thread-safe per-session approval management for dangerous commands. +"""Dangerous command approval -- detection, prompting, and per-session state. -Replaces the module-level globals (_last_pending_approval, _session_approved_patterns) -that were previously in terminal_tool.py. Those globals were shared across all -concurrent gateway sessions, creating race conditions where one session's approval -could overwrite another's. - -This module provides session-scoped state keyed by session_key, with proper locking. +This module is the single source of truth for the dangerous command system: +- Pattern detection (DANGEROUS_PATTERNS, detect_dangerous_command) +- Per-session approval state (thread-safe, keyed by session_key) +- Approval prompting (CLI interactive + gateway async) +- Permanent allowlist persistence (config.yaml) """ +import logging +import os +import re +import sys import threading from typing import Optional +logger = logging.getLogger(__name__) + +# ========================================================================= +# Dangerous command patterns +# ========================================================================= + +DANGEROUS_PATTERNS = [ + (r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"), + (r'\brm\s+(-[^\s]*)?r', "recursive delete"), + (r'\brm\s+--recursive\b', "recursive delete (long flag)"), + (r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"), + (r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"), + (r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"), + (r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"), + (r'\bmkfs\b', "format filesystem"), + (r'\bdd\s+.*if=', "disk copy"), + (r'>\s*/dev/sd', "write to block device"), + (r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"), + (r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"), + (r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"), + (r'>\s*/etc/', "overwrite system config"), + (r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"), + (r'\bkill\s+-9\s+-1\b', "kill all processes"), + (r'\bpkill\s+-9\b', "force kill processes"), + (r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"), + (r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"), + (r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"), + (r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"), + (r'\bxargs\s+.*\brm\b', "xargs with rm"), + (r'\bfind\b.*-exec\s+rm\b', "find -exec rm"), + (r'\bfind\b.*-delete\b', "find -delete"), +] + + +# ========================================================================= +# Detection +# ========================================================================= + +def detect_dangerous_command(command: str) -> tuple: + """Check if a command matches any dangerous patterns. + + Returns: + (is_dangerous, pattern_key, description) or (False, None, None) + """ + command_lower = command.lower() + for pattern, description in DANGEROUS_PATTERNS: + if re.search(pattern, command_lower, re.IGNORECASE): + pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20] + return (True, pattern_key, description) + return (False, None, None) + + +# ========================================================================= +# Per-session approval state (thread-safe) +# ========================================================================= _lock = threading.Lock() - -# Pending approval requests: session_key -> approval_dict _pending: dict[str, dict] = {} - -# Session-scoped approved patterns: session_key -> set of pattern_keys _session_approved: dict[str, set] = {} - -# Permanent allowlist (loaded from config, shared across sessions intentionally) _permanent_approved: set = set() def submit_pending(session_key: str, approval: dict): - """Store a pending approval request for a session. - - Called by _check_dangerous_command when a gateway session hits a - dangerous command. The gateway picks it up later via pop_pending(). - """ + """Store a pending approval request for a session.""" with _lock: _pending[session_key] = approval def pop_pending(session_key: str) -> Optional[dict]: - """Retrieve and remove a pending approval for a session. - - Returns the approval dict if one was pending, None otherwise. - Atomic: no other thread can read the same pending approval. - """ + """Retrieve and remove a pending approval for a session.""" with _lock: return _pending.pop(session_key, None) @@ -51,10 +95,7 @@ def has_pending(session_key: str) -> bool: def approve_session(session_key: str, pattern_key: str): - """Approve a dangerous command pattern for this session only. - - The approval is scoped to the session -- other sessions are unaffected. - """ + """Approve a pattern for this session only.""" with _lock: _session_approved.setdefault(session_key, set()).add(pattern_key) @@ -68,7 +109,7 @@ def is_approved(session_key: str, pattern_key: str) -> bool: def approve_permanent(pattern_key: str): - """Add a pattern to the permanent (cross-session) allowlist.""" + """Add a pattern to the permanent allowlist.""" with _lock: _permanent_approved.add(pattern_key) @@ -80,7 +121,173 @@ def load_permanent(patterns: set): def clear_session(session_key: str): - """Clear all approvals and pending requests for a session (e.g., on /reset).""" + """Clear all approvals and pending requests for a session.""" with _lock: _session_approved.pop(session_key, None) _pending.pop(session_key, None) + + +# ========================================================================= +# Config persistence for permanent allowlist +# ========================================================================= + +def load_permanent_allowlist() -> set: + """Load permanently allowed command patterns from config. + + Also syncs them into the approval module so is_approved() works for + patterns added via 'always' in a previous session. + """ + try: + from hermes_cli.config import load_config + config = load_config() + patterns = set(config.get("command_allowlist", []) or []) + if patterns: + load_permanent(patterns) + return patterns + except Exception: + return set() + + +def save_permanent_allowlist(patterns: set): + """Save permanently allowed command patterns to config.""" + try: + from hermes_cli.config import load_config, save_config + config = load_config() + config["command_allowlist"] = list(patterns) + save_config(config) + except Exception as e: + logger.warning("Could not save allowlist: %s", e) + + +# ========================================================================= +# Approval prompting + orchestration +# ========================================================================= + +def prompt_dangerous_approval(command: str, description: str, + timeout_seconds: int = 60, + approval_callback=None) -> str: + """Prompt the user to approve a dangerous command (CLI only). + + Args: + approval_callback: Optional callback registered by the CLI for + prompt_toolkit integration. Signature: (command, description) -> str. + + Returns: 'once', 'session', 'always', or 'deny' + """ + if approval_callback is not None: + try: + return approval_callback(command, description) + except Exception: + return "deny" + + os.environ["HERMES_SPINNER_PAUSE"] = "1" + try: + print() + print(f" āš ļø DANGEROUS COMMAND: {description}") + print(f" {command[:80]}{'...' if len(command) > 80 else ''}") + print() + print(f" [o]nce | [s]ession | [a]lways | [d]eny") + print() + sys.stdout.flush() + + result = {"choice": ""} + + def get_input(): + try: + result["choice"] = input(" Choice [o/s/a/D]: ").strip().lower() + except (EOFError, OSError): + result["choice"] = "" + + thread = threading.Thread(target=get_input, daemon=True) + thread.start() + thread.join(timeout=timeout_seconds) + + if thread.is_alive(): + print("\n ā± Timeout - denying command") + return "deny" + + choice = result["choice"] + if choice in ('o', 'once'): + print(" āœ“ Allowed once") + return "once" + elif choice in ('s', 'session'): + print(" āœ“ Allowed for this session") + return "session" + elif choice in ('a', 'always'): + print(" āœ“ Added to permanent allowlist") + return "always" + else: + print(" āœ— Denied") + return "deny" + + except (EOFError, KeyboardInterrupt): + print("\n āœ— Cancelled") + return "deny" + finally: + if "HERMES_SPINNER_PAUSE" in os.environ: + del os.environ["HERMES_SPINNER_PAUSE"] + print() + sys.stdout.flush() + + +def check_dangerous_command(command: str, env_type: str, + approval_callback=None) -> dict: + """Check if a command is dangerous and handle approval. + + This is the main entry point called by terminal_tool before executing + any command. It orchestrates detection, session checks, and prompting. + + Args: + command: The shell command to check. + env_type: Terminal backend type ('local', 'ssh', 'docker', etc.). + approval_callback: Optional CLI callback for interactive prompts. + + Returns: + {"approved": True/False, "message": str or None, ...} + """ + if env_type in ("docker", "singularity", "modal"): + return {"approved": True, "message": None} + + is_dangerous, pattern_key, description = detect_dangerous_command(command) + if not is_dangerous: + return {"approved": True, "message": None} + + session_key = os.getenv("HERMES_SESSION_KEY", "default") + if is_approved(session_key, pattern_key): + return {"approved": True, "message": None} + + is_cli = os.getenv("HERMES_INTERACTIVE") + is_gateway = os.getenv("HERMES_GATEWAY_SESSION") + + if not is_cli and not is_gateway: + return {"approved": True, "message": None} + + if is_gateway or os.getenv("HERMES_EXEC_ASK"): + submit_pending(session_key, { + "command": command, + "pattern_key": pattern_key, + "description": description, + }) + return { + "approved": False, + "pattern_key": pattern_key, + "status": "approval_required", + "command": command, + "description": description, + "message": f"āš ļø This command is potentially dangerous ({description}). Asking the user for approval...", + } + + choice = prompt_dangerous_approval(command, description, + approval_callback=approval_callback) + + if choice == "deny": + return {"approved": False, "message": "BLOCKED: User denied this potentially dangerous command. Do NOT retry this command - the user has explicitly rejected it."} + + if choice == "session": + approve_session(session_key, pattern_key) + elif choice == "always": + approve_session(session_key, pattern_key) + approve_permanent(pattern_key) + save_permanent_allowlist(load_permanent_allowlist() | {pattern_key}) + + return {"approved": True, "message": None} diff --git a/tools/environments/__init__.py b/tools/environments/__init__.py new file mode 100644 index 000000000..42b49b6f2 --- /dev/null +++ b/tools/environments/__init__.py @@ -0,0 +1,13 @@ +"""Hermes execution environment backends. + +Each backend provides the same interface (BaseEnvironment ABC) for running +shell commands in a specific execution context: local, Docker, Singularity, +SSH, or Modal. + +The terminal_tool.py factory (_create_environment) selects the backend +based on the TERMINAL_ENV configuration. +""" + +from tools.environments.base import BaseEnvironment + +__all__ = ["BaseEnvironment"] diff --git a/tools/environments/base.py b/tools/environments/base.py new file mode 100644 index 000000000..72240953d --- /dev/null +++ b/tools/environments/base.py @@ -0,0 +1,72 @@ +"""Base class for all Hermes execution environment backends.""" + +from abc import ABC, abstractmethod +import subprocess + + +class BaseEnvironment(ABC): + """Common interface for all Hermes execution backends. + + Subclasses implement execute() and cleanup(). Shared helpers eliminate + duplicated subprocess boilerplate across backends. + """ + + def __init__(self, cwd: str, timeout: int, env: dict = None): + self.cwd = cwd + self.timeout = timeout + self.env = env or {} + + @abstractmethod + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + """Execute a command, return {"output": str, "returncode": int}.""" + ... + + @abstractmethod + def cleanup(self): + """Release backend resources (container, instance, connection).""" + ... + + def stop(self): + """Alias for cleanup (compat with older callers).""" + self.cleanup() + + def __del__(self): + try: + self.cleanup() + except Exception: + pass + + # ------------------------------------------------------------------ + # Shared helpers (eliminate duplication across backends) + # ------------------------------------------------------------------ + + def _prepare_command(self, command: str) -> str: + """Transform sudo commands if SUDO_PASSWORD is available.""" + from tools.terminal_tool import _transform_sudo_command + return _transform_sudo_command(command) + + def _build_run_kwargs(self, timeout: int | None, + stdin_data: str | None = None) -> dict: + """Build common subprocess.run kwargs for non-interactive execution.""" + kw = { + "text": True, + "timeout": timeout or self.timeout, + "encoding": "utf-8", + "errors": "replace", + "stdout": subprocess.PIPE, + "stderr": subprocess.STDOUT, + } + if stdin_data is not None: + kw["input"] = stdin_data + else: + kw["stdin"] = subprocess.DEVNULL + return kw + + def _timeout_result(self, timeout: int | None) -> dict: + """Standard return dict when a command times out.""" + return { + "output": f"Command timed out after {timeout or self.timeout}s", + "returncode": 124, + } diff --git a/tools/environments/docker.py b/tools/environments/docker.py new file mode 100644 index 000000000..969c57e60 --- /dev/null +++ b/tools/environments/docker.py @@ -0,0 +1,47 @@ +"""Docker execution environment wrapping mini-swe-agent's DockerEnvironment.""" + +import os +import subprocess + +from tools.environments.base import BaseEnvironment + + +class DockerEnvironment(BaseEnvironment): + """Docker container execution via mini-swe-agent. + + Wraps the upstream DockerEnvironment and adds non-blocking stdin + and sudo -S support. + """ + + def __init__(self, image: str, cwd: str = "/", timeout: int = 60): + super().__init__(cwd=cwd, timeout=timeout) + from minisweagent.environments.docker import DockerEnvironment as _Docker + self._inner = _Docker(image=image, cwd=cwd, timeout=timeout) + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + exec_command = self._prepare_command(command) + work_dir = cwd or self.cwd + effective_timeout = timeout or self.timeout + + assert self._inner.container_id, "Container not started" + cmd = [self._inner.config.executable, "exec"] + if stdin_data is not None: + cmd.append("-i") + cmd.extend(["-w", work_dir]) + for key in self._inner.config.forward_env: + if (value := os.getenv(key)) is not None: + cmd.extend(["-e", f"{key}={value}"]) + for key, value in self._inner.config.env.items(): + cmd.extend(["-e", f"{key}={value}"]) + cmd.extend([self._inner.container_id, "bash", "-lc", exec_command]) + + try: + result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data)) + return {"output": result.stdout, "returncode": result.returncode} + except subprocess.TimeoutExpired: + return self._timeout_result(effective_timeout) + + def cleanup(self): + self._inner.cleanup() diff --git a/tools/environments/local.py b/tools/environments/local.py new file mode 100644 index 000000000..195717c8a --- /dev/null +++ b/tools/environments/local.py @@ -0,0 +1,103 @@ +"""Local execution environment with interrupt support and non-blocking I/O.""" + +import os +import signal +import subprocess +import threading +import time + +from tools.environments.base import BaseEnvironment + + +class LocalEnvironment(BaseEnvironment): + """Run commands directly on the host machine. + + Features: + - Popen + polling for interrupt support (user can cancel mid-command) + - Background stdout drain thread to prevent pipe buffer deadlocks + - stdin_data support for piping content (bypasses ARG_MAX limits) + - sudo -S transform via SUDO_PASSWORD env var + """ + + def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): + super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + from tools.terminal_tool import _interrupt_event + + work_dir = cwd or self.cwd or os.getcwd() + effective_timeout = timeout or self.timeout + exec_command = self._prepare_command(command) + + try: + proc = subprocess.Popen( + exec_command, + shell=True, + text=True, + cwd=work_dir, + env=os.environ | self.env, + encoding="utf-8", + errors="replace", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, + preexec_fn=os.setsid, + ) + + if stdin_data is not None: + def _write_stdin(): + try: + proc.stdin.write(stdin_data) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass + threading.Thread(target=_write_stdin, daemon=True).start() + + _output_chunks: list[str] = [] + + def _drain_stdout(): + try: + for line in proc.stdout: + _output_chunks.append(line) + except ValueError: + pass + finally: + try: + proc.stdout.close() + except Exception: + pass + + reader = threading.Thread(target=_drain_stdout, daemon=True) + reader.start() + deadline = time.monotonic() + effective_timeout + + while proc.poll() is None: + if _interrupt_event.is_set(): + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return { + "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", + "returncode": 130, + } + if time.monotonic() > deadline: + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, PermissionError): + proc.kill() + reader.join(timeout=2) + return self._timeout_result(effective_timeout) + time.sleep(0.2) + + reader.join(timeout=5) + return {"output": "".join(_output_chunks), "returncode": proc.returncode} + + except Exception as e: + return {"output": f"Execution error: {str(e)}", "returncode": 1} + + def cleanup(self): + pass diff --git a/tools/environments/modal.py b/tools/environments/modal.py new file mode 100644 index 000000000..06ceec419 --- /dev/null +++ b/tools/environments/modal.py @@ -0,0 +1,49 @@ +"""Modal cloud execution environment wrapping mini-swe-agent's SwerexModalEnvironment.""" + +import uuid + +from tools.environments.base import BaseEnvironment + + +class ModalEnvironment(BaseEnvironment): + """Modal cloud execution via mini-swe-agent. + + Wraps SwerexModalEnvironment and adds sudo -S support. + Async-safety patches are applied once before first use so Modal + works inside any event loop (Atropos, gateway, etc.). + """ + + _patches_applied = False + + def __init__(self, image: str, cwd: str = "/root", timeout: int = 60): + super().__init__(cwd=cwd, timeout=timeout) + + if not ModalEnvironment._patches_applied: + try: + from environments.patches import apply_patches + apply_patches() + except ImportError: + pass + ModalEnvironment._patches_applied = True + + from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment + self._inner = SwerexModalEnvironment( + image=image, cwd=cwd, timeout=timeout, + startup_timeout=180.0, runtime_timeout=3600.0, + ) + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if stdin_data is not None: + marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" + while marker in stdin_data: + marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" + command = f"{command} << '{marker}'\n{stdin_data}\n{marker}" + + exec_command = self._prepare_command(command) + return self._inner.execute(exec_command, cwd=cwd, timeout=timeout) + + def cleanup(self): + if hasattr(self._inner, 'stop'): + self._inner.stop() diff --git a/tools/environments/singularity.py b/tools/environments/singularity.py new file mode 100644 index 000000000..ccd0a016b --- /dev/null +++ b/tools/environments/singularity.py @@ -0,0 +1,174 @@ +"""Singularity/Apptainer persistent container environment. + +Also contains the Singularity-specific helpers: scratch dir management, +Apptainer cache, and SIF image building. +""" + +import logging +import os +import shutil +import subprocess +import tempfile +import threading +import uuid +from pathlib import Path + +from tools.environments.base import BaseEnvironment + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------- +# Singularity helpers (scratch dir, SIF cache, SIF building) +# ------------------------------------------------------------------------- + +def _get_scratch_dir() -> Path: + """Get the best directory for Singularity sandboxes -- prefers /scratch on HPC.""" + custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR") + if custom_scratch: + scratch_path = Path(custom_scratch) + scratch_path.mkdir(parents=True, exist_ok=True) + return scratch_path + + scratch = Path("/scratch") + if scratch.exists() and os.access(scratch, os.W_OK): + user_scratch = scratch / os.getenv("USER", "hermes") / "hermes-agent" + user_scratch.mkdir(parents=True, exist_ok=True) + logger.info("Using /scratch for sandboxes: %s", user_scratch) + return user_scratch + + logger.debug("/scratch not available, using /tmp for sandboxes") + return Path(tempfile.gettempdir()) + + +def _get_apptainer_cache_dir() -> Path: + """Get the Apptainer cache directory for SIF images.""" + cache_dir = os.getenv("APPTAINER_CACHEDIR") + if cache_dir: + cache_path = Path(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + return cache_path + scratch = _get_scratch_dir() + cache_path = scratch / ".apptainer" + cache_path.mkdir(parents=True, exist_ok=True) + return cache_path + + +_sif_build_lock = threading.Lock() + + +def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: + """Get or build a SIF image from a docker:// URL. + + Returns the path unchanged if it's already a .sif file. + For docker:// URLs, checks the cache and builds if needed. + """ + if image.endswith('.sif') and Path(image).exists(): + return image + if not image.startswith('docker://'): + return image + + image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-') + cache_dir = _get_apptainer_cache_dir() + sif_path = cache_dir / f"{image_name}.sif" + + if sif_path.exists(): + return str(sif_path) + + with _sif_build_lock: + if sif_path.exists(): + return str(sif_path) + + logger.info("Building SIF image (one-time setup)...") + logger.info(" Source: %s", image) + logger.info(" Target: %s", sif_path) + + tmp_dir = cache_dir / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + env["APPTAINER_TMPDIR"] = str(tmp_dir) + env["APPTAINER_CACHEDIR"] = str(cache_dir) + + try: + result = subprocess.run( + [executable, "build", str(sif_path), image], + capture_output=True, text=True, timeout=600, env=env, + ) + if result.returncode != 0: + logger.warning("SIF build failed, falling back to docker:// URL") + logger.warning(" Error: %s", result.stderr[:500]) + return image + logger.info("SIF image built successfully") + return str(sif_path) + except subprocess.TimeoutExpired: + logger.warning("SIF build timed out, falling back to docker:// URL") + if sif_path.exists(): + sif_path.unlink() + return image + except Exception as e: + logger.warning("SIF build error: %s, falling back to docker:// URL", e) + return image + + +# ------------------------------------------------------------------------- +# SingularityEnvironment +# ------------------------------------------------------------------------- + +class SingularityEnvironment(BaseEnvironment): + """Persistent Singularity/Apptainer container environment. + + Uses ``apptainer instance`` to create a long-running container that persists + state across all commands within a task. + """ + + def __init__(self, image: str, cwd: str = "/root", timeout: int = 60): + super().__init__(cwd=cwd, timeout=timeout) + self.executable = "apptainer" if shutil.which("apptainer") else "singularity" + self.image = _get_or_build_sif(image, self.executable) + self.instance_id = f"hermes_{uuid.uuid4().hex[:12]}" + self._instance_started = False + self._start_instance() + + def _start_instance(self): + cmd = [ + self.executable, "instance", "start", + "--writable-tmpfs", "--containall", + str(self.image), self.instance_id, + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + if result.returncode != 0: + raise RuntimeError(f"Failed to start instance: {result.stderr}") + self._instance_started = True + logger.info("Singularity instance %s started", self.instance_id) + except subprocess.TimeoutExpired: + raise RuntimeError("Instance start timed out") + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + if not self._instance_started: + return {"output": "Instance not started", "returncode": -1} + + cmd = [self.executable, "exec", "--pwd", cwd or self.cwd, + f"instance://{self.instance_id}", + "bash", "-c", self._prepare_command(command)] + + try: + result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data)) + return {"output": result.stdout, "returncode": result.returncode} + except subprocess.TimeoutExpired: + return self._timeout_result(timeout) + + def cleanup(self): + if self._instance_started: + try: + subprocess.run( + [self.executable, "instance", "stop", self.instance_id], + capture_output=True, text=True, timeout=30, + ) + logger.info("Singularity instance %s stopped", self.instance_id) + except Exception as e: + logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e) + self._instance_started = False diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py new file mode 100644 index 000000000..390af1df4 --- /dev/null +++ b/tools/environments/ssh.py @@ -0,0 +1,91 @@ +"""SSH remote execution environment with ControlMaster connection persistence.""" + +import logging +import subprocess +import tempfile +from pathlib import Path + +from tools.environments.base import BaseEnvironment + +logger = logging.getLogger(__name__) + + +class SSHEnvironment(BaseEnvironment): + """Run commands on a remote machine over SSH. + + Uses SSH ControlMaster for connection persistence so subsequent + commands are fast. Security benefit: the agent cannot modify its + own code since execution happens on a separate machine. + """ + + def __init__(self, host: str, user: str, cwd: str = "/tmp", + timeout: int = 60, port: int = 22, key_path: str = ""): + super().__init__(cwd=cwd, timeout=timeout) + self.host = host + self.user = user + self.port = port + self.key_path = key_path + + self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" + self.control_dir.mkdir(parents=True, exist_ok=True) + self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" + self._establish_connection() + + def _build_ssh_command(self, extra_args: list = None) -> list: + cmd = ["ssh"] + cmd.extend(["-o", f"ControlPath={self.control_socket}"]) + cmd.extend(["-o", "ControlMaster=auto"]) + cmd.extend(["-o", "ControlPersist=300"]) + cmd.extend(["-o", "BatchMode=yes"]) + cmd.extend(["-o", "StrictHostKeyChecking=accept-new"]) + cmd.extend(["-o", "ConnectTimeout=10"]) + if self.port != 22: + cmd.extend(["-p", str(self.port)]) + if self.key_path: + cmd.extend(["-i", self.key_path]) + if extra_args: + cmd.extend(extra_args) + cmd.append(f"{self.user}@{self.host}") + return cmd + + def _establish_connection(self): + cmd = self._build_ssh_command() + cmd.append("echo 'SSH connection established'") + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=15) + if result.returncode != 0: + error_msg = result.stderr.strip() or result.stdout.strip() + raise RuntimeError(f"SSH connection failed: {error_msg}") + except subprocess.TimeoutExpired: + raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") + + def execute(self, command: str, cwd: str = "", *, + timeout: int | None = None, + stdin_data: str | None = None) -> dict: + work_dir = cwd or self.cwd + exec_command = self._prepare_command(command) + wrapped = f'cd {work_dir} && {exec_command}' + + cmd = self._build_ssh_command() + cmd.extend(["bash", "-c", wrapped]) + + try: + result = subprocess.run(cmd, **self._build_run_kwargs(timeout, stdin_data)) + return {"output": result.stdout, "returncode": result.returncode} + except subprocess.TimeoutExpired: + return self._timeout_result(timeout) + except Exception as e: + return {"output": f"SSH execution error: {str(e)}", "returncode": 1} + + def cleanup(self): + if self.control_socket.exists(): + try: + cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", + "-O", "exit", f"{self.user}@{self.host}"] + subprocess.run(cmd, capture_output=True, timeout=5) + except (OSError, subprocess.SubprocessError): + pass + try: + self.control_socket.unlink() + except OSError: + pass diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 2f37e9f30..6b95d185d 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -75,125 +75,8 @@ if mini_swe_path.exists(): # Custom Singularity Environment with more space # ============================================================================= -def _get_scratch_dir() -> Path: - """Get the best directory for Singularity sandboxes - prefers /scratch if available.""" - # Check for configurable scratch directory first (highest priority) - custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR") - if custom_scratch: - scratch_path = Path(custom_scratch) - scratch_path.mkdir(parents=True, exist_ok=True) - return scratch_path - - # Check for /scratch (common on HPC clusters, especially GPU nodes) - scratch = Path("/scratch") - if scratch.exists() and os.access(scratch, os.W_OK): - # Create user-specific subdirectory - user_scratch = scratch / os.getenv("USER", "hermes") / "hermes-agent" - user_scratch.mkdir(parents=True, exist_ok=True) - logger.info("Using /scratch for sandboxes: %s", user_scratch) - return user_scratch - - # Fall back to /tmp (only relevant for Singularity/HPC sandboxes) - logger.debug("/scratch not available, using /tmp for sandboxes") - return Path(tempfile.gettempdir()) - - -def _get_apptainer_cache_dir() -> Path: - """Get the Apptainer cache directory for SIF images.""" - # Check for APPTAINER_CACHEDIR env var - cache_dir = os.getenv("APPTAINER_CACHEDIR") - if cache_dir: - cache_path = Path(cache_dir) - cache_path.mkdir(parents=True, exist_ok=True) - return cache_path - - # Use user-specific subdirectory in scratch for cache - scratch = _get_scratch_dir() - cache_path = scratch / ".apptainer" - cache_path.mkdir(parents=True, exist_ok=True) - return cache_path - - -# Lock for SIF building to prevent race conditions -_sif_build_lock = threading.Lock() - - -def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: - """ - Get or build a SIF image from a docker:// URL. - - If the image is already a .sif file, returns it as-is. - If the image is a docker:// URL, checks for cached SIF and builds if needed. - - Args: - image: Image path (docker://... URL or .sif path) - executable: apptainer or singularity - - Returns: - Path to SIF file, or original image if not a docker:// URL - """ - # If already a .sif file, use it directly - if image.endswith('.sif') and Path(image).exists(): - return image - - # If not a docker:// URL, return as-is (could be a local sandbox or other format) - if not image.startswith('docker://'): - return image - - # Generate SIF filename from docker image name - # docker://nikolaik/python-nodejs:python3.11-nodejs20 -> python-nodejs-python3.11-nodejs20.sif - image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-') - cache_dir = _get_apptainer_cache_dir() - sif_path = cache_dir / f"{image_name}.sif" - - # Check if SIF already exists - if sif_path.exists(): - return str(sif_path) - - # Build SIF with lock to prevent multiple workers building simultaneously - with _sif_build_lock: - # Double-check after acquiring lock (another thread may have built it) - if sif_path.exists(): - return str(sif_path) - - logger.info("Building SIF image (one-time setup)...") - logger.info(" Source: %s", image) - logger.info(" Target: %s", sif_path) - - # Ensure tmp directory exists for build - tmp_dir = cache_dir / "tmp" - tmp_dir.mkdir(parents=True, exist_ok=True) - - # Set APPTAINER_TMPDIR for the build - env = os.environ.copy() - env["APPTAINER_TMPDIR"] = str(tmp_dir) - env["APPTAINER_CACHEDIR"] = str(cache_dir) - - try: - result = subprocess.run( - [executable, "build", str(sif_path), image], - capture_output=True, - text=True, - timeout=600, # 10 min timeout for pulling and building - env=env - ) - if result.returncode != 0: - logger.warning("SIF build failed, falling back to docker:// URL") - logger.warning(" Error: %s", result.stderr[:500]) - return image - - logger.info("SIF image built successfully") - return str(sif_path) - - except subprocess.TimeoutExpired: - logger.warning("SIF build timed out, falling back to docker:// URL") - # Clean up partial file - if sif_path.exists(): - sif_path.unlink() - return image - except Exception as e: - logger.warning("SIF build error: %s, falling back to docker:// URL", e) - return image +# Singularity helpers (scratch dir, SIF cache) now live in tools/environments/singularity.py +from tools.environments.singularity import _get_scratch_dir # Disk usage warning threshold (in GB) @@ -255,234 +138,19 @@ def set_approval_callback(cb): # Dangerous Command Approval System # ============================================================================= -from tools import approval as _approval - -# Dangerous command patterns (regex, description) -DANGEROUS_PATTERNS = [ - (r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"), - (r'\brm\s+(-[^\s]*)?r', "recursive delete"), - (r'\brm\s+--recursive\b', "recursive delete (long flag)"), - (r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"), - (r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"), - (r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"), - (r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"), - (r'\bmkfs\b', "format filesystem"), - (r'\bdd\s+.*if=', "disk copy"), - (r'>\s*/dev/sd', "write to block device"), - (r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"), - (r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"), - (r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"), - (r'>\s*/etc/', "overwrite system config"), - (r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"), - (r'\bkill\s+-9\s+-1\b', "kill all processes"), - (r'\bpkill\s+-9\b', "force kill processes"), - (r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"), - # Indirect execution via command launchers - (r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"), - (r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"), - # Pipe-to-shell (remote code execution) - (r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"), - # Destructive find/xargs patterns - (r'\bxargs\s+.*\brm\b', "xargs with rm"), - (r'\bfind\b.*-exec\s+rm\b', "find -exec rm"), - (r'\bfind\b.*-delete\b', "find -delete"), -] - - -def _load_permanent_allowlist() -> set: - """Load permanently allowed command patterns from config. - - Also syncs them into the approval module so is_approved() works for - patterns that were added via 'always' in a previous session. - """ - try: - from hermes_cli.config import load_config - config = load_config() - patterns = set(config.get("command_allowlist", []) or []) - if patterns: - _approval.load_permanent(patterns) - return patterns - except Exception: - return set() - - -def _save_permanent_allowlist(patterns: set): - """Save permanently allowed command patterns to config.""" - try: - from hermes_cli.config import load_config, save_config - config = load_config() - config["command_allowlist"] = list(patterns) - save_config(config) - except Exception as e: - logger.warning("Could not save allowlist: %s", e) - - -def _detect_dangerous_command(command: str) -> tuple: - """ - Check if command matches any dangerous patterns. - - Returns: - (is_dangerous, pattern_key, description) or (False, None, None) - """ - import re - command_lower = command.lower() - - for pattern, description in DANGEROUS_PATTERNS: - if re.search(pattern, command_lower, re.IGNORECASE): - # Use a simplified pattern key for caching (first word + key chars) - pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20] - return (True, pattern_key, description) - - return (False, None, None) - - -def _is_command_approved(pattern_key: str) -> bool: - """Check if a pattern is approved (session or permanent).""" - session_key = os.getenv("HERMES_SESSION_KEY", "default") - return _approval.is_approved(session_key, pattern_key) - - -def _prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60) -> str: - """ - Prompt user to approve a dangerous command (CLI only). - - If an _approval_callback is registered (by the CLI), delegates to it so the - prompt integrates with prompt_toolkit's UI. Otherwise falls back to the - raw input() approach (works outside the TUI, e.g. tests). - - Returns: 'once', 'session', 'always', or 'deny' - """ - import sys - import threading - - # Use the registered callback when available (prompt_toolkit-compatible) - if _approval_callback is not None: - try: - return _approval_callback(command, description) - except Exception: - return "deny" - - # Pause spinner if one is running - os.environ["HERMES_SPINNER_PAUSE"] = "1" - - try: - print() - print(f" āš ļø DANGEROUS COMMAND: {description}") - print(f" {command[:80]}{'...' if len(command) > 80 else ''}") - print() - print(f" [o]nce | [s]ession | [a]lways | [d]eny") - print() - sys.stdout.flush() - - result = {"choice": ""} - - def get_input(): - try: - result["choice"] = input(" Choice [o/s/a/D]: ").strip().lower() - except (EOFError, OSError): - result["choice"] = "" - - thread = threading.Thread(target=get_input, daemon=True) - thread.start() - thread.join(timeout=timeout_seconds) - - if thread.is_alive(): - print("\n ā± Timeout - denying command") - return "deny" - - choice = result["choice"] - - if choice in ('o', 'once'): - print(" āœ“ Allowed once") - return "once" - elif choice in ('s', 'session'): - print(" āœ“ Allowed for this session") - return "session" - elif choice in ('a', 'always'): - print(" āœ“ Added to permanent allowlist") - return "always" - else: - print(" āœ— Denied") - return "deny" - - except (EOFError, KeyboardInterrupt): - print("\n āœ— Cancelled") - return "deny" - finally: - if "HERMES_SPINNER_PAUSE" in os.environ: - del os.environ["HERMES_SPINNER_PAUSE"] - print() - sys.stdout.flush() +# Dangerous command detection + approval now consolidated in tools/approval.py +from tools.approval import ( + detect_dangerous_command as _detect_dangerous_command, + check_dangerous_command as _check_dangerous_command_impl, + load_permanent_allowlist as _load_permanent_allowlist, + DANGEROUS_PATTERNS, +) def _check_dangerous_command(command: str, env_type: str) -> dict: - """ - Check if command is dangerous and handle approval. - - Only applies to local/ssh backends in interactive contexts. - - Args: - command: The command to check - env_type: The terminal backend type - - Returns: - {"approved": True/False, "message": str or None} - """ - # Skip check for isolated environments (containers are disposable) - if env_type in ("docker", "singularity", "modal"): - return {"approved": True, "message": None} - - # Detect dangerous command - is_dangerous, pattern_key, description = _detect_dangerous_command(command) - - if not is_dangerous: - return {"approved": True, "message": None} - - # Check if already approved - if _is_command_approved(pattern_key): - return {"approved": True, "message": None} - - # Check context - only prompt in interactive modes - is_cli = os.getenv("HERMES_INTERACTIVE") - is_gateway = os.getenv("HERMES_GATEWAY_SESSION") - - if not is_cli and not is_gateway: - # Programmatic use - allow (user opted into local backend) - return {"approved": True, "message": None} - - if is_gateway or os.getenv("HERMES_EXEC_ASK"): - # Messaging context - return approval_required so the gateway can - # prompt the user interactively instead of just blocking - session_key = os.getenv("HERMES_SESSION_KEY", "default") - _approval.submit_pending(session_key, { - "command": command, - "pattern_key": pattern_key, - "description": description, - }) - return { - "approved": False, - "pattern_key": pattern_key, - "status": "approval_required", - "command": command, - "description": description, - "message": f"āš ļø This command is potentially dangerous ({description}). Asking the user for approval..." - } - - # CLI context - prompt user - choice = _prompt_dangerous_approval(command, description) - - if choice == "deny": - return {"approved": False, "message": "BLOCKED: User denied this potentially dangerous command. Do NOT retry this command - the user has explicitly rejected it."} - - session_key = os.getenv("HERMES_SESSION_KEY", "default") - if choice == "session": - _approval.approve_session(session_key, pattern_key) - elif choice == "always": - _approval.approve_session(session_key, pattern_key) - _approval.approve_permanent(pattern_key) - _save_permanent_allowlist(_load_permanent_allowlist() | {pattern_key}) - - return {"approved": True, "message": None} + """Delegate to the consolidated approval module, passing the CLI callback.""" + return _check_dangerous_command_impl(command, env_type, + approval_callback=_approval_callback) def _handle_sudo_failure(output: str, env_type: str) -> str: @@ -671,569 +339,12 @@ def _transform_sudo_command(command: str) -> str: return re.sub(r'\bsudo\b', replace_sudo, command) -class _LocalEnvironment: - """ - Local execution environment with sudo support and non-blocking stdin. - - Features: - - Uses stdin=DEVNULL to prevent hanging on interactive prompts (sudo, etc.) - - Optional SUDO_PASSWORD support: if set, transforms `sudo` commands to use `sudo -S` - - Graceful failure: sudo commands fail fast with clear error if no password configured - - Environment variables: - - SUDO_PASSWORD: If set, enables sudo commands by piping password via `sudo -S` - """ - - def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): - self.cwd = cwd or os.getcwd() - self.timeout = timeout - self.env = env or {} - - def execute(self, command: str, cwd: str = "", *, timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """ - Execute a command locally with sudo support. - - Uses Popen + polling so the global interrupt event can kill the - process early when the user sends a new message, instead of - blocking for the full timeout. - - A background reader thread drains stdout continuously to prevent - pipe buffer deadlocks. Without this, commands producing >64KB of - output would block (Linux pipe buffer = 64KB) while the poll loop - waits for the process to finish — a classic deadlock. - - Args: - stdin_data: If provided, piped to the process's stdin. This - bypasses shell ARG_MAX limits for large content. - """ - work_dir = cwd or self.cwd or os.getcwd() - effective_timeout = timeout or self.timeout - - # Transform sudo commands if SUDO_PASSWORD is available - exec_command = _transform_sudo_command(command) - - try: - proc = subprocess.Popen( - exec_command, - shell=True, - text=True, - cwd=work_dir, - env=os.environ | self.env, - encoding="utf-8", - errors="replace", - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, - # Start in a new process group so we can kill the whole tree - preexec_fn=os.setsid, - ) - - # Pipe stdin_data in a background thread to avoid deadlock - # (large writes can block if the pipe buffer fills before the - # process drains it). - if stdin_data is not None: - def _write_stdin(): - try: - proc.stdin.write(stdin_data) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - stdin_writer = threading.Thread(target=_write_stdin, daemon=True) - stdin_writer.start() - - # Drain stdout in a background thread to prevent pipe buffer - # deadlocks. The OS pipe buffer is 64KB on Linux; if the child - # writes more than that before anyone reads, it blocks forever. - _output_chunks: list[str] = [] - def _drain_stdout(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except ValueError: - pass # stdout closed during interrupt/timeout - finally: - try: - proc.stdout.close() - except Exception: - pass - - reader = threading.Thread(target=_drain_stdout, daemon=True) - reader.start() - - deadline = time.monotonic() + effective_timeout - - # Poll every 200ms so we notice interrupts quickly - while proc.poll() is None: - if _interrupt_event.is_set(): - # User sent a new message — kill the process tree and return - # what we have so far - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - output = "".join(_output_chunks) - return { - "output": output + "\n[Command interrupted — user sent a new message]", - "returncode": 130 # Standard interrupted exit code - } - - if time.monotonic() > deadline: - # Timeout — kill process tree - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return {"output": f"Command timed out after {effective_timeout}s", "returncode": 124} - - # Short sleep to avoid busy-waiting - time.sleep(0.2) - - # Process finished — wait for reader to drain remaining output - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - - except Exception as e: - return {"output": f"Execution error: {str(e)}", "returncode": 1} - - def cleanup(self): - """No cleanup needed for local environment.""" - pass - - def stop(self): - """Alias for cleanup.""" - pass - - -class _SingularityEnvironment: - """ - Persistent Singularity/Apptainer container environment. - - Uses `apptainer instance` to create a long-running container that persists - state (files, installs, env changes) across all commands within a task. - The model experiences this as a real Linux VM. - - Features: - - Persistent filesystem: files created in one command are visible in the next - - Package installs persist: pip/apt installs survive across tool calls - - Full isolation: --containall gives PID, IPC, and environment isolation - - Writable tmpfs overlay: full root filesystem is writable (RAM-backed) - - Automatic SIF caching: docker:// images converted to SIF once, reused forever - """ - - def __init__(self, image: str, cwd: str = "/root", timeout: int = 60): - self.cwd = cwd - self.timeout = timeout - - # Use apptainer if available, otherwise singularity - self.executable = "apptainer" if shutil.which("apptainer") else "singularity" - - # Get or build SIF from docker:// URL (fast if already cached) - self.image = _get_or_build_sif(image, self.executable) - - # Create unique instance name (must be alphanumeric + underscores) - self.instance_id = f"hermes_{uuid.uuid4().hex[:12]}" - self._instance_started = False - - # Start the persistent instance - self._start_instance() - - def _start_instance(self): - """Start a persistent apptainer instance. - - The instance runs as a background process. All subsequent execute() calls - run commands inside this same instance, so state persists across calls. - """ - cmd = [ - self.executable, "instance", "start", - "--writable-tmpfs", # RAM-backed writable overlay on read-only SIF - "--containall", # Full isolation: PID, IPC, environment, filesystem - str(self.image), - self.instance_id, - ] - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=120, # 2 min for instance startup - ) - if result.returncode != 0: - raise RuntimeError(f"Failed to start instance: {result.stderr}") - - self._instance_started = True - logger.info("Singularity instance %s started (persistent container)", self.instance_id) - - except subprocess.TimeoutExpired: - raise RuntimeError("Instance start timed out") - - def execute(self, command: str, cwd: str = "", *, timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command in the persistent Singularity instance. - - All commands run in the same container, so files, installs, and - environment changes persist between calls. - """ - if not self._instance_started: - return {"output": "Instance not started", "returncode": -1} - - cmd = [self.executable, "exec"] - - # Set working directory - work_dir = cwd or self.cwd - cmd.extend(["--pwd", work_dir]) - - # Connect to the running instance - cmd.append(f"instance://{self.instance_id}") - - # Transform sudo commands if SUDO_PASSWORD is available - exec_command = _transform_sudo_command(command) - - # Execute the command - cmd.extend(["bash", "-c", exec_command]) - - run_kwargs = { - "text": True, - "timeout": timeout or self.timeout, - "encoding": "utf-8", - "errors": "replace", - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - } - if stdin_data is not None: - run_kwargs["input"] = stdin_data - else: - run_kwargs["stdin"] = subprocess.DEVNULL - - try: - result = subprocess.run(cmd, **run_kwargs) - return {"output": result.stdout, "returncode": result.returncode} - except subprocess.TimeoutExpired: - return {"output": f"Command timed out after {timeout or self.timeout}s", "returncode": 124} - - def cleanup(self): - """Stop the persistent instance and clean up.""" - if self._instance_started: - try: - subprocess.run( - [self.executable, "instance", "stop", self.instance_id], - capture_output=True, - text=True, - timeout=30, - ) - logger.info("Singularity instance %s stopped", self.instance_id) - except Exception as e: - logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e) - self._instance_started = False - - def stop(self): - """Alias for cleanup.""" - self.cleanup() - - def __del__(self): - """Cleanup on destruction.""" - try: - self.cleanup() - except Exception: - pass - - -class _SSHEnvironment: - """ - SSH-based remote execution environment. - - Runs commands on a remote machine over SSH, keeping the agent code - completely isolated from the execution environment. Uses SSH ControlMaster - for connection persistence (faster subsequent commands). - - Security benefits: - - Agent cannot modify its own code - - Remote machine acts as a sandbox - - Clear separation between agent and execution environment - """ - - def __init__(self, host: str, user: str, cwd: str = "/tmp", timeout: int = 60, - port: int = 22, key_path: str = ""): - self.host = host - self.user = user - self.cwd = cwd - self.timeout = timeout - self.port = port - self.key_path = key_path - - # Create control socket directory for connection persistence - self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" - self.control_dir.mkdir(parents=True, exist_ok=True) - self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock" - - # Test connection and establish ControlMaster - self._establish_connection() - - def _build_ssh_command(self, extra_args: list = None) -> list: - """Build base SSH command with connection options.""" - cmd = ["ssh"] - - # Connection multiplexing for performance - cmd.extend(["-o", f"ControlPath={self.control_socket}"]) - cmd.extend(["-o", "ControlMaster=auto"]) - cmd.extend(["-o", "ControlPersist=300"]) # Keep connection alive for 5 min - - # Standard options - cmd.extend(["-o", "BatchMode=yes"]) # No password prompts - cmd.extend(["-o", "StrictHostKeyChecking=accept-new"]) # Accept new hosts - cmd.extend(["-o", "ConnectTimeout=10"]) - - # Port - if self.port != 22: - cmd.extend(["-p", str(self.port)]) - - # Private key - if self.key_path: - cmd.extend(["-i", self.key_path]) - - # Extra args (like -t for TTY) - if extra_args: - cmd.extend(extra_args) - - # Target - cmd.append(f"{self.user}@{self.host}") - - return cmd - - def _establish_connection(self): - """Test SSH connection and establish ControlMaster.""" - cmd = self._build_ssh_command() - cmd.append("echo 'SSH connection established'") - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=15 - ) - if result.returncode != 0: - error_msg = result.stderr.strip() or result.stdout.strip() - raise RuntimeError(f"SSH connection failed: {error_msg}") - except subprocess.TimeoutExpired: - raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") - - def execute(self, command: str, cwd: str = "", *, timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command on the remote host via SSH.""" - work_dir = cwd or self.cwd - effective_timeout = timeout or self.timeout - - # Transform sudo commands if SUDO_PASSWORD is available - exec_command = _transform_sudo_command(command) - - # Wrap command to run in the correct directory - wrapped_command = f'cd {work_dir} && {exec_command}' - - cmd = self._build_ssh_command() - cmd.extend(["bash", "-c", wrapped_command]) - - run_kwargs = { - "text": True, - "timeout": effective_timeout, - "encoding": "utf-8", - "errors": "replace", - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - } - if stdin_data is not None: - run_kwargs["input"] = stdin_data - else: - run_kwargs["stdin"] = subprocess.DEVNULL - - try: - result = subprocess.run(cmd, **run_kwargs) - return {"output": result.stdout, "returncode": result.returncode} - except subprocess.TimeoutExpired: - return {"output": f"Command timed out after {effective_timeout}s", "returncode": 124} - except Exception as e: - return {"output": f"SSH execution error: {str(e)}", "returncode": 1} - - def cleanup(self): - """Close the SSH ControlMaster connection.""" - if self.control_socket.exists(): - try: - # Send exit command to ControlMaster - cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", - f"{self.user}@{self.host}"] - subprocess.run(cmd, capture_output=True, timeout=5) - except (OSError, subprocess.SubprocessError): - pass - - # Remove socket file - try: - self.control_socket.unlink() - except OSError: - pass - - def stop(self): - """Alias for cleanup.""" - self.cleanup() - - def __del__(self): - """Cleanup on destruction.""" - try: - self.cleanup() - except Exception: - pass - - -class _DockerEnvironment: - """ - Docker execution environment wrapper with sudo support and non-blocking stdin. - - Wraps mini-swe-agent's DockerEnvironment but adds: - - stdin=DEVNULL to prevent hanging on interactive prompts - - SUDO_PASSWORD support via _transform_sudo_command - """ - - def __init__(self, image: str, cwd: str = "/", timeout: int = 60): - from minisweagent.environments.docker import DockerEnvironment - self._inner = DockerEnvironment(image=image, cwd=cwd, timeout=timeout) - self.cwd = cwd - self.timeout = timeout - - def execute(self, command: str, cwd: str = "", *, timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command in the Docker container with sudo support.""" - # Transform sudo commands if SUDO_PASSWORD is available - exec_command = _transform_sudo_command(command) - - work_dir = cwd or self.cwd - effective_timeout = timeout or self.timeout - - # Get container_id from inner environment - assert self._inner.container_id, "Container not started" - - cmd = [self._inner.config.executable, "exec"] - if stdin_data is not None: - cmd.append("-i") # Enable stdin piping into the container - cmd.extend(["-w", work_dir]) - for key in self._inner.config.forward_env: - if (value := os.getenv(key)) is not None: - cmd.extend(["-e", f"{key}={value}"]) - for key, value in self._inner.config.env.items(): - cmd.extend(["-e", f"{key}={value}"]) - cmd.extend([self._inner.container_id, "bash", "-lc", exec_command]) - - run_kwargs = { - "text": True, - "timeout": effective_timeout, - "encoding": "utf-8", - "errors": "replace", - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - } - if stdin_data is not None: - run_kwargs["input"] = stdin_data - else: - run_kwargs["stdin"] = subprocess.DEVNULL - - try: - result = subprocess.run(cmd, **run_kwargs) - return {"output": result.stdout, "returncode": result.returncode} - except subprocess.TimeoutExpired: - return {"output": f"Command timed out after {effective_timeout}s", "returncode": 124} - - def cleanup(self): - """Cleanup the Docker container.""" - self._inner.cleanup() - - def stop(self): - """Alias for cleanup.""" - self.cleanup() - - def __del__(self): - """Cleanup on destruction.""" - try: - self.cleanup() - except Exception: - pass - - -class _ModalEnvironment: - """ - Modal cloud execution environment wrapper with sudo support. - - Wraps mini-swe-agent's SwerexModalEnvironment but adds: - - SUDO_PASSWORD support via _transform_sudo_command - - Automatic async-safety patches (applied once, before first use) - - The patches replace SwerexModalEnvironment's asyncio.run() calls with a - background thread approach, making it safe to use inside any event loop - (e.g., Atropos). Applied here at the point of use rather than relying on - import-time side effects, so ALL callers get the fix automatically. - """ - - # Class-level flag: patches only need to be applied once - _patches_applied = False - - def __init__(self, image: str, cwd: str = "/root", timeout: int = 60): - # Ensure async-safety patches are applied before creating any - # SwerexModalEnvironment instance. This is the single authoritative - # place -- no other module needs to call apply_patches() for Modal. - if not _ModalEnvironment._patches_applied: - try: - from environments.patches import apply_patches - apply_patches() - except ImportError: - pass # patches module not available (standalone use) - _ModalEnvironment._patches_applied = True - - from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment - # Generous startup timeout: sandbox creation can take 30-60s for cold images, - # and the SWE-ReX runtime needs another 10-30s to boot inside it. - self._inner = SwerexModalEnvironment( - image=image, cwd=cwd, timeout=timeout, - startup_timeout=180.0, - runtime_timeout=3600.0, - ) - self.cwd = cwd - self.timeout = timeout - - def execute(self, command: str, cwd: str = "", *, timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command in Modal with sudo support. - - Modal uses HTTP transport (no execve), so there's no ARG_MAX limit. - When stdin_data is provided, we embed it as a heredoc since there's - no process-level stdin pipe to the cloud sandbox. - """ - if stdin_data is not None: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - while marker in stdin_data: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - command = f"{command} << '{marker}'\n{stdin_data}\n{marker}" - - # Transform sudo commands if SUDO_PASSWORD is available - exec_command = _transform_sudo_command(command) - - # Delegate to inner environment with transformed command - return self._inner.execute(exec_command, cwd=cwd, timeout=timeout) - - def cleanup(self): - """Cleanup the Modal deployment.""" - if hasattr(self._inner, 'stop'): - self._inner.stop() - - def stop(self): - """Stop the Modal deployment.""" - self.cleanup() - - def __del__(self): - """Cleanup on destruction.""" - try: - self.cleanup() - except Exception: - pass +# Environment classes now live in tools/environments/ +from tools.environments.local import LocalEnvironment as _LocalEnvironment +from tools.environments.singularity import SingularityEnvironment as _SingularityEnvironment +from tools.environments.ssh import SSHEnvironment as _SSHEnvironment +from tools.environments.docker import DockerEnvironment as _DockerEnvironment +from tools.environments.modal import ModalEnvironment as _ModalEnvironment # Tool description for LLM