Enhance BatchRunner and AIAgent with new configuration options, default model now opus 4.6, default summarizer gemini flash 3
- Added `max_tokens`, `reasoning_config`, and `prefill_messages` parameters to `BatchRunner` and `AIAgent` for improved model response control. - Updated CLI to support new options for reasoning effort and prefill messages from a JSON file. - Modified example configuration files to reflect changes in default model and summary model. - Improved error handling for loading prefill messages and reasoning configurations in the CLI. - Updated documentation to include new parameters and usage examples.
This commit is contained in:
@@ -244,6 +244,9 @@ def _process_single_prompt(
|
||||
providers_ignored=config.get("providers_ignored"),
|
||||
providers_order=config.get("providers_order"),
|
||||
provider_sort=config.get("provider_sort"),
|
||||
max_tokens=config.get("max_tokens"),
|
||||
reasoning_config=config.get("reasoning_config"),
|
||||
prefill_messages=config.get("prefill_messages"),
|
||||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
@@ -428,6 +431,9 @@ class BatchRunner:
|
||||
providers_ignored: List[str] = None,
|
||||
providers_order: List[str] = None,
|
||||
provider_sort: str = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
@@ -449,6 +455,9 @@ class BatchRunner:
|
||||
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
|
||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking)
|
||||
prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -466,6 +475,9 @@ class BatchRunner:
|
||||
self.providers_ignored = providers_ignored
|
||||
self.providers_order = providers_order
|
||||
self.provider_sort = provider_sort
|
||||
self.max_tokens = max_tokens
|
||||
self.reasoning_config = reasoning_config
|
||||
self.prefill_messages = prefill_messages
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -735,6 +747,9 @@ class BatchRunner:
|
||||
"providers_ignored": self.providers_ignored,
|
||||
"providers_order": self.providers_order,
|
||||
"provider_sort": self.provider_sort,
|
||||
"max_tokens": self.max_tokens,
|
||||
"reasoning_config": self.reasoning_config,
|
||||
"prefill_messages": self.prefill_messages,
|
||||
}
|
||||
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
@@ -956,6 +971,10 @@ def main(
|
||||
providers_ignored: str = None,
|
||||
providers_order: str = None,
|
||||
provider_sort: str = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_effort: str = None,
|
||||
reasoning_disabled: bool = False,
|
||||
prefill_messages_file: str = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
@@ -979,6 +998,10 @@ def main(
|
||||
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -990,9 +1013,13 @@ def main(
|
||||
# Use specific distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
|
||||
|
||||
# With ephemeral system prompt (not saved to dataset)
|
||||
# With disabled reasoning and max tokens
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
--reasoning_disabled --max_tokens=128000
|
||||
|
||||
# With prefill messages from file
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--prefill_messages_file=configs/prefill_opus.json
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
@@ -1031,6 +1058,36 @@ def main(
|
||||
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Build reasoning_config from CLI flags
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
|
||||
reasoning_config = None
|
||||
if reasoning_disabled:
|
||||
# Completely disable reasoning/thinking tokens
|
||||
reasoning_config = {"effort": "none"}
|
||||
print("🧠 Reasoning: DISABLED (effort=none)")
|
||||
elif reasoning_effort:
|
||||
# Use specified effort level
|
||||
valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"]
|
||||
if reasoning_effort not in valid_efforts:
|
||||
print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
|
||||
return
|
||||
reasoning_config = {"enabled": True, "effort": reasoning_effort}
|
||||
print(f"🧠 Reasoning effort: {reasoning_effort}")
|
||||
|
||||
# Load prefill messages from JSON file if provided
|
||||
prefill_messages = None
|
||||
if prefill_messages_file:
|
||||
try:
|
||||
with open(prefill_messages_file, 'r', encoding='utf-8') as f:
|
||||
prefill_messages = json.load(f)
|
||||
if not isinstance(prefill_messages, list):
|
||||
print(f"❌ Error: prefill_messages_file must contain a JSON array of messages")
|
||||
return
|
||||
print(f"💬 Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading prefill messages: {e}")
|
||||
return
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
@@ -1050,6 +1107,9 @@ def main(
|
||||
providers_ignored=providers_ignored_list,
|
||||
providers_order=providers_order_list,
|
||||
provider_sort=provider_sort,
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
model:
|
||||
# Default model to use (can be overridden with --model flag)
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
default: "anthropic/claude-opus-4.6"
|
||||
|
||||
# API configuration (falls back to OPENROUTER_API_KEY env var)
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
@@ -140,7 +140,7 @@ compression:
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
summary_model: "google/gemini-3-flash-preview"
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
|
||||
31
cli.py
31
cli.py
@@ -83,7 +83,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
# Default configuration
|
||||
defaults = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4-20250514",
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"terminal": {
|
||||
@@ -101,7 +101,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"compression": {
|
||||
"enabled": True, # Auto-compress when approaching context limit
|
||||
"threshold": 0.85, # Compress at 85% of model's context limit
|
||||
"summary_model": "google/gemini-2.0-flash-001", # Fast/cheap model for summaries
|
||||
"summary_model": "google/gemini-3-flash-preview", # Fast/cheap model for summaries
|
||||
},
|
||||
"agent": {
|
||||
"max_turns": 60, # Default max tool-calling iterations
|
||||
@@ -1332,6 +1332,11 @@ class HermesCLI:
|
||||
# Get the final response
|
||||
response = result.get("final_response", "") if result else ""
|
||||
|
||||
# Handle failed results (e.g., non-retryable errors like invalid model)
|
||||
if result and result.get("failed") and not response:
|
||||
error_detail = result.get("error", "Unknown error")
|
||||
response = f"Error: {error_detail}"
|
||||
|
||||
# Handle interrupt - check if we were interrupted
|
||||
pending_message = None
|
||||
if result and result.get("interrupted"):
|
||||
@@ -1403,6 +1408,7 @@ class HermesCLI:
|
||||
self._agent_running = False
|
||||
self._pending_input = queue.Queue()
|
||||
self._should_exit = False
|
||||
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
||||
|
||||
# Create a persistent input area using prompt_toolkit Application
|
||||
input_buffer = Buffer()
|
||||
@@ -1422,11 +1428,28 @@ class HermesCLI:
|
||||
|
||||
@kb.add('c-c')
|
||||
def handle_ctrl_c(event):
|
||||
"""Handle Ctrl+C - interrupt or exit."""
|
||||
"""Handle Ctrl+C - interrupt agent or force exit on double press.
|
||||
|
||||
First Ctrl+C: interrupt the running agent gracefully.
|
||||
Second Ctrl+C within 2 seconds (or when agent is idle): force exit.
|
||||
"""
|
||||
import time as _time
|
||||
now = _time.time()
|
||||
|
||||
if self._agent_running and self.agent:
|
||||
print("\n⚡ Interrupting agent...")
|
||||
# Check for double Ctrl+C (second press within 2 seconds)
|
||||
if now - self._last_ctrl_c_time < 2.0:
|
||||
print("\n⚡ Force exiting...")
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
return
|
||||
|
||||
# First Ctrl+C: try graceful interrupt
|
||||
self._last_ctrl_c_time = now
|
||||
print("\n⚡ Interrupting agent... (press Ctrl+C again to force exit)")
|
||||
self.agent.interrupt()
|
||||
else:
|
||||
# Agent not running, exit immediately
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def run_job(job: dict) -> tuple[bool, str, Optional[str]]:
|
||||
# Create agent with default settings
|
||||
# Jobs run in isolated sessions (no prior context)
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
@@ -481,7 +481,7 @@ class GatewayRunner:
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=[toolset],
|
||||
|
||||
@@ -71,7 +71,7 @@ def ensure_hermes_home():
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"model": "anthropic/claude-sonnet-4.5",
|
||||
"model": "anthropic/claude-opus-4.6",
|
||||
"toolsets": ["hermes-cli"],
|
||||
"max_turns": 100,
|
||||
|
||||
@@ -91,7 +91,7 @@ DEFAULT_CONFIG = {
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.85,
|
||||
"summary_model": "google/gemini-2.0-flash-001",
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
},
|
||||
|
||||
"display": {
|
||||
@@ -555,7 +555,7 @@ def show_config():
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-2.0-flash-001')}")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
|
||||
# Messaging
|
||||
print()
|
||||
|
||||
255
run_agent.py
255
run_agent.py
@@ -66,6 +66,7 @@ _MODEL_CACHE_TTL = 3600 # 1 hour cache TTL
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
@@ -206,7 +207,7 @@ class ContextCompressor:
|
||||
self,
|
||||
model: str,
|
||||
threshold_percent: float = 0.85,
|
||||
summary_model: str = "google/gemini-2.0-flash-001",
|
||||
summary_model: str = "google/gemini-3-flash-preview",
|
||||
protect_first_n: int = 3,
|
||||
protect_last_n: int = 4,
|
||||
summary_target_tokens: int = 500,
|
||||
@@ -584,7 +585,7 @@ class AIAgent:
|
||||
self,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514", # OpenRouter format
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
@@ -601,6 +602,9 @@ class AIAgent:
|
||||
provider_sort: str = None,
|
||||
session_id: str = None,
|
||||
tool_progress_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -625,6 +629,12 @@ class AIAgent:
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
|
||||
tool_progress_callback (callable): Callback function(tool_name, args_preview) for progress notifications
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -653,6 +663,11 @@ class AIAgent:
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
self.disabled_toolsets = disabled_toolsets
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Configure logging
|
||||
if self.verbose_logging:
|
||||
logging.basicConfig(
|
||||
@@ -781,7 +796,7 @@ class AIAgent:
|
||||
# Compresses conversation when approaching model's context limit
|
||||
# Configuration via environment variables (can be set in .env or cli-config.yaml)
|
||||
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-2.0-flash-001")
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-3-flash-preview")
|
||||
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
self.context_compressor = ContextCompressor(
|
||||
@@ -1086,6 +1101,25 @@ class AIAgent:
|
||||
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _convert_scratchpad_to_think(content: str) -> str:
|
||||
"""
|
||||
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
|
||||
|
||||
When native thinking/reasoning is disabled and the model is prompted to
|
||||
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
|
||||
to the standard <think> format used in our trajectory storage.
|
||||
|
||||
Args:
|
||||
content: Assistant message content that may contain scratchpad tags
|
||||
|
||||
Returns:
|
||||
Content with scratchpad tags replaced by think tags
|
||||
"""
|
||||
if not content or "<REASONING_SCRATCHPAD>" not in content:
|
||||
return content
|
||||
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to trajectory format for saving.
|
||||
@@ -1120,14 +1154,19 @@ class AIAgent:
|
||||
"value": system_msg
|
||||
})
|
||||
|
||||
# Add the initial user message
|
||||
# Add the actual user prompt (from the dataset) as the first human message
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": user_query
|
||||
})
|
||||
|
||||
# Process remaining messages
|
||||
i = 1 # Skip the first user message as we already added it
|
||||
# Calculate where agent responses start in the messages list.
|
||||
# Prefill messages are ephemeral (only used to prime model response style)
|
||||
# so we skip them entirely in the saved trajectory.
|
||||
# Layout: [*prefill_msgs, actual_user_msg, ...agent_responses...]
|
||||
num_prefill = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
i = num_prefill + 1 # Skip prefill messages + the actual user message (already added above)
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
@@ -1138,12 +1177,14 @@ class AIAgent:
|
||||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
if msg.get("content") and msg["content"].strip():
|
||||
content += msg["content"] + "\n"
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
|
||||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
@@ -1206,11 +1247,14 @@ class AIAgent:
|
||||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
content += msg["content"] or ""
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
raw_content = msg["content"] or ""
|
||||
content += self._convert_scratchpad_to_think(raw_content)
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
@@ -1261,6 +1305,66 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def _log_api_payload(self, turn_number: int, api_kwargs: Dict[str, Any], response=None):
|
||||
"""
|
||||
[TEMPORARY DEBUG] Log the full API payload and response token metrics
|
||||
for each agent turn to a per-session JSONL file for inspection.
|
||||
|
||||
Writes one JSON line per turn to logs/payload_<session_id>.jsonl.
|
||||
Tool schemas are summarized (just names) to keep logs readable.
|
||||
|
||||
Args:
|
||||
turn_number: Which API call this is (1-indexed)
|
||||
api_kwargs: The full kwargs dict being passed to chat.completions.create
|
||||
response: The API response object (optional, added after the call completes)
|
||||
"""
|
||||
try:
|
||||
payload_log_file = self.logs_dir / f"payload_{self.session_id}.jsonl"
|
||||
|
||||
# Build a serializable copy of the request payload
|
||||
payload = {
|
||||
"turn": turn_number,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": api_kwargs.get("model"),
|
||||
"max_tokens": api_kwargs.get("max_tokens"),
|
||||
"extra_body": api_kwargs.get("extra_body"),
|
||||
"num_tools": len(api_kwargs.get("tools") or []),
|
||||
"tool_names": [t["function"]["name"] for t in (api_kwargs.get("tools") or [])],
|
||||
"messages": api_kwargs.get("messages", []),
|
||||
}
|
||||
|
||||
# Add response token metrics if available
|
||||
if response is not None:
|
||||
try:
|
||||
usage_raw = response.usage.model_dump() if hasattr(response.usage, 'model_dump') else {}
|
||||
payload["response"] = {
|
||||
# Core token counts
|
||||
"prompt_tokens": usage_raw.get("prompt_tokens"),
|
||||
"completion_tokens": usage_raw.get("completion_tokens"),
|
||||
"total_tokens": usage_raw.get("total_tokens"),
|
||||
# Completion breakdown (reasoning tokens, etc.)
|
||||
"completion_tokens_details": usage_raw.get("completion_tokens_details"),
|
||||
# Prompt breakdown (cached tokens, etc.)
|
||||
"prompt_tokens_details": usage_raw.get("prompt_tokens_details"),
|
||||
# Cost tracking
|
||||
"cost": usage_raw.get("cost"),
|
||||
"is_byok": usage_raw.get("is_byok"),
|
||||
"cost_details": usage_raw.get("cost_details"),
|
||||
# Provider info (top-level field from OpenRouter)
|
||||
"provider": getattr(response, 'provider', None),
|
||||
"response_model": getattr(response, 'model', None),
|
||||
}
|
||||
except Exception:
|
||||
payload["response"] = {"error": "failed to extract usage"}
|
||||
|
||||
with open(payload_log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(payload, ensure_ascii=False, default=str) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
# Silent fail - don't interrupt the agent for debug logging
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to log API payload: {e}")
|
||||
|
||||
def _save_session_log(self, messages: List[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save the current session trajectory to the logs directory.
|
||||
@@ -1276,10 +1380,12 @@ class AIAgent:
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract the first user message for the trajectory format
|
||||
# The first message should be the user's initial query
|
||||
# Extract the actual user query for the trajectory format.
|
||||
# Skip prefill messages (they're ephemeral and shouldn't appear in trajectories)
|
||||
# so the first user message we find is the real task prompt.
|
||||
first_user_query = ""
|
||||
for msg in messages:
|
||||
start_idx = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
for msg in messages[start_idx:]:
|
||||
if msg.get("role") == "user":
|
||||
first_user_query = msg.get("content", "")
|
||||
break
|
||||
@@ -1373,6 +1479,12 @@ class AIAgent:
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
# Inject prefill messages at the start of conversation (before user's actual prompt)
|
||||
# This is used for few-shot priming, e.g., a greeting exchange to set response style
|
||||
if self.prefill_messages and not conversation_history:
|
||||
for prefill_msg in self.prefill_messages:
|
||||
messages.append(prefill_msg.copy())
|
||||
|
||||
# Add user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -1493,6 +1605,10 @@ class AIAgent:
|
||||
"timeout": 600.0 # 10 minute timeout for very long responses
|
||||
}
|
||||
|
||||
# Add max_tokens if configured (overrides model default)
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Add extra_body for OpenRouter (provider preferences + reasoning)
|
||||
extra_body = {}
|
||||
|
||||
@@ -1500,12 +1616,17 @@ class AIAgent:
|
||||
if provider_preferences:
|
||||
extra_body["provider"] = provider_preferences
|
||||
|
||||
# Enable reasoning with xhigh effort for OpenRouter
|
||||
# Configure reasoning for OpenRouter
|
||||
# If reasoning_config is explicitly provided, use it (allows disabling/customizing)
|
||||
# Otherwise, default to xhigh effort for OpenRouter models
|
||||
if "openrouter" in self.base_url.lower():
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
if self.reasoning_config is not None:
|
||||
extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
@@ -1527,6 +1648,9 @@ class AIAgent:
|
||||
# Log response with provider info if available
|
||||
resp_model = getattr(response, 'model', 'N/A') if response else 'N/A'
|
||||
logging.debug(f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
# [DEBUG] Log the full API payload + response token metrics
|
||||
self._log_api_payload(api_call_count, api_kwargs, response=response)
|
||||
|
||||
# Validate response has valid choices before proceeding
|
||||
if response is None or not hasattr(response, 'choices') or response.choices is None or len(response.choices) == 0:
|
||||
@@ -1589,7 +1713,20 @@ class AIAgent:
|
||||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
||||
print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...")
|
||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments to stay responsive to interrupts
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
continue # Retry the API call
|
||||
|
||||
# Check finish_reason before proceeding
|
||||
@@ -1668,6 +1805,41 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}")
|
||||
print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
|
||||
|
||||
# Check for interrupt before deciding to retry
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
is_client_error = any(phrase in error_msg for phrase in [
|
||||
'error code: 400', 'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
'is not a valid model', 'invalid model', 'model not found',
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])
|
||||
|
||||
if is_client_error:
|
||||
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
|
||||
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
|
||||
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"failed": True,
|
||||
"error": str(api_error),
|
||||
}
|
||||
|
||||
# Check for non-retryable errors (context length exceeded)
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'maximum context', 'token limit',
|
||||
@@ -1708,7 +1880,21 @@ class AIAgent:
|
||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments so we can respond to interrupts quickly
|
||||
# instead of blocking the entire wait_time in one sleep() call
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2) # Check interrupt every 200ms
|
||||
|
||||
try:
|
||||
assistant_message = response.choices[0].message
|
||||
@@ -2069,13 +2255,28 @@ class AIAgent:
|
||||
if self.ephemeral_system_prompt:
|
||||
api_messages = [{"role": "system", "content": self.ephemeral_system_prompt}] + api_messages
|
||||
|
||||
summary_response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
# Build extra_body for summary call (same reasoning config as main loop)
|
||||
summary_extra_body = {}
|
||||
if "openrouter" in self.base_url.lower():
|
||||
if self.reasoning_config is not None:
|
||||
summary_extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
summary_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
# No tools parameter - forces text response
|
||||
extra_headers=self.extra_headers,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs["max_tokens"] = self.max_tokens
|
||||
if summary_extra_body:
|
||||
summary_kwargs["extra_body"] = summary_extra_body
|
||||
|
||||
summary_response = self.client.chat.completions.create(**summary_kwargs)
|
||||
|
||||
if summary_response.choices and summary_response.choices[0].message.content:
|
||||
final_response = summary_response.choices[0].message.content
|
||||
@@ -2151,7 +2352,7 @@ class AIAgent:
|
||||
|
||||
def main(
|
||||
query: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-opus-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
|
||||
Reference in New Issue
Block a user