feat: show estimated tool token context in hermes tools checklist (#3805)

* feat: show estimated tool token context in hermes tools checklist

Adds a live token estimate indicator to the bottom of the interactive
tool configuration checklist (hermes tools / hermes setup). As users
toggle toolsets on/off, the total estimated context cost updates in
real time.

Implementation:
- tools/registry.py: Add get_schema() for check_fn-free schema access
- hermes_cli/curses_ui.py: Add optional status_fn callback to
  curses_checklist — renders at bottom-right of terminal, stays fixed
  while items scroll
- hermes_cli/tools_config.py: Add _estimate_tool_tokens() using
  tiktoken (cl100k_base, already installed) to count tokens in the
  JSON-serialised OpenAI-format tool schemas. Results are cached
  per-process. The status function deduplicates overlapping tools
  (e.g. browser includes web_search) for accurate totals.
- 12 new tests covering estimation, caching, graceful degradation
  when tiktoken is unavailable, status_fn wiring, deduplication,
  and the numbered fallback display

* fix: use effective toolsets (includes plugins) for token estimation index mapping

The status_fn closure built ts_keys from CONFIGURABLE_TOOLSETS but the
checklist uses _get_effective_configurable_toolsets() which appends plugin
toolsets. With plugins present, the indices would mismatch, causing
IndexError when selecting a plugin toolset.
This commit is contained in:
Teknium
2026-03-29 15:36:56 -07:00
committed by GitHub
parent 475205e30b
commit ee3d2941cc
4 changed files with 382 additions and 4 deletions

View File

@@ -4,7 +4,7 @@ Used by `hermes tools` and `hermes skills` for interactive checklists.
Provides a curses multi-select with keyboard navigation, plus a
text-based numbered fallback for terminals without curses support.
"""
from typing import List, Set
from typing import Callable, List, Optional, Set
from hermes_cli.colors import Colors, color
@@ -15,6 +15,7 @@ def curses_checklist(
selected: Set[int],
*,
cancel_returns: Set[int] | None = None,
status_fn: Optional[Callable[[Set[int]], str]] = None,
) -> Set[int]:
"""Curses multi-select checklist. Returns set of selected indices.
@@ -23,6 +24,9 @@ def curses_checklist(
items: Display labels for each row.
selected: Indices that start checked (pre-selected).
cancel_returns: Returned on ESC/q. Defaults to the original *selected*.
status_fn: Optional callback ``f(chosen_indices) -> str`` whose return
value is rendered on the bottom row of the terminal. Use this for
live aggregate info (e.g. estimated token counts).
"""
if cancel_returns is None:
cancel_returns = set(selected)
@@ -47,6 +51,9 @@ def curses_checklist(
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Reserve bottom row for status bar when status_fn provided
footer_rows = 1 if status_fn else 0
# Header
try:
hattr = curses.A_BOLD
@@ -62,7 +69,7 @@ def curses_checklist(
pass
# Scrollable item list
visible_rows = max_y - 3
visible_rows = max_y - 3 - footer_rows
if cursor < scroll_offset:
scroll_offset = cursor
elif cursor >= scroll_offset + visible_rows:
@@ -72,7 +79,7 @@ def curses_checklist(
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
):
y = draw_i + 3
if y >= max_y - 1:
if y >= max_y - 1 - footer_rows:
break
check = "" if i in chosen else " "
arrow = "" if i == cursor else " "
@@ -87,6 +94,20 @@ def curses_checklist(
except curses.error:
pass
# Status bar (bottom row, right-aligned)
if status_fn:
try:
status_text = status_fn(chosen)
if status_text:
# Right-align on the bottom row
sx = max(0, max_x - len(status_text) - 1)
sattr = curses.A_DIM
if curses.has_colors():
sattr |= curses.color_pair(3)
stdscr.addnstr(max_y - 1, sx, status_text, max_x - sx - 1, sattr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
@@ -107,7 +128,7 @@ def curses_checklist(
return result_holder[0] if result_holder[0] is not None else cancel_returns
except Exception:
return _numbered_fallback(title, items, selected, cancel_returns)
return _numbered_fallback(title, items, selected, cancel_returns, status_fn)
def _numbered_fallback(
@@ -115,6 +136,7 @@ def _numbered_fallback(
items: List[str],
selected: Set[int],
cancel_returns: Set[int],
status_fn: Optional[Callable[[Set[int]], str]] = None,
) -> Set[int]:
"""Text-based toggle fallback for terminals without curses."""
chosen = set(selected)
@@ -125,6 +147,10 @@ def _numbered_fallback(
for i, label in enumerate(items):
marker = color("[✓]", Colors.GREEN) if i in chosen else "[ ]"
print(f" {marker} {i + 1:>2}. {label}")
if status_fn:
status_text = status_fn(chosen)
if status_text:
print(color(f"\n {status_text}", Colors.DIM))
print()
try:
val = input(color(" Toggle # (or Enter to confirm): ", Colors.DIM)).strip()

View File

@@ -9,6 +9,8 @@ Saves per-platform tool configuration to ~/.hermes/config.yaml under
the `platform_toolsets` key.
"""
import json as _json
import logging
import sys
from pathlib import Path
from typing import Dict, List, Optional, Set
@@ -19,6 +21,8 @@ from hermes_cli.config import (
)
from hermes_cli.colors import Colors, color
logger = logging.getLogger(__name__)
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
@@ -653,9 +657,61 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
return default
# ─── Token Estimation ────────────────────────────────────────────────────────
# Module-level cache so discovery + tokenization runs at most once per process.
_tool_token_cache: Optional[Dict[str, int]] = None
def _estimate_tool_tokens() -> Dict[str, int]:
"""Return estimated token counts per individual tool name.
Uses tiktoken (cl100k_base) to count tokens in the JSON-serialised
OpenAI-format tool schema. Triggers tool discovery on first call,
then caches the result for the rest of the process.
Returns an empty dict when tiktoken or the registry is unavailable.
"""
global _tool_token_cache
if _tool_token_cache is not None:
return _tool_token_cache
try:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
except Exception:
logger.debug("tiktoken unavailable; skipping tool token estimation")
_tool_token_cache = {}
return _tool_token_cache
try:
# Trigger full tool discovery (imports all tool modules).
import model_tools # noqa: F401
from tools.registry import registry
except Exception:
logger.debug("Tool registry unavailable; skipping token estimation")
_tool_token_cache = {}
return _tool_token_cache
counts: Dict[str, int] = {}
for name in registry.get_all_tool_names():
schema = registry.get_schema(name)
if schema:
# Mirror what gets sent to the API:
# {"type": "function", "function": <schema>}
text = _json.dumps({"type": "function", "function": schema})
counts[name] = len(enc.encode(text))
_tool_token_cache = counts
return _tool_token_cache
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
from hermes_cli.curses_ui import curses_checklist
from toolsets import resolve_toolset
# Pre-compute per-tool token counts (cached after first call).
tool_tokens = _estimate_tool_tokens()
effective = _get_effective_configurable_toolsets()
@@ -671,11 +727,27 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
if ts_key in enabled
}
# Build a live status function that shows deduplicated total token cost.
status_fn = None
if tool_tokens:
ts_keys = [ts_key for ts_key, _, _ in effective]
def status_fn(chosen: set) -> str:
# Collect unique tool names across all selected toolsets
all_tools: set = set()
for idx in chosen:
all_tools.update(resolve_toolset(ts_keys[idx]))
total = sum(tool_tokens.get(name, 0) for name in all_tools)
if total >= 1000:
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
return f"Est. tool context: ~{total} tokens"
chosen = curses_checklist(
f"Tools for {platform_label}",
labels,
pre_selected,
cancel_returns=pre_selected,
status_fn=status_fn,
)
return {effective[i][0] for i in chosen}

View File

@@ -0,0 +1,271 @@
"""Tests for tool token estimation and curses_ui status_fn support."""
from unittest.mock import patch
import pytest
# ─── Token Estimation Tests ──────────────────────────────────────────────────
def test_estimate_tool_tokens_returns_positive_counts():
"""_estimate_tool_tokens should return a non-empty dict with positive values."""
from hermes_cli.tools_config import _estimate_tool_tokens, _tool_token_cache
# Clear cache to force fresh computation
import hermes_cli.tools_config as tc
tc._tool_token_cache = None
tokens = _estimate_tool_tokens()
assert isinstance(tokens, dict)
assert len(tokens) > 0
for name, count in tokens.items():
assert isinstance(name, str)
assert isinstance(count, int)
assert count > 0, f"Tool {name} has non-positive token count: {count}"
def test_estimate_tool_tokens_is_cached():
"""Second call should return the same cached dict object."""
import hermes_cli.tools_config as tc
tc._tool_token_cache = None
first = tc._estimate_tool_tokens()
second = tc._estimate_tool_tokens()
assert first is second
def test_estimate_tool_tokens_returns_empty_when_tiktoken_unavailable(monkeypatch):
"""Graceful degradation when tiktoken cannot be imported."""
import hermes_cli.tools_config as tc
tc._tool_token_cache = None
import builtins
real_import = builtins.__import__
def mock_import(name, *args, **kwargs):
if name == "tiktoken":
raise ImportError("mocked")
return real_import(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", mock_import)
result = tc._estimate_tool_tokens()
assert result == {}
# Reset cache for other tests
tc._tool_token_cache = None
def test_estimate_tool_tokens_covers_known_tools():
"""Should include schemas for well-known tools like terminal, web_search."""
import hermes_cli.tools_config as tc
tc._tool_token_cache = None
tokens = tc._estimate_tool_tokens()
# These tools should always be discoverable
for expected in ("terminal", "web_search", "read_file"):
assert expected in tokens, f"Expected {expected!r} in token estimates"
# ─── Status Function Tests ───────────────────────────────────────────────────
def test_prompt_toolset_checklist_passes_status_fn(monkeypatch):
"""_prompt_toolset_checklist should pass a status_fn to curses_checklist."""
import hermes_cli.tools_config as tc
captured_kwargs = {}
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
captured_kwargs["status_fn"] = status_fn
captured_kwargs["title"] = title
return selected # Return pre-selected unchanged
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
assert "status_fn" in captured_kwargs
# If tiktoken is available, status_fn should be set
tokens = tc._estimate_tool_tokens()
if tokens:
assert captured_kwargs["status_fn"] is not None
def test_status_fn_returns_formatted_token_count(monkeypatch):
"""The status_fn should return a human-readable token count string."""
import hermes_cli.tools_config as tc
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
captured = {}
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
captured["status_fn"] = status_fn
return selected
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
status_fn = captured.get("status_fn")
if status_fn is None:
pytest.skip("tiktoken unavailable; status_fn not created")
# Find the indices for web and terminal
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
# Call status_fn with web + terminal selected
result = status_fn({idx_map["web"], idx_map["terminal"]})
assert "tokens" in result
assert "Est. tool context" in result
def test_status_fn_deduplicates_overlapping_tools(monkeypatch):
"""When toolsets overlap (browser includes web_search), tokens should not double-count."""
import hermes_cli.tools_config as tc
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
captured = {}
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
captured["status_fn"] = status_fn
return selected
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
tc._prompt_toolset_checklist("CLI", {"web"})
status_fn = captured.get("status_fn")
if status_fn is None:
pytest.skip("tiktoken unavailable; status_fn not created")
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
# web alone
web_only = status_fn({idx_map["web"]})
# browser includes web_search, so browser + web should not double-count web_search
browser_only = status_fn({idx_map["browser"]})
both = status_fn({idx_map["web"], idx_map["browser"]})
# Extract numeric token counts from strings like "~8.3k tokens" or "~350 tokens"
import re
def parse_tokens(s):
m = re.search(r"~([\d.]+)k?\s+tokens", s)
if not m:
return 0
val = float(m.group(1))
if "k" in s[m.start():m.end()]:
val *= 1000
return val
web_tok = parse_tokens(web_only)
browser_tok = parse_tokens(browser_only)
both_tok = parse_tokens(both)
# Both together should be LESS than naive sum (due to web_search dedup)
naive_sum = web_tok + browser_tok
assert both_tok < naive_sum, (
f"Expected deduplication: web({web_tok}) + browser({browser_tok}) = {naive_sum} "
f"but combined = {both_tok}"
)
def test_status_fn_empty_selection():
"""Status function with no tools selected should return ~0 tokens."""
import hermes_cli.tools_config as tc
tc._tool_token_cache = None
tokens = tc._estimate_tool_tokens()
if not tokens:
pytest.skip("tiktoken unavailable")
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
from toolsets import resolve_toolset
ts_keys = [ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS]
def status_fn(chosen: set) -> str:
all_tools: set = set()
for idx in chosen:
all_tools.update(resolve_toolset(ts_keys[idx]))
total = sum(tokens.get(name, 0) for name in all_tools)
if total >= 1000:
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
return f"Est. tool context: ~{total} tokens"
result = status_fn(set())
assert "~0 tokens" in result
# ─── Curses UI Status Bar Tests ──────────────────────────────────────────────
def test_curses_checklist_numbered_fallback_shows_status(monkeypatch, capsys):
"""The numbered fallback should print the status_fn output."""
from hermes_cli.curses_ui import _numbered_fallback
def my_status(chosen):
return f"Selected {len(chosen)} items"
# Simulate user pressing Enter immediately (empty input → confirm)
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
result = _numbered_fallback(
"Test title",
["Item A", "Item B", "Item C"],
{0, 2},
{0, 2},
status_fn=my_status,
)
captured = capsys.readouterr()
assert "Selected 2 items" in captured.out
assert result == {0, 2}
def test_curses_checklist_numbered_fallback_without_status(monkeypatch, capsys):
"""The numbered fallback should work fine without status_fn."""
from hermes_cli.curses_ui import _numbered_fallback
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
result = _numbered_fallback(
"Test title",
["Item A", "Item B"],
{0},
{0},
)
captured = capsys.readouterr()
assert "Est. tool context" not in captured.out
assert result == {0}
# ─── Registry get_schema Tests ───────────────────────────────────────────────
def test_registry_get_schema_returns_schema():
"""registry.get_schema() should return a tool's schema dict."""
from tools.registry import registry
# Import to trigger discovery
import model_tools # noqa: F401
schema = registry.get_schema("terminal")
assert schema is not None
assert "name" in schema
assert schema["name"] == "terminal"
assert "parameters" in schema
def test_registry_get_schema_returns_none_for_unknown():
"""registry.get_schema() should return None for unknown tools."""
from tools.registry import registry
assert registry.get_schema("nonexistent_tool_xyz") is None

View File

@@ -149,6 +149,15 @@ class ToolRegistry:
"""Return sorted list of all registered tool names."""
return sorted(self._tools.keys())
def get_schema(self, name: str) -> Optional[dict]:
"""Return a tool's raw schema dict, bypassing check_fn filtering.
Useful for token estimation and introspection where availability
doesn't matter — only the schema content does.
"""
entry = self._tools.get(name)
return entry.schema if entry else None
def get_toolset_for_tool(self, name: str) -> Optional[str]:
"""Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name)