prevent leakage of morph instances between tasks

This commit is contained in:
hjc-puro
2025-11-04 03:32:43 -05:00
parent a4db3fdee5
commit fbd3a2fdb8
4 changed files with 177 additions and 27 deletions

View File

@@ -156,7 +156,6 @@ def _process_single_prompt(
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
# Initialize agent with sampled toolsets
# Use prompt_index as task_id to ensure each task gets its own isolated VM
agent = AIAgent(
base_url=config.get("base_url"),
api_key=config.get("api_key"),
@@ -165,12 +164,11 @@ def _process_single_prompt(
enabled_toolsets=selected_toolsets,
save_trajectories=False, # We handle saving ourselves
verbose_logging=config.get("verbose", False),
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
task_id=f"task_{prompt_index}"
ephemeral_system_prompt=config.get("ephemeral_system_prompt")
)
# Run the agent
result = agent.run_conversation(prompt)
# Run the agent with task_id to ensure each task gets its own isolated VM
result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
# Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"])

View File

@@ -28,7 +28,7 @@ Usage:
import json
import asyncio
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION

View File

@@ -43,6 +43,7 @@ else:
# Import our tool system
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
from tools.terminal_tool import cleanup_vm
class AIAgent:
@@ -64,8 +65,7 @@ class AIAgent:
disabled_toolsets: List[str] = None,
save_trajectories: bool = False,
verbose_logging: bool = False,
ephemeral_system_prompt: str = None,
task_id: str = None
ephemeral_system_prompt: str = None
):
"""
Initialize the AI Agent.
@@ -81,7 +81,6 @@ class AIAgent:
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
verbose_logging (bool): Enable verbose logging for debugging (default: False)
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
"""
self.model = model
self.max_iterations = max_iterations
@@ -90,10 +89,6 @@ class AIAgent:
self.verbose_logging = verbose_logging
self.ephemeral_system_prompt = ephemeral_system_prompt
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
import uuid
self.task_id = task_id or str(uuid.uuid4())
# Store toolset filtering options
self.enabled_toolsets = enabled_toolsets
self.disabled_toolsets = disabled_toolsets
@@ -348,22 +343,27 @@ class AIAgent:
print(f"⚠️ Failed to save trajectory: {e}")
def run_conversation(
self,
user_message: str,
system_message: str = None,
conversation_history: List[Dict[str, Any]] = None
self,
user_message: str,
system_message: str = None,
conversation_history: List[Dict[str, Any]] = None,
task_id: str = None
) -> Dict[str, Any]:
"""
Run a complete conversation with tool calling until completion.
Args:
user_message (str): The user's message/question
system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided)
conversation_history (List[Dict]): Previous conversation messages (optional)
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided)
Returns:
Dict: Complete conversation result with final response and message history
"""
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
import uuid
effective_task_id = task_id or str(uuid.uuid4())
# Initialize conversation
messages = conversation_history or []
@@ -479,7 +479,7 @@ class AIAgent:
tool_start_time = time.time()
# Execute the tool with task_id to isolate VMs between concurrent tasks
function_result = handle_function_call(function_name, function_args, self.task_id)
function_result = handle_function_call(function_name, function_args, effective_task_id)
tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result
@@ -543,10 +543,17 @@ class AIAgent:
# Determine if conversation completed successfully
completed = final_response is not None and api_call_count < self.max_iterations
# Save trajectory if enabled
self._save_trajectory(messages, user_message, completed)
# Clean up VM for this task after conversation completes
try:
cleanup_vm(effective_task_id)
except Exception as e:
if self.verbose_logging:
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
return {
"final_response": final_response,
"messages": messages,

View File

@@ -4,8 +4,12 @@ Terminal Tool Module
This module provides a single terminal tool using Hecate's VM infrastructure.
It wraps Hecate's functionality to provide a simple interface for executing commands
on Morph VMs with automatic lifecycle management. VMs live for 5 minutes after last use.
Timer resets with each use.
on Morph VMs with automatic lifecycle management.
VM Lifecycle:
- VMs have a TTL (time to live) set at creation (default: 20 minutes)
- VMs are also cleaned up locally after 5 minutes of inactivity
- Timer resets with each use
Available tool:
- terminal_tool: Execute commands with optional interactive session support
@@ -24,6 +28,8 @@ import json
import os
import uuid
import threading
import time
import atexit
from typing import Optional, Dict, Any
# Detailed description for the terminal tool based on Hermes Terminal system prompt
@@ -78,7 +84,134 @@ When commands enter interactive mode (vim, nano, less, git prompts, package mana
# Changed to dictionaries keyed by task_id to prevent leakage between concurrent tasks
_active_instances: Dict[str, Any] = {}
_active_contexts: Dict[str, Any] = {}
_last_activity: Dict[str, float] = {} # Track last activity time for each VM
_instance_lock = threading.Lock()
_cleanup_thread = None
_cleanup_running = False
def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300):
"""
Clean up VMs that have been inactive for longer than vm_lifetime_seconds.
This function should be called periodically by a background thread.
Args:
vm_lifetime_seconds: Maximum lifetime in seconds for inactive VMs (default: 300)
"""
global _active_instances, _active_contexts, _last_activity
current_time = time.time()
tasks_to_cleanup = []
with _instance_lock:
# Find all VMs that have been inactive for too long
for task_id, last_time in list(_last_activity.items()):
if current_time - last_time > vm_lifetime_seconds:
tasks_to_cleanup.append(task_id)
# Clean up the inactive VMs
for task_id in tasks_to_cleanup:
try:
if task_id in _active_instances:
instance = _active_instances[task_id]
# Terminate the VM instance
if hasattr(instance, 'terminate'):
instance.terminate()
elif hasattr(instance, 'stop'):
instance.stop()
elif hasattr(instance, 'delete'):
instance.delete()
# Remove from tracking dictionaries
del _active_instances[task_id]
print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}")
if task_id in _active_contexts:
del _active_contexts[task_id]
if task_id in _last_activity:
del _last_activity[task_id]
except Exception as e:
print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}")
def _cleanup_thread_worker():
"""
Background thread worker that periodically cleans up inactive VMs.
Runs every 60 seconds.
"""
global _cleanup_running
while _cleanup_running:
try:
vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
_cleanup_inactive_vms(vm_lifetime)
except Exception as e:
print(f"[VM Cleanup] Error in cleanup thread: {e}")
# Sleep for 60 seconds, but check every second if we should stop
for _ in range(60):
if not _cleanup_running:
break
time.sleep(1)
def _start_cleanup_thread():
"""
Start the background cleanup thread if it's not already running.
"""
global _cleanup_thread, _cleanup_running
with _instance_lock:
if _cleanup_thread is None or not _cleanup_thread.is_alive():
_cleanup_running = True
_cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True)
_cleanup_thread.start()
def _stop_cleanup_thread():
"""
Stop the background cleanup thread.
"""
global _cleanup_running
_cleanup_running = False
if _cleanup_thread is not None:
_cleanup_thread.join(timeout=5)
def cleanup_vm(task_id: str):
"""
Manually clean up a specific VM by task_id.
This should be called when a task is completed.
Args:
task_id: The task ID of the VM to clean up
"""
global _active_instances, _active_contexts, _last_activity
with _instance_lock:
try:
if task_id in _active_instances:
instance = _active_instances[task_id]
# Terminate the VM instance
if hasattr(instance, 'terminate'):
instance.terminate()
elif hasattr(instance, 'stop'):
instance.stop()
elif hasattr(instance, 'delete'):
instance.delete()
# Remove from tracking dictionaries
del _active_instances[task_id]
print(f"[VM Cleanup] Manually terminated VM for task: {task_id}")
if task_id in _active_contexts:
del _active_contexts[task_id]
if task_id in _last_activity:
del _last_activity[task_id]
except Exception as e:
print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}")
# Register cleanup on program exit
atexit.register(_stop_cleanup_thread)
def terminal_tool(
command: Optional[str] = None,
@@ -144,6 +277,7 @@ def terminal_tool(
# Get configuration from environment
vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) # 20 minutes default
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg")
# Check API key
@@ -160,17 +294,27 @@ def terminal_tool(
# If no task_id provided, use "default" for backward compatibility
effective_task_id = task_id or "default"
# Start the cleanup thread if not already running
_start_cleanup_thread()
# Get or create VM instance and execution context per task
# This is critical for interactive session support - the context must persist!
with _instance_lock:
if effective_task_id not in _active_instances:
morph_client = MorphCloudClient(api_key=morph_api_key)
_active_instances[effective_task_id] = morph_client.instances.start(snapshot_id=snapshot_id)
_active_instances[effective_task_id] = morph_client.instances.start(
snapshot_id=snapshot_id,
ttl_seconds=vm_ttl_seconds,
ttl_action="stop"
)
# Get or create persistent execution context per task
if effective_task_id not in _active_contexts:
_active_contexts[effective_task_id] = ExecutionContext()
# Update last activity time for this VM (resets the inactivity timer)
_last_activity[effective_task_id] = time.time()
instance = _active_instances[effective_task_id]
ctx = _active_contexts[effective_task_id]
@@ -303,5 +447,6 @@ if __name__ == "__main__":
print("\nEnvironment Variables:")
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}")
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)")
print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)")
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)")
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')} (default: snapshot_defv9tjg)")