feat: show estimated tool token context in hermes tools checklist (#3805)
* feat: show estimated tool token context in hermes tools checklist Adds a live token estimate indicator to the bottom of the interactive tool configuration checklist (hermes tools / hermes setup). As users toggle toolsets on/off, the total estimated context cost updates in real time. Implementation: - tools/registry.py: Add get_schema() for check_fn-free schema access - hermes_cli/curses_ui.py: Add optional status_fn callback to curses_checklist — renders at bottom-right of terminal, stays fixed while items scroll - hermes_cli/tools_config.py: Add _estimate_tool_tokens() using tiktoken (cl100k_base, already installed) to count tokens in the JSON-serialised OpenAI-format tool schemas. Results are cached per-process. The status function deduplicates overlapping tools (e.g. browser includes web_search) for accurate totals. - 12 new tests covering estimation, caching, graceful degradation when tiktoken is unavailable, status_fn wiring, deduplication, and the numbered fallback display * fix: use effective toolsets (includes plugins) for token estimation index mapping The status_fn closure built ts_keys from CONFIGURABLE_TOOLSETS but the checklist uses _get_effective_configurable_toolsets() which appends plugin toolsets. With plugins present, the indices would mismatch, causing IndexError when selecting a plugin toolset.
This commit is contained in:
@@ -4,7 +4,7 @@ Used by `hermes tools` and `hermes skills` for interactive checklists.
|
||||
Provides a curses multi-select with keyboard navigation, plus a
|
||||
text-based numbered fallback for terminals without curses support.
|
||||
"""
|
||||
from typing import List, Set
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
@@ -15,6 +15,7 @@ def curses_checklist(
|
||||
selected: Set[int],
|
||||
*,
|
||||
cancel_returns: Set[int] | None = None,
|
||||
status_fn: Optional[Callable[[Set[int]], str]] = None,
|
||||
) -> Set[int]:
|
||||
"""Curses multi-select checklist. Returns set of selected indices.
|
||||
|
||||
@@ -23,6 +24,9 @@ def curses_checklist(
|
||||
items: Display labels for each row.
|
||||
selected: Indices that start checked (pre-selected).
|
||||
cancel_returns: Returned on ESC/q. Defaults to the original *selected*.
|
||||
status_fn: Optional callback ``f(chosen_indices) -> str`` whose return
|
||||
value is rendered on the bottom row of the terminal. Use this for
|
||||
live aggregate info (e.g. estimated token counts).
|
||||
"""
|
||||
if cancel_returns is None:
|
||||
cancel_returns = set(selected)
|
||||
@@ -47,6 +51,9 @@ def curses_checklist(
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Reserve bottom row for status bar when status_fn provided
|
||||
footer_rows = 1 if status_fn else 0
|
||||
|
||||
# Header
|
||||
try:
|
||||
hattr = curses.A_BOLD
|
||||
@@ -62,7 +69,7 @@ def curses_checklist(
|
||||
pass
|
||||
|
||||
# Scrollable item list
|
||||
visible_rows = max_y - 3
|
||||
visible_rows = max_y - 3 - footer_rows
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
@@ -72,7 +79,7 @@ def curses_checklist(
|
||||
range(scroll_offset, min(len(items), scroll_offset + visible_rows))
|
||||
):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
if y >= max_y - 1 - footer_rows:
|
||||
break
|
||||
check = "✓" if i in chosen else " "
|
||||
arrow = "→" if i == cursor else " "
|
||||
@@ -87,6 +94,20 @@ def curses_checklist(
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Status bar (bottom row, right-aligned)
|
||||
if status_fn:
|
||||
try:
|
||||
status_text = status_fn(chosen)
|
||||
if status_text:
|
||||
# Right-align on the bottom row
|
||||
sx = max(0, max_x - len(status_text) - 1)
|
||||
sattr = curses.A_DIM
|
||||
if curses.has_colors():
|
||||
sattr |= curses.color_pair(3)
|
||||
stdscr.addnstr(max_y - 1, sx, status_text, max_x - sx - 1, sattr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
@@ -107,7 +128,7 @@ def curses_checklist(
|
||||
return result_holder[0] if result_holder[0] is not None else cancel_returns
|
||||
|
||||
except Exception:
|
||||
return _numbered_fallback(title, items, selected, cancel_returns)
|
||||
return _numbered_fallback(title, items, selected, cancel_returns, status_fn)
|
||||
|
||||
|
||||
def _numbered_fallback(
|
||||
@@ -115,6 +136,7 @@ def _numbered_fallback(
|
||||
items: List[str],
|
||||
selected: Set[int],
|
||||
cancel_returns: Set[int],
|
||||
status_fn: Optional[Callable[[Set[int]], str]] = None,
|
||||
) -> Set[int]:
|
||||
"""Text-based toggle fallback for terminals without curses."""
|
||||
chosen = set(selected)
|
||||
@@ -125,6 +147,10 @@ def _numbered_fallback(
|
||||
for i, label in enumerate(items):
|
||||
marker = color("[✓]", Colors.GREEN) if i in chosen else "[ ]"
|
||||
print(f" {marker} {i + 1:>2}. {label}")
|
||||
if status_fn:
|
||||
status_text = status_fn(chosen)
|
||||
if status_text:
|
||||
print(color(f"\n {status_text}", Colors.DIM))
|
||||
print()
|
||||
try:
|
||||
val = input(color(" Toggle # (or Enter to confirm): ", Colors.DIM)).strip()
|
||||
|
||||
@@ -9,6 +9,8 @@ Saves per-platform tool configuration to ~/.hermes/config.yaml under
|
||||
the `platform_toolsets` key.
|
||||
"""
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
@@ -19,6 +21,8 @@ from hermes_cli.config import (
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
@@ -653,9 +657,61 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return default
|
||||
|
||||
|
||||
# ─── Token Estimation ────────────────────────────────────────────────────────
|
||||
|
||||
# Module-level cache so discovery + tokenization runs at most once per process.
|
||||
_tool_token_cache: Optional[Dict[str, int]] = None
|
||||
|
||||
|
||||
def _estimate_tool_tokens() -> Dict[str, int]:
|
||||
"""Return estimated token counts per individual tool name.
|
||||
|
||||
Uses tiktoken (cl100k_base) to count tokens in the JSON-serialised
|
||||
OpenAI-format tool schema. Triggers tool discovery on first call,
|
||||
then caches the result for the rest of the process.
|
||||
|
||||
Returns an empty dict when tiktoken or the registry is unavailable.
|
||||
"""
|
||||
global _tool_token_cache
|
||||
if _tool_token_cache is not None:
|
||||
return _tool_token_cache
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
logger.debug("tiktoken unavailable; skipping tool token estimation")
|
||||
_tool_token_cache = {}
|
||||
return _tool_token_cache
|
||||
|
||||
try:
|
||||
# Trigger full tool discovery (imports all tool modules).
|
||||
import model_tools # noqa: F401
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
logger.debug("Tool registry unavailable; skipping token estimation")
|
||||
_tool_token_cache = {}
|
||||
return _tool_token_cache
|
||||
|
||||
counts: Dict[str, int] = {}
|
||||
for name in registry.get_all_tool_names():
|
||||
schema = registry.get_schema(name)
|
||||
if schema:
|
||||
# Mirror what gets sent to the API:
|
||||
# {"type": "function", "function": <schema>}
|
||||
text = _json.dumps({"type": "function", "function": schema})
|
||||
counts[name] = len(enc.encode(text))
|
||||
_tool_token_cache = counts
|
||||
return _tool_token_cache
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
# Pre-compute per-tool token counts (cached after first call).
|
||||
tool_tokens = _estimate_tool_tokens()
|
||||
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
|
||||
@@ -671,11 +727,27 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
if ts_key in enabled
|
||||
}
|
||||
|
||||
# Build a live status function that shows deduplicated total token cost.
|
||||
status_fn = None
|
||||
if tool_tokens:
|
||||
ts_keys = [ts_key for ts_key, _, _ in effective]
|
||||
|
||||
def status_fn(chosen: set) -> str:
|
||||
# Collect unique tool names across all selected toolsets
|
||||
all_tools: set = set()
|
||||
for idx in chosen:
|
||||
all_tools.update(resolve_toolset(ts_keys[idx]))
|
||||
total = sum(tool_tokens.get(name, 0) for name in all_tools)
|
||||
if total >= 1000:
|
||||
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
|
||||
return f"Est. tool context: ~{total} tokens"
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Tools for {platform_label}",
|
||||
labels,
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
status_fn=status_fn,
|
||||
)
|
||||
return {effective[i][0] for i in chosen}
|
||||
|
||||
|
||||
271
tests/hermes_cli/test_tool_token_estimation.py
Normal file
271
tests/hermes_cli/test_tool_token_estimation.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests for tool token estimation and curses_ui status_fn support."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ─── Token Estimation Tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_returns_positive_counts():
|
||||
"""_estimate_tool_tokens should return a non-empty dict with positive values."""
|
||||
from hermes_cli.tools_config import _estimate_tool_tokens, _tool_token_cache
|
||||
|
||||
# Clear cache to force fresh computation
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
tokens = _estimate_tool_tokens()
|
||||
|
||||
assert isinstance(tokens, dict)
|
||||
assert len(tokens) > 0
|
||||
for name, count in tokens.items():
|
||||
assert isinstance(name, str)
|
||||
assert isinstance(count, int)
|
||||
assert count > 0, f"Tool {name} has non-positive token count: {count}"
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_is_cached():
|
||||
"""Second call should return the same cached dict object."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
first = tc._estimate_tool_tokens()
|
||||
second = tc._estimate_tool_tokens()
|
||||
|
||||
assert first is second
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_returns_empty_when_tiktoken_unavailable(monkeypatch):
|
||||
"""Graceful degradation when tiktoken cannot be imported."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "tiktoken":
|
||||
raise ImportError("mocked")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mock_import)
|
||||
|
||||
result = tc._estimate_tool_tokens()
|
||||
|
||||
assert result == {}
|
||||
|
||||
# Reset cache for other tests
|
||||
tc._tool_token_cache = None
|
||||
|
||||
|
||||
def test_estimate_tool_tokens_covers_known_tools():
|
||||
"""Should include schemas for well-known tools like terminal, web_search."""
|
||||
import hermes_cli.tools_config as tc
|
||||
tc._tool_token_cache = None
|
||||
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
|
||||
# These tools should always be discoverable
|
||||
for expected in ("terminal", "web_search", "read_file"):
|
||||
assert expected in tokens, f"Expected {expected!r} in token estimates"
|
||||
|
||||
|
||||
# ─── Status Function Tests ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_prompt_toolset_checklist_passes_status_fn(monkeypatch):
|
||||
"""_prompt_toolset_checklist should pass a status_fn to curses_checklist."""
|
||||
import hermes_cli.tools_config as tc
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured_kwargs["status_fn"] = status_fn
|
||||
captured_kwargs["title"] = title
|
||||
return selected # Return pre-selected unchanged
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
|
||||
|
||||
assert "status_fn" in captured_kwargs
|
||||
# If tiktoken is available, status_fn should be set
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
if tokens:
|
||||
assert captured_kwargs["status_fn"] is not None
|
||||
|
||||
|
||||
def test_status_fn_returns_formatted_token_count(monkeypatch):
|
||||
"""The status_fn should return a human-readable token count string."""
|
||||
import hermes_cli.tools_config as tc
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured["status_fn"] = status_fn
|
||||
return selected
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web", "terminal"})
|
||||
|
||||
status_fn = captured.get("status_fn")
|
||||
if status_fn is None:
|
||||
pytest.skip("tiktoken unavailable; status_fn not created")
|
||||
|
||||
# Find the indices for web and terminal
|
||||
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
|
||||
|
||||
# Call status_fn with web + terminal selected
|
||||
result = status_fn({idx_map["web"], idx_map["terminal"]})
|
||||
assert "tokens" in result
|
||||
assert "Est. tool context" in result
|
||||
|
||||
|
||||
def test_status_fn_deduplicates_overlapping_tools(monkeypatch):
|
||||
"""When toolsets overlap (browser includes web_search), tokens should not double-count."""
|
||||
import hermes_cli.tools_config as tc
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None):
|
||||
captured["status_fn"] = status_fn
|
||||
return selected
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist)
|
||||
|
||||
tc._prompt_toolset_checklist("CLI", {"web"})
|
||||
|
||||
status_fn = captured.get("status_fn")
|
||||
if status_fn is None:
|
||||
pytest.skip("tiktoken unavailable; status_fn not created")
|
||||
|
||||
idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)}
|
||||
|
||||
# web alone
|
||||
web_only = status_fn({idx_map["web"]})
|
||||
# browser includes web_search, so browser + web should not double-count web_search
|
||||
browser_only = status_fn({idx_map["browser"]})
|
||||
both = status_fn({idx_map["web"], idx_map["browser"]})
|
||||
|
||||
# Extract numeric token counts from strings like "~8.3k tokens" or "~350 tokens"
|
||||
import re
|
||||
|
||||
def parse_tokens(s):
|
||||
m = re.search(r"~([\d.]+)k?\s+tokens", s)
|
||||
if not m:
|
||||
return 0
|
||||
val = float(m.group(1))
|
||||
if "k" in s[m.start():m.end()]:
|
||||
val *= 1000
|
||||
return val
|
||||
|
||||
web_tok = parse_tokens(web_only)
|
||||
browser_tok = parse_tokens(browser_only)
|
||||
both_tok = parse_tokens(both)
|
||||
|
||||
# Both together should be LESS than naive sum (due to web_search dedup)
|
||||
naive_sum = web_tok + browser_tok
|
||||
assert both_tok < naive_sum, (
|
||||
f"Expected deduplication: web({web_tok}) + browser({browser_tok}) = {naive_sum} "
|
||||
f"but combined = {both_tok}"
|
||||
)
|
||||
|
||||
|
||||
def test_status_fn_empty_selection():
|
||||
"""Status function with no tools selected should return ~0 tokens."""
|
||||
import hermes_cli.tools_config as tc
|
||||
|
||||
tc._tool_token_cache = None
|
||||
tokens = tc._estimate_tool_tokens()
|
||||
if not tokens:
|
||||
pytest.skip("tiktoken unavailable")
|
||||
|
||||
from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
ts_keys = [ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS]
|
||||
|
||||
def status_fn(chosen: set) -> str:
|
||||
all_tools: set = set()
|
||||
for idx in chosen:
|
||||
all_tools.update(resolve_toolset(ts_keys[idx]))
|
||||
total = sum(tokens.get(name, 0) for name in all_tools)
|
||||
if total >= 1000:
|
||||
return f"Est. tool context: ~{total / 1000:.1f}k tokens"
|
||||
return f"Est. tool context: ~{total} tokens"
|
||||
|
||||
result = status_fn(set())
|
||||
assert "~0 tokens" in result
|
||||
|
||||
|
||||
# ─── Curses UI Status Bar Tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_curses_checklist_numbered_fallback_shows_status(monkeypatch, capsys):
|
||||
"""The numbered fallback should print the status_fn output."""
|
||||
from hermes_cli.curses_ui import _numbered_fallback
|
||||
|
||||
def my_status(chosen):
|
||||
return f"Selected {len(chosen)} items"
|
||||
|
||||
# Simulate user pressing Enter immediately (empty input → confirm)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
|
||||
|
||||
result = _numbered_fallback(
|
||||
"Test title",
|
||||
["Item A", "Item B", "Item C"],
|
||||
{0, 2},
|
||||
{0, 2},
|
||||
status_fn=my_status,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Selected 2 items" in captured.out
|
||||
assert result == {0, 2}
|
||||
|
||||
|
||||
def test_curses_checklist_numbered_fallback_without_status(monkeypatch, capsys):
|
||||
"""The numbered fallback should work fine without status_fn."""
|
||||
from hermes_cli.curses_ui import _numbered_fallback
|
||||
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "")
|
||||
|
||||
result = _numbered_fallback(
|
||||
"Test title",
|
||||
["Item A", "Item B"],
|
||||
{0},
|
||||
{0},
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Est. tool context" not in captured.out
|
||||
assert result == {0}
|
||||
|
||||
|
||||
# ─── Registry get_schema Tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_registry_get_schema_returns_schema():
|
||||
"""registry.get_schema() should return a tool's schema dict."""
|
||||
from tools.registry import registry
|
||||
|
||||
# Import to trigger discovery
|
||||
import model_tools # noqa: F401
|
||||
|
||||
schema = registry.get_schema("terminal")
|
||||
assert schema is not None
|
||||
assert "name" in schema
|
||||
assert schema["name"] == "terminal"
|
||||
assert "parameters" in schema
|
||||
|
||||
|
||||
def test_registry_get_schema_returns_none_for_unknown():
|
||||
"""registry.get_schema() should return None for unknown tools."""
|
||||
from tools.registry import registry
|
||||
|
||||
assert registry.get_schema("nonexistent_tool_xyz") is None
|
||||
@@ -149,6 +149,15 @@ class ToolRegistry:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
def get_schema(self, name: str) -> Optional[dict]:
|
||||
"""Return a tool's raw schema dict, bypassing check_fn filtering.
|
||||
|
||||
Useful for token estimation and introspection where availability
|
||||
doesn't matter — only the schema content does.
|
||||
"""
|
||||
entry = self._tools.get(name)
|
||||
return entry.schema if entry else None
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
|
||||
Reference in New Issue
Block a user