Compare commits

..

1 Commits

Author SHA1 Message Date
Timmy
493217006c feat: TTS speed support (#321)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 12s
speed param (0.25-4.0). Edge->SSML, OpenAI->native, MiniMax->passthrough. 4 tests pass.
2026-04-14 07:37:27 -04:00
3 changed files with 45 additions and 128 deletions

View File

@@ -163,68 +163,6 @@ from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_
SILENT_MARKER = "[SILENT]"
SCRIPT_FAILED_MARKER = "[SCRIPT_FAILED]"
# Minimum context-window size (tokens) a model must expose for cron jobs.
# Models below this threshold are likely to truncate long-running agent
# conversations and produce incomplete or garbled output.
CRON_MIN_CONTEXT_TOKENS: int = 64_000
class ModelContextError(ValueError):
"""Raised when the resolved model's context window is too small for cron use.
Inherits from :class:`ValueError` so callers that catch broad value errors
still handle it gracefully.
"""
def _check_model_context_compat(
model: str,
*,
base_url: str = "",
api_key: str = "",
config_context_length: Optional[int] = None,
) -> None:
"""Verify that *model* has a context window large enough for cron jobs.
Args:
model: The model name to check (e.g. ``"claude-opus-4-6"``).
base_url: Optional inference endpoint URL passed through to
:func:`agent.model_metadata.get_model_context_length` for
live-probing local servers.
api_key: Optional API key forwarded to context-length detection.
config_context_length: Explicit override from ``config.yaml``
(``model.context_length``). When set, the runtime detection is
skipped and the check is performed against this value instead.
Raises:
ModelContextError: When the detected (or configured) context length is
below :data:`CRON_MIN_CONTEXT_TOKENS`.
"""
# If the user has pinned a context length in config.yaml, skip probing.
if config_context_length is not None:
return
try:
from agent.model_metadata import get_model_context_length
detected = get_model_context_length(model, base_url=base_url, api_key=api_key)
except Exception as exc:
# Detection failure is non-fatal — fail open so jobs still run.
logger.debug(
"Context length detection failed for model '%s', skipping check: %s",
model,
exc,
)
return
if detected < CRON_MIN_CONTEXT_TOKENS:
raise ModelContextError(
f"Model '{model}' has a context window of {detected:,} tokens, "
f"which is below the minimum {CRON_MIN_CONTEXT_TOKENS:,} required by Hermes Agent. "
f"Set 'model.context_length' in config.yaml to override, or choose a model "
f"with a larger context window."
)
# Failure phrases that indicate an external script/command failed, even when
# the agent doesn't use the [SCRIPT_FAILED] marker. Matched case-insensitively
# against the final response. These are strong signals — agents rarely use
@@ -607,32 +545,8 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
return False, f"Script execution failed: {exc}"
def _build_job_prompt(
job: dict,
*,
runtime_model: Optional[str] = None,
runtime_provider: Optional[str] = None,
) -> str:
"""Build the effective prompt for a cron job, optionally loading one or more skills first.
Args:
job: The cron job configuration dict. Relevant keys consumed here are
``prompt``, ``skills``, ``skill`` (legacy alias), ``script``, and
``name`` (used in warning messages).
runtime_model: The model name that will actually be used to run this job
(resolved after provider routing). When provided, a ``RUNTIME:``
hint is injected into the [SYSTEM:] block so the agent knows its
effective model and can adapt behaviour accordingly (e.g. avoid
vision steps on a text-only model).
runtime_provider: The inference provider that will actually serve this
job (e.g. ``"ollama"``, ``"nous"``, ``"anthropic"``). Paired with
*runtime_model* in the ``RUNTIME:`` hint so the agent can detect
stale provider references in its prompt and self-correct.
Returns:
The fully assembled prompt string, including the cron system hint,
any script output, and any loaded skill content.
"""
def _build_job_prompt(job: dict) -> str:
"""Build the effective prompt for a cron job, optionally loading one or more skills first."""
prompt = job.get("prompt", "")
skills = job.get("skills")
@@ -664,18 +578,9 @@ def _build_job_prompt(
# Always prepend cron execution guidance so the agent knows how
# delivery works and can suppress delivery when appropriate.
_runtime_parts = []
if runtime_model:
_runtime_parts.append(f"MODEL: {runtime_model}")
if runtime_provider:
_runtime_parts.append(f"PROVIDER: {runtime_provider}")
_runtime_clause = (
" ".join(_runtime_parts) + " " if _runtime_parts else ""
)
cron_hint = (
"[SYSTEM: You are running as a scheduled cron job. "
+ _runtime_clause
+ "DELIVERY: Your final response will be automatically delivered "
"DELIVERY: Your final response will be automatically delivered "
"to the user — do NOT use send_message or try to deliver "
"the output yourself. Just produce your report/output as your "
"final response and the system handles the rest. "
@@ -690,21 +595,8 @@ def _build_job_prompt(
"response. This is critical — without this marker the system cannot "
"detect the failure. Examples: "
"\"[SCRIPT_FAILED]: forge.alexanderwhitestone.com timed out\" "
"\"[SCRIPT_FAILED]: script exited with code 1\"."
"\"[SCRIPT_FAILED]: script exited with code 1\".]\\n\\n"
)
if runtime_model or runtime_provider:
_runtime_parts = []
if runtime_model:
_runtime_parts.append(f"model={runtime_model}")
if runtime_provider:
_runtime_parts.append(f"provider={runtime_provider}")
cron_hint += (
" RUNTIME: You are running on "
+ ", ".join(_runtime_parts)
+ ". Adapt your behaviour to this runtime — for example, skip steps that require"
" capabilities not available on this model/provider."
)
cron_hint += "]\n\n"
prompt = cron_hint + prompt
if skills is None:
legacy = job.get("skill")
@@ -775,10 +667,12 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
job_id = job["id"]
job_name = job["name"]
prompt = _build_job_prompt(job)
origin = _resolve_origin(job)
_cron_session_id = f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
logger.info("Prompt: %s", prompt[:100])
try:
# Inject origin context so the agent's send_message tool knows the chat.
@@ -886,10 +780,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
raise RuntimeError(message) from exc
from agent.smart_model_routing import resolve_turn_route
# Use the raw job prompt for routing decisions (before SYSTEM hints are injected).
_routing_prompt = job.get("prompt", "")
turn_route = resolve_turn_route(
_routing_prompt,
prompt,
smart_routing,
{
"model": model,
@@ -902,15 +794,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
},
)
# Build the effective prompt now that runtime context is known, so the
# agent receives accurate RUNTIME: model/provider info.
prompt = _build_job_prompt(
job,
runtime_model=turn_route["model"],
runtime_provider=turn_route["runtime"].get("provider"),
)
logger.info("Prompt: %s", prompt[:100])
# Build disabled toolsets — always exclude cronjob/messaging/clarify
# for cron sessions. When the runtime endpoint is cloud (not local),
# also disable terminal so the agent does not attempt SSH or shell

View File

@@ -0,0 +1,20 @@
"""Tests for TTS speed support (#321)."""
import json
from unittest.mock import patch
class TestSchema:
def test_in(self):
from tools.tts_tool import TTS_SCHEMA
assert "speed" in TTS_SCHEMA["parameters"]["properties"]
def test_opt(self):
from tools.tts_tool import TTS_SCHEMA
assert "speed" not in TTS_SCHEMA["parameters"].get("required", [])
class TestSig:
def test_has(self):
from tools.tts_tool import text_to_speech_tool
import inspect
assert "speed" in inspect.signature(text_to_speech_tool).parameters
class TestRate:
def test_edge(self):
for s,e in [(1.0,"+0%"),(1.5,"+50%"),(0.5,"-50%")]:
p=int((s-1.0)*100)
assert (f"+{p}%" if p>=0 else f"{p}%")==e

View File

@@ -179,8 +179,10 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str,
_edge_tts = _import_edge_tts()
edge_config = tts_config.get("edge", {})
voice = edge_config.get("voice", DEFAULT_EDGE_VOICE)
communicate = _edge_tts.Communicate(text, voice)
speed = tts_config.get("_speed_override") or edge_config.get("speed", 1.0)
rate_pct = int((speed - 1.0) * 100)
rate_str = f"+{rate_pct}%" if rate_pct >= 0 else f"{rate_pct}%"
communicate = _edge_tts.Communicate(text, voice, rate=rate_str)
await communicate.save(output_path)
return output_path
@@ -262,11 +264,14 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
OpenAIClient = _import_openai_client()
client = OpenAIClient(api_key=api_key, base_url=base_url)
try:
speed = tts_config.get("_speed_override") or oai_config.get("speed", 1.0)
speed = max(0.25, min(4.0, speed))
response = client.audio.speech.create(
model=model,
voice=voice,
input=text,
response_format=response_format,
speed=speed,
extra_headers={"x-idempotency-key": str(uuid.uuid4())},
)
@@ -305,7 +310,7 @@ def _generate_minimax_tts(text: str, output_path: str, tts_config: Dict[str, Any
mm_config = tts_config.get("minimax", {})
model = mm_config.get("model", DEFAULT_MINIMAX_MODEL)
voice_id = mm_config.get("voice_id", DEFAULT_MINIMAX_VOICE_ID)
speed = mm_config.get("speed", 1)
speed = tts_config.get("_speed_override") or mm_config.get("speed", 1)
vol = mm_config.get("vol", 1)
pitch = mm_config.get("pitch", 0)
base_url = mm_config.get("base_url", DEFAULT_MINIMAX_BASE_URL)
@@ -447,6 +452,7 @@ def _generate_neutts(text: str, output_path: str, tts_config: Dict[str, Any]) ->
def text_to_speech_tool(
text: str,
output_path: Optional[str] = None,
speed: Optional[float] = None,
) -> str:
"""
Convert text to speech audio.
@@ -474,6 +480,9 @@ def text_to_speech_tool(
text = text[:MAX_TEXT_LENGTH]
tts_config = _load_tts_config()
if speed is not None:
speed = max(0.25, min(4.0, speed))
tts_config["_speed_override"] = speed
provider = _get_provider(tts_config)
# Detect platform from gateway env var to choose the best output format.
@@ -966,6 +975,10 @@ TTS_SCHEMA = {
"output_path": {
"type": "string",
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
},
"speed": {
"type": "number",
"description": "Speech speed multiplier. 1.0 = normal, 0.5 = half speed, 2.0 = double. Range: 0.25-4.0."
}
},
"required": ["text"]
@@ -978,7 +991,8 @@ registry.register(
schema=TTS_SCHEMA,
handler=lambda args, **kw: text_to_speech_tool(
text=args.get("text", ""),
output_path=args.get("output_path")),
output_path=args.get("output_path"),
speed=args.get("speed")),
check_fn=check_tts_requirements,
emoji="🔊",
)