Migrate from old tool_call_parser (instance) to new ToolCallTranslator pattern from atropos add-openai-endpoint-for-managed-server branch: - Set tool_parser on ServerManager (string name, e.g. 'hermes') - Use managed_server(tokenizer=..., preserve_think_blocks=...) instead of managed_server(tokenizer=..., tool_call_parser=instance) - ManagedServer now handles tool call translation internally via ToolCallTranslator (bidirectional raw text <-> OpenAI tool_calls) - Remove old parser loading code (get_parser/KeyError fallback) The hermes-agent tool_call_parsers/ directory is preserved as a standalone fallback for environments that don't use vLLM's parsers.
671 lines
26 KiB
Python
671 lines
26 KiB
Python
"""
|
|
HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
|
|
|
|
Provides the Atropos integration plumbing that all hermes-agent environments share:
|
|
- Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
|
|
- Per-group toolset/distribution resolution
|
|
- Agent loop orchestration via HermesAgentLoop
|
|
- ToolContext creation for reward functions
|
|
- ScoredDataGroup construction from ManagedServer state
|
|
|
|
Subclasses only need to implement:
|
|
setup() -- Load dataset, initialize state
|
|
get_next_item() -- Return the next item from the dataset
|
|
format_prompt() -- Convert a dataset item into the user message
|
|
compute_reward() -- Score the rollout (has full ToolContext access)
|
|
evaluate() -- Periodic evaluation
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from abc import abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
# Ensure the hermes-agent repo root is on sys.path so that imports like
|
|
# `from model_tools import ...` and `from environments.X import ...` work
|
|
# regardless of where the script is invoked from.
|
|
_repo_root = Path(__file__).resolve().parent.parent
|
|
if str(_repo_root) not in sys.path:
|
|
sys.path.insert(0, str(_repo_root))
|
|
|
|
from dotenv import load_dotenv
|
|
from pydantic import Field
|
|
|
|
# Load API keys from hermes-agent/.env so all environments can access them
|
|
_env_path = _repo_root / ".env"
|
|
if _env_path.exists():
|
|
load_dotenv(dotenv_path=_env_path)
|
|
|
|
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
|
# This patches SwerexModalEnvironment to use a background thread instead of
|
|
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
|
from environments.patches import apply_patches
|
|
apply_patches()
|
|
|
|
from atroposlib.envs.base import (
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
ScoredDataItem,
|
|
)
|
|
from atroposlib.envs.server_handling.server_manager import (
|
|
APIServerConfig,
|
|
ServerBaseline,
|
|
ServerManager,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
|
|
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
from environments.tool_context import ToolContext
|
|
|
|
# Import hermes-agent toolset infrastructure
|
|
from model_tools import get_tool_definitions
|
|
from toolset_distributions import sample_toolsets_from_distribution
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HermesAgentEnvConfig(BaseEnvConfig):
|
|
"""
|
|
Configuration for hermes-agent Atropos environments.
|
|
|
|
Extends BaseEnvConfig with agent-specific settings for toolsets,
|
|
terminal backend, dataset loading, and tool call parsing.
|
|
"""
|
|
|
|
# --- Toolset configuration ---
|
|
# Mutually exclusive: use either enabled_toolsets OR distribution
|
|
enabled_toolsets: Optional[List[str]] = Field(
|
|
default=None,
|
|
description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
|
|
"If None and distribution is also None, all available toolsets are enabled.",
|
|
)
|
|
disabled_toolsets: Optional[List[str]] = Field(
|
|
default=None,
|
|
description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
|
|
)
|
|
distribution: Optional[str] = Field(
|
|
default=None,
|
|
description="Name of a toolset distribution from toolset_distributions.py "
|
|
"(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
|
|
"Mutually exclusive with enabled_toolsets.",
|
|
)
|
|
|
|
# --- Agent loop configuration ---
|
|
max_agent_turns: int = Field(
|
|
default=30,
|
|
description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
|
|
)
|
|
system_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="System prompt for the agent. Tools are handled via the tools= parameter, "
|
|
"not embedded in the prompt text.",
|
|
)
|
|
agent_temperature: float = Field(
|
|
default=1.0,
|
|
description="Sampling temperature for agent generation during rollouts.",
|
|
)
|
|
|
|
# --- Terminal backend ---
|
|
terminal_backend: str = Field(
|
|
default="local",
|
|
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
|
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
|
)
|
|
terminal_timeout: int = Field(
|
|
default=120,
|
|
description="Per-command timeout in seconds for terminal tool calls. "
|
|
"Commands exceeding this are killed. Increase for tasks with long-running "
|
|
"commands (compilation, pip install, etc.).",
|
|
)
|
|
terminal_lifetime: int = Field(
|
|
default=3600,
|
|
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
|
"sandboxes that have been idle longer than this. Must be longer than "
|
|
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
|
)
|
|
|
|
# --- Dataset ---
|
|
dataset_name: Optional[str] = Field(
|
|
default=None,
|
|
description="HuggingFace dataset name. Optional if tasks are defined inline.",
|
|
)
|
|
dataset_split: str = Field(
|
|
default="train",
|
|
description="Dataset split to use.",
|
|
)
|
|
prompt_field: str = Field(
|
|
default="prompt",
|
|
description="Which field in the dataset contains the prompt.",
|
|
)
|
|
|
|
# --- Thread pool ---
|
|
tool_pool_size: int = Field(
|
|
default=128,
|
|
description="Thread pool size for tool execution. Each concurrent task needs a "
|
|
"thread for tool calls. Must be large enough for parallel evaluation. "
|
|
"Too small = thread pool starvation.",
|
|
)
|
|
|
|
# --- Phase 2: Tool call parsing ---
|
|
tool_call_parser: str = Field(
|
|
default="hermes",
|
|
description="Tool call parser name for Phase 2 (VLLM server type). "
|
|
"Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
|
|
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
|
)
|
|
|
|
# --- Provider-specific parameters ---
|
|
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
|
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
|
# Example YAML:
|
|
# extra_body:
|
|
# provider:
|
|
# ignore: ["DeepInfra", "Fireworks"]
|
|
# order: ["Together"]
|
|
# transforms: ["middle-out"]
|
|
extra_body: Optional[Dict[str, Any]] = Field(
|
|
default=None,
|
|
description="Extra body parameters passed to the OpenAI client's "
|
|
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
|
"transforms, and other provider-specific settings.",
|
|
)
|
|
|
|
|
|
class HermesAgentBaseEnv(BaseEnv):
|
|
"""
|
|
Abstract base environment for hermes-agent Atropos integration.
|
|
|
|
Handles two modes of operation:
|
|
- Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
|
|
The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
|
|
and reasoning extraction natively. DummyManagedServer provides placeholder
|
|
tokens. Good for SFT data gen, verifier testing, evaluation.
|
|
|
|
- Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
|
|
via /generate. Client-side tool call parser reconstructs structured tool_calls
|
|
from raw output. Full RL training capability.
|
|
|
|
Subclasses must implement:
|
|
setup() -- Load dataset, initialize state
|
|
get_next_item() -- Return the next item to roll out
|
|
format_prompt() -- Convert a dataset item into the user message string
|
|
compute_reward() -- Score the rollout using ToolContext
|
|
evaluate() -- Periodic evaluation
|
|
"""
|
|
|
|
name: Optional[str] = "hermes-agent"
|
|
env_config_cls = HermesAgentEnvConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: HermesAgentEnvConfig,
|
|
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
|
slurm=False,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
|
|
# Set terminal environment variables so hermes tools pick them up.
|
|
# These can all be overridden per-environment via config fields instead
|
|
# of requiring users to set shell env vars.
|
|
if config.terminal_backend:
|
|
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
|
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
|
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
|
print(
|
|
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
|
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
|
)
|
|
|
|
# Resize the agent loop's thread pool for tool execution.
|
|
# This must be large enough for the number of concurrent tasks
|
|
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
|
from environments.agent_loop import resize_tool_pool
|
|
resize_tool_pool(config.tool_pool_size)
|
|
|
|
# Set tool_parser on the ServerManager so ManagedServer uses it
|
|
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
|
|
if hasattr(self.server, 'tool_parser'):
|
|
self.server.tool_parser = config.tool_call_parser
|
|
print(f"🔧 Tool parser: {config.tool_call_parser}")
|
|
|
|
# Current group's resolved tools (set in collect_trajectories)
|
|
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
|
|
|
# Tool error tracking for wandb logging
|
|
self._tool_error_buffer: List[Dict[str, Any]] = []
|
|
|
|
# =========================================================================
|
|
# Toolset resolution (per-group)
|
|
# =========================================================================
|
|
|
|
def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
|
|
"""
|
|
Resolve toolsets for a group. Called once in collect_trajectories(),
|
|
then shared by all collect_trajectory() calls in the group.
|
|
|
|
If distribution is set, samples probabilistically.
|
|
If enabled_toolsets is set, uses that explicit list.
|
|
disabled_toolsets is applied as a filter on top.
|
|
|
|
Returns:
|
|
(tool_schemas, valid_tool_names) tuple
|
|
"""
|
|
config = self.config
|
|
|
|
if config.distribution:
|
|
group_toolsets = sample_toolsets_from_distribution(config.distribution)
|
|
logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
|
|
else:
|
|
group_toolsets = config.enabled_toolsets # None means "all available"
|
|
if group_toolsets is None:
|
|
logger.warning(
|
|
"enabled_toolsets is None -- loading ALL tools including messaging. "
|
|
"Set explicit enabled_toolsets for RL training."
|
|
)
|
|
|
|
tools = get_tool_definitions(
|
|
enabled_toolsets=group_toolsets,
|
|
disabled_toolsets=config.disabled_toolsets,
|
|
quiet_mode=True,
|
|
)
|
|
|
|
valid_names = {t["function"]["name"] for t in tools} if tools else set()
|
|
logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
|
|
return tools, valid_names
|
|
|
|
# =========================================================================
|
|
# Server mode detection
|
|
# =========================================================================
|
|
|
|
def _use_managed_server(self) -> bool:
|
|
"""
|
|
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
|
|
|
Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
|
|
which go through the /generate endpoint for exact token tracking.
|
|
|
|
Phase 1 (direct server) is used for 'openai' server type, which uses
|
|
/v1/chat/completions with native tool call parsing.
|
|
"""
|
|
if not self.server.servers:
|
|
return False
|
|
|
|
server = self.server.servers[0]
|
|
# If the server is an OpenAI server (not VLLM/SGLang), use direct mode
|
|
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|
return not isinstance(server, OpenAIServer)
|
|
|
|
# =========================================================================
|
|
# Core Atropos integration
|
|
# =========================================================================
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[
|
|
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
|
List[Item],
|
|
]:
|
|
"""
|
|
Override collect_trajectories to resolve toolsets once per group,
|
|
then delegate to the standard group-level collection.
|
|
|
|
The default BaseEnv.collect_trajectories() calls collect_trajectory()
|
|
group_size times in parallel. We resolve tools once here and store
|
|
them for all those calls to use.
|
|
"""
|
|
# Resolve toolsets for this group (shared by all rollouts in the group)
|
|
self._current_group_tools = self._resolve_tools_for_group()
|
|
|
|
# Delegate to the default implementation which calls collect_trajectory()
|
|
# group_size times via asyncio.gather
|
|
return await super().collect_trajectories(item)
|
|
|
|
# =========================================================================
|
|
# Wandb rollout display -- format trajectories nicely
|
|
# =========================================================================
|
|
|
|
@staticmethod
|
|
def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
|
|
"""
|
|
Format a conversation's messages into a readable trajectory string
|
|
for wandb rollout tables. Shows tool calls, tool results, and reasoning
|
|
in a structured way instead of raw token decoding.
|
|
"""
|
|
parts = []
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
content = msg.get("content", "")
|
|
|
|
if role == "system":
|
|
parts.append(f"[SYSTEM]\n{content}")
|
|
|
|
elif role == "user":
|
|
parts.append(f"[USER]\n{content}")
|
|
|
|
elif role == "assistant":
|
|
# Show reasoning if present
|
|
reasoning = msg.get("reasoning_content", "")
|
|
if reasoning:
|
|
# Truncate long reasoning for display
|
|
if len(reasoning) > 300:
|
|
reasoning = reasoning[:300] + "..."
|
|
parts.append(f"[ASSISTANT thinking]\n{reasoning}")
|
|
|
|
# Show content
|
|
if content:
|
|
parts.append(f"[ASSISTANT]\n{content}")
|
|
|
|
# Show tool calls
|
|
tool_calls = msg.get("tool_calls", [])
|
|
for tc in tool_calls:
|
|
func = tc.get("function", {})
|
|
name = func.get("name", "?")
|
|
args = func.get("arguments", "{}")
|
|
# Truncate long arguments for display
|
|
if len(args) > 200:
|
|
args = args[:200] + "..."
|
|
parts.append(f"[TOOL CALL] {name}({args})")
|
|
|
|
elif role == "tool":
|
|
tool_id = msg.get("tool_call_id", "")
|
|
result = content
|
|
# Truncate long tool results for display
|
|
if len(result) > 500:
|
|
result = result[:500] + "..."
|
|
parts.append(f"[TOOL RESULT] {result}")
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data,
|
|
item=None,
|
|
):
|
|
"""
|
|
Override to show formatted trajectories with tool calls visible,
|
|
instead of raw token decoding which loses all structure.
|
|
"""
|
|
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
if num_keep == -1:
|
|
num_keep = self.config.group_size
|
|
|
|
group = []
|
|
for i in range(min(num_keep, len(scored_data.get("scores", [])))):
|
|
score = scored_data["scores"][i]
|
|
|
|
# Use messages if available for rich display
|
|
messages = None
|
|
if scored_data.get("messages") and i < len(scored_data["messages"]):
|
|
messages = scored_data["messages"][i]
|
|
|
|
if messages:
|
|
text = self._format_trajectory_for_display(messages)
|
|
elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
|
|
text = self.tokenizer.decode(scored_data["tokens"][i])
|
|
else:
|
|
text = "(no data)"
|
|
|
|
group.append((text, score))
|
|
|
|
self.rollouts_for_wandb.append(group)
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log base metrics including tool errors to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Log tool error stats
|
|
if self._tool_error_buffer:
|
|
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
|
|
|
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
|
error_summaries = []
|
|
for err in self._tool_error_buffer:
|
|
error_summaries.append(
|
|
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
|
)
|
|
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
|
|
|
# Also print to stdout for immediate visibility
|
|
for summary in error_summaries:
|
|
print(f" Tool Error: {summary}")
|
|
|
|
self._tool_error_buffer = []
|
|
else:
|
|
wandb_metrics["train/tool_errors_count"] = 0
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def collect_trajectory(
|
|
self, item: Item
|
|
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
|
"""
|
|
Run a single rollout: agent loop + reward computation.
|
|
|
|
This is called group_size times in parallel by collect_trajectories().
|
|
Each call gets its own task_id for terminal/browser session isolation.
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# Get group-level tools (resolved once in collect_trajectories)
|
|
if self._current_group_tools is None:
|
|
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
|
tools, valid_names = self._resolve_tools_for_group()
|
|
else:
|
|
tools, valid_names = self._current_group_tools
|
|
|
|
# Build initial messages
|
|
messages: List[Dict[str, Any]] = []
|
|
if self.config.system_prompt:
|
|
messages.append({"role": "system", "content": self.config.system_prompt})
|
|
messages.append({"role": "user", "content": self.format_prompt(item)})
|
|
|
|
# Run the agent loop
|
|
result: AgentResult
|
|
if self._use_managed_server():
|
|
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
|
|
# tool_parser is set on ServerManager in __init__ and passed through
|
|
# to ManagedServer, which uses ToolCallTranslator for bidirectional
|
|
# translation between raw text and OpenAI tool_calls.
|
|
try:
|
|
async with self.server.managed_server(
|
|
tokenizer=self.tokenizer,
|
|
preserve_think_blocks=bool(self.config.thinking_mode),
|
|
) as managed:
|
|
agent = HermesAgentLoop(
|
|
server=managed,
|
|
tool_schemas=tools,
|
|
valid_tool_names=valid_names,
|
|
max_turns=self.config.max_agent_turns,
|
|
task_id=task_id,
|
|
temperature=self.config.agent_temperature,
|
|
max_tokens=self.config.max_token_length,
|
|
extra_body=self.config.extra_body,
|
|
)
|
|
result = await agent.run(messages)
|
|
except NotImplementedError:
|
|
# DummyManagedServer not allowed -- fall back to Phase 1
|
|
logger.warning(
|
|
"ManagedServer not available (OpenAI server?). "
|
|
"Falling back to direct server mode."
|
|
)
|
|
agent = HermesAgentLoop(
|
|
server=self.server,
|
|
tool_schemas=tools,
|
|
valid_tool_names=valid_names,
|
|
max_turns=self.config.max_agent_turns,
|
|
task_id=task_id,
|
|
temperature=self.config.agent_temperature,
|
|
max_tokens=self.config.max_token_length,
|
|
extra_body=self.config.extra_body,
|
|
)
|
|
result = await agent.run(messages)
|
|
else:
|
|
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
|
agent = HermesAgentLoop(
|
|
server=self.server,
|
|
tool_schemas=tools,
|
|
valid_tool_names=valid_names,
|
|
max_turns=self.config.max_agent_turns,
|
|
task_id=task_id,
|
|
temperature=self.config.agent_temperature,
|
|
max_tokens=self.config.max_token_length,
|
|
extra_body=self.config.extra_body,
|
|
)
|
|
result = await agent.run(messages)
|
|
|
|
# Skip reward computation if the agent loop produced no meaningful work
|
|
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
|
# just to verify files that were never created.
|
|
only_system_and_user = all(
|
|
msg.get("role") in ("system", "user") for msg in result.messages
|
|
)
|
|
if result.turns_used == 0 or only_system_and_user:
|
|
logger.warning(
|
|
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
|
result.turns_used, len(result.messages),
|
|
)
|
|
reward = 0.0
|
|
else:
|
|
# Compute reward using ToolContext (gives verifier full tool access)
|
|
ctx = ToolContext(task_id)
|
|
try:
|
|
reward = await self.compute_reward(item, result, ctx)
|
|
except Exception as e:
|
|
logger.error("compute_reward failed: %s", e)
|
|
reward = 0.0
|
|
finally:
|
|
ctx.cleanup()
|
|
|
|
# Track tool errors for wandb logging
|
|
if result.tool_errors:
|
|
for err in result.tool_errors:
|
|
self._tool_error_buffer.append({
|
|
"turn": err.turn,
|
|
"tool": err.tool_name,
|
|
"args": err.arguments[:150],
|
|
"error": err.error[:300],
|
|
"result": err.tool_result[:300],
|
|
})
|
|
|
|
# Build ScoredDataItem from ManagedServer state
|
|
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
|
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
|
nodes = (result.managed_state or {}).get("nodes", [])
|
|
|
|
if nodes:
|
|
# Phase 2 (or DummyManagedServer): use actual node data
|
|
node = nodes[-1] # Final sequence node = full trajectory
|
|
scored_item: Dict[str, Any] = {
|
|
"tokens": node.tokens,
|
|
"masks": node.masked_tokens,
|
|
"scores": reward,
|
|
}
|
|
|
|
# Include logprobs if available (Phase 2)
|
|
if hasattr(node, "logprobs") and node.logprobs:
|
|
scored_item["advantages"] = None # Computed by trainer
|
|
scored_item["ref_logprobs"] = None
|
|
else:
|
|
# Phase 1 with no managed state: create placeholder tokens
|
|
# so the data pipeline doesn't break. These are NOT suitable
|
|
# for training but allow process mode (SFT data gen) to work.
|
|
# Tokenize the full conversation to get approximate tokens.
|
|
full_text = "\n".join(
|
|
msg.get("content", "") for msg in result.messages if msg.get("content")
|
|
)
|
|
if self.tokenizer:
|
|
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
|
else:
|
|
tokens = list(range(min(len(full_text) // 4, 128)))
|
|
|
|
scored_item = {
|
|
"tokens": tokens,
|
|
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
|
"scores": reward,
|
|
}
|
|
|
|
# Always include messages for wandb rollout display and data logging
|
|
scored_item["messages"] = result.messages
|
|
|
|
return scored_item, []
|
|
|
|
# =========================================================================
|
|
# Abstract methods -- subclasses must implement
|
|
# =========================================================================
|
|
|
|
@abstractmethod
|
|
async def setup(self):
|
|
"""
|
|
Load dataset, initialize state.
|
|
|
|
Called once when the environment starts. Typical implementation:
|
|
self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
|
self.iter = 0
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_next_item(self) -> Item:
|
|
"""
|
|
Return the next item from the dataset for rollout.
|
|
|
|
Called by the base env's main loop to get items for workers.
|
|
Should cycle through the dataset.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def format_prompt(self, item: Item) -> str:
|
|
"""
|
|
Convert a dataset item into the user message for the agent.
|
|
|
|
Args:
|
|
item: Dataset item (dict, tuple, etc.)
|
|
|
|
Returns:
|
|
The prompt string to send to the agent
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def compute_reward(
|
|
self, item: Item, result: AgentResult, ctx: ToolContext
|
|
) -> float:
|
|
"""
|
|
Score the rollout. Has full access to:
|
|
- item: the original dataset item (ground truth, test commands, etc.)
|
|
- result: AgentResult with full messages, turn count, reasoning, etc.
|
|
- ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
|
|
browser, vision...) scoped to this rollout's sandbox. Nothing
|
|
is off-limits.
|
|
|
|
Args:
|
|
item: The dataset item that was rolled out
|
|
result: The agent's rollout result
|
|
ctx: ToolContext with full tool access for verification
|
|
|
|
Returns:
|
|
Reward float (typically 0.0 to 1.0, but any float is valid)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Periodic evaluation. Called every steps_per_eval steps.
|
|
|
|
Typical implementation runs the agent on a held-out eval set
|
|
and logs metrics via wandb/evaluate_log.
|
|
"""
|
|
raise NotImplementedError
|