From da44c196b60423e82fa7c754662a01f884dfbd80 Mon Sep 17 00:00:00 2001 From: Teknium Date: Sat, 21 Mar 2026 15:57:13 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20@=20context=20references=20=E2=80=94=20?= =?UTF-8?q?inline=20file,=20folder,=20diff,=20git,=20and=20URL=20injection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add @file:path, @folder:dir, @diff, @staged, @git:N, and @url: references that expand inline before the message reaches the LLM. Supports line ranges (@file:main.py:10-50), token budget enforcement (soft warn at 25%, hard block at 50%), and path sandboxing for gateway. Core module from PR #2090 by @kshitijk4poor. CLI and gateway wiring rewritten against current main. Fixed asyncio.run() crash when called from inside a running event loop (gateway). Closes #682. --- agent/context_references.py | 440 +++++++++++++++++++++++++++++++ cli.py | 22 ++ gateway/run.py | 26 +- tests/test_context_references.py | 221 ++++++++++++++++ 4 files changed, 708 insertions(+), 1 deletion(-) create mode 100644 agent/context_references.py create mode 100644 tests/test_context_references.py diff --git a/agent/context_references.py b/agent/context_references.py new file mode 100644 index 000000000..fbe9a2d67 --- /dev/null +++ b/agent/context_references.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import mimetypes +import os +import re +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Awaitable, Callable + +from agent.model_metadata import estimate_tokens_rough + +REFERENCE_PATTERN = re.compile( + r"(?diff|staged)\b|(?Pfile|folder|git|url):(?P\S+))" +) +TRAILING_PUNCTUATION = ",.;!?" + + +@dataclass(frozen=True) +class ContextReference: + raw: str + kind: str + target: str + start: int + end: int + line_start: int | None = None + line_end: int | None = None + + +@dataclass +class ContextReferenceResult: + message: str + original_message: str + references: list[ContextReference] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + injected_tokens: int = 0 + expanded: bool = False + blocked: bool = False + + +def parse_context_references(message: str) -> list[ContextReference]: + refs: list[ContextReference] = [] + if not message: + return refs + + for match in REFERENCE_PATTERN.finditer(message): + simple = match.group("simple") + if simple: + refs.append( + ContextReference( + raw=match.group(0), + kind=simple, + target="", + start=match.start(), + end=match.end(), + ) + ) + continue + + kind = match.group("kind") + value = _strip_trailing_punctuation(match.group("value") or "") + line_start = None + line_end = None + target = value + + if kind == "file": + range_match = re.match(r"^(?P.+?):(?P\d+)(?:-(?P\d+))?$", value) + if range_match: + target = range_match.group("path") + line_start = int(range_match.group("start")) + line_end = int(range_match.group("end") or range_match.group("start")) + + refs.append( + ContextReference( + raw=match.group(0), + kind=kind, + target=target, + start=match.start(), + end=match.end(), + line_start=line_start, + line_end=line_end, + ) + ) + + return refs + + +def preprocess_context_references( + message: str, + *, + cwd: str | Path, + context_length: int, + url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, + allowed_root: str | Path | None = None, +) -> ContextReferenceResult: + coro = preprocess_context_references_async( + message, + cwd=cwd, + context_length=context_length, + url_fetcher=url_fetcher, + allowed_root=allowed_root, + ) + # Safe for both CLI (no loop) and gateway (loop already running). + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +async def preprocess_context_references_async( + message: str, + *, + cwd: str | Path, + context_length: int, + url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, + allowed_root: str | Path | None = None, +) -> ContextReferenceResult: + refs = parse_context_references(message) + if not refs: + return ContextReferenceResult(message=message, original_message=message) + + cwd_path = Path(cwd).expanduser().resolve() + allowed_root_path = Path(allowed_root).expanduser().resolve() if allowed_root is not None else None + warnings: list[str] = [] + blocks: list[str] = [] + injected_tokens = 0 + + for ref in refs: + warning, block = await _expand_reference( + ref, + cwd_path, + url_fetcher=url_fetcher, + allowed_root=allowed_root_path, + ) + if warning: + warnings.append(warning) + if block: + blocks.append(block) + injected_tokens += estimate_tokens_rough(block) + + hard_limit = max(1, int(context_length * 0.50)) + soft_limit = max(1, int(context_length * 0.25)) + if injected_tokens > hard_limit: + warnings.append( + f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})." + ) + return ContextReferenceResult( + message=message, + original_message=message, + references=refs, + warnings=warnings, + injected_tokens=injected_tokens, + expanded=False, + blocked=True, + ) + + if injected_tokens > soft_limit: + warnings.append( + f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})." + ) + + stripped = _remove_reference_tokens(message, refs) + final = stripped + if warnings: + final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings) + if blocks: + final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks) + + return ContextReferenceResult( + message=final.strip(), + original_message=message, + references=refs, + warnings=warnings, + injected_tokens=injected_tokens, + expanded=bool(blocks or warnings), + blocked=False, + ) + + +async def _expand_reference( + ref: ContextReference, + cwd: Path, + *, + url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, + allowed_root: Path | None = None, +) -> tuple[str | None, str | None]: + try: + if ref.kind == "file": + return _expand_file_reference(ref, cwd, allowed_root=allowed_root) + if ref.kind == "folder": + return _expand_folder_reference(ref, cwd, allowed_root=allowed_root) + if ref.kind == "diff": + return _expand_git_reference(ref, cwd, ["diff"], "git diff") + if ref.kind == "staged": + return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged") + if ref.kind == "git": + count = max(1, min(int(ref.target or "1"), 10)) + return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p") + if ref.kind == "url": + content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher) + if not content: + return f"{ref.raw}: no content extracted", None + return None, f"๐ŸŒ {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}" + except Exception as exc: + return f"{ref.raw}: {exc}", None + + return f"{ref.raw}: unsupported reference type", None + + +def _expand_file_reference( + ref: ContextReference, + cwd: Path, + *, + allowed_root: Path | None = None, +) -> tuple[str | None, str | None]: + path = _resolve_path(cwd, ref.target, allowed_root=allowed_root) + if not path.exists(): + return f"{ref.raw}: file not found", None + if not path.is_file(): + return f"{ref.raw}: path is not a file", None + if _is_binary_file(path): + return f"{ref.raw}: binary files are not supported", None + + text = path.read_text(encoding="utf-8") + if ref.line_start is not None: + lines = text.splitlines() + start_idx = max(ref.line_start - 1, 0) + end_idx = min(ref.line_end or ref.line_start, len(lines)) + text = "\n".join(lines[start_idx:end_idx]) + + lang = _code_fence_language(path) + label = ref.raw + return None, f"๐Ÿ“„ {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```" + + +def _expand_folder_reference( + ref: ContextReference, + cwd: Path, + *, + allowed_root: Path | None = None, +) -> tuple[str | None, str | None]: + path = _resolve_path(cwd, ref.target, allowed_root=allowed_root) + if not path.exists(): + return f"{ref.raw}: folder not found", None + if not path.is_dir(): + return f"{ref.raw}: path is not a folder", None + + listing = _build_folder_listing(path, cwd) + return None, f"๐Ÿ“ {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}" + + +def _expand_git_reference( + ref: ContextReference, + cwd: Path, + args: list[str], + label: str, +) -> tuple[str | None, str | None]: + result = subprocess.run( + ["git", *args], + cwd=cwd, + capture_output=True, + text=True, + ) + if result.returncode != 0: + stderr = (result.stderr or "").strip() or "git command failed" + return f"{ref.raw}: {stderr}", None + content = result.stdout.strip() + if not content: + content = "(no output)" + return None, f"๐Ÿงพ {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```" + + +async def _fetch_url_content( + url: str, + *, + url_fetcher: Callable[[str], str | Awaitable[str]] | None = None, +) -> str: + fetcher = url_fetcher or _default_url_fetcher + content = fetcher(url) + if inspect.isawaitable(content): + content = await content + return str(content or "").strip() + + +async def _default_url_fetcher(url: str) -> str: + from tools.web_tools import web_extract_tool + + raw = await web_extract_tool([url], format="markdown", use_llm_processing=True) + payload = json.loads(raw) + docs = payload.get("data", {}).get("documents", []) + if not docs: + return "" + doc = docs[0] + return str(doc.get("content") or doc.get("raw_content") or "").strip() + + +def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path: + path = Path(os.path.expanduser(target)) + if not path.is_absolute(): + path = cwd / path + resolved = path.resolve() + if allowed_root is not None: + try: + resolved.relative_to(allowed_root) + except ValueError as exc: + raise ValueError("path is outside the allowed workspace") from exc + return resolved + + +def _strip_trailing_punctuation(value: str) -> str: + stripped = value.rstrip(TRAILING_PUNCTUATION) + while stripped.endswith((")", "]", "}")): + closer = stripped[-1] + opener = {")": "(", "]": "[", "}": "{"}[closer] + if stripped.count(closer) > stripped.count(opener): + stripped = stripped[:-1] + continue + break + return stripped + + +def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str: + pieces: list[str] = [] + cursor = 0 + for ref in refs: + pieces.append(message[cursor:ref.start]) + cursor = ref.end + pieces.append(message[cursor:]) + text = "".join(pieces) + text = re.sub(r"\s{2,}", " ", text) + text = re.sub(r"\s+([,.;:!?])", r"\1", text) + return text.strip() + + +def _is_binary_file(path: Path) -> bool: + mime, _ = mimetypes.guess_type(path.name) + if mime and not mime.startswith("text/") and not any( + path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts") + ): + return True + chunk = path.read_bytes()[:4096] + return b"\x00" in chunk + + +def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str: + lines = [f"{path.relative_to(cwd)}/"] + entries = _iter_visible_entries(path, cwd, limit=limit) + for entry in entries: + rel = entry.relative_to(cwd) + indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0) + if entry.is_dir(): + lines.append(f"{indent}- {entry.name}/") + else: + meta = _file_metadata(entry) + lines.append(f"{indent}- {entry.name} ({meta})") + if len(entries) >= limit: + lines.append("- ...") + return "\n".join(lines) + + +def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]: + rg_entries = _rg_files(path, cwd, limit=limit) + if rg_entries is not None: + output: list[Path] = [] + seen_dirs: set[Path] = set() + for rel in rg_entries: + full = cwd / rel + for parent in full.parents: + if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}: + continue + seen_dirs.add(parent) + output.append(parent) + output.append(full) + return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p))) + + output = [] + for root, dirs, files in os.walk(path): + dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__") + files = sorted(f for f in files if not f.startswith(".")) + root_path = Path(root) + for d in dirs: + output.append(root_path / d) + if len(output) >= limit: + return output + for f in files: + output.append(root_path / f) + if len(output) >= limit: + return output + return output + + +def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None: + try: + result = subprocess.run( + ["rg", "--files", str(path.relative_to(cwd))], + cwd=cwd, + capture_output=True, + text=True, + ) + except FileNotFoundError: + return None + if result.returncode != 0: + return None + files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()] + return files[:limit] + + +def _file_metadata(path: Path) -> str: + if _is_binary_file(path): + return f"{path.stat().st_size} bytes" + try: + line_count = path.read_text(encoding="utf-8").count("\n") + 1 + except Exception: + return f"{path.stat().st_size} bytes" + return f"{line_count} lines" + + +def _code_fence_language(path: Path) -> str: + mapping = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".json": "json", + ".md": "markdown", + ".sh": "bash", + ".yml": "yaml", + ".yaml": "yaml", + ".toml": "toml", + } + return mapping.get(path.suffix.lower(), "") diff --git a/cli.py b/cli.py index b0dae05d7..8772d5720 100755 --- a/cli.py +++ b/cli.py @@ -5320,6 +5320,28 @@ class HermesCLI: message if isinstance(message, str) else "", images ) + # Expand @ context references (e.g. @file:main.py, @diff, @folder:src/) + if isinstance(message, str) and "@" in message: + try: + from agent.context_references import preprocess_context_references + from agent.model_metadata import get_model_context_length + _ctx_len = get_model_context_length( + self.model, base_url=self.base_url or "", api_key=self.api_key or "") + _ctx_result = preprocess_context_references( + message, cwd=os.getcwd(), context_length=_ctx_len) + if _ctx_result.expanded or _ctx_result.blocked: + if _ctx_result.references: + _cprint( + f" {_DIM}[@ context: {len(_ctx_result.references)} ref(s), " + f"{_ctx_result.injected_tokens} tokens]{_RST}") + for w in _ctx_result.warnings: + _cprint(f" {_DIM}โš  {w}{_RST}") + if _ctx_result.blocked: + return "\n".join(_ctx_result.warnings) or "Context injection refused." + message = _ctx_result.message + except Exception as e: + logging.debug("@ context reference expansion failed: %s", e) + # Add user message to history self.conversation_history.append({"role": "user", "content": message}) diff --git a/gateway/run.py b/gateway/run.py index 954738748..814757529 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2058,7 +2058,31 @@ class GatewayRunner: "message": message_text[:500], } await self.hooks.emit("agent:start", hook_ctx) - + + # Expand @ context references (@file:, @folder:, @diff, etc.) + if "@" in message_text: + try: + from agent.context_references import preprocess_context_references_async + from agent.model_metadata import get_model_context_length + _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) + _msg_ctx_len = get_model_context_length( + self._model, base_url=self._base_url or "") + _ctx_result = await preprocess_context_references_async( + message_text, cwd=_msg_cwd, + context_length=_msg_ctx_len, allowed_root=_msg_cwd) + if _ctx_result.blocked: + _adapter = self.adapters.get(source.platform) + if _adapter: + await _adapter.send( + source.chat_id, + "\n".join(_ctx_result.warnings) or "Context injection refused.", + ) + return + if _ctx_result.expanded: + message_text = _ctx_result.message + except Exception as exc: + logger.debug("@ context reference expansion failed: %s", exc) + # Run the agent agent_result = await self._run_agent( message=message_text, diff --git a/tests/test_context_references.py b/tests/test_context_references.py new file mode 100644 index 000000000..34ac06033 --- /dev/null +++ b/tests/test_context_references.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import asyncio +import subprocess +from pathlib import Path + +import pytest + + +def _git(cwd: Path, *args: str) -> str: + result = subprocess.run( + ["git", *args], + cwd=cwd, + check=True, + capture_output=True, + text=True, + ) + return result.stdout.strip() + + +@pytest.fixture +def sample_repo(tmp_path: Path) -> Path: + repo = tmp_path / "repo" + repo.mkdir() + _git(repo, "init") + _git(repo, "config", "user.name", "Hermes Tests") + _git(repo, "config", "user.email", "tests@example.com") + + (repo / "src").mkdir() + (repo / "src" / "main.py").write_text( + "def alpha():\n" + " return 'a'\n\n" + "def beta():\n" + " return 'b'\n", + encoding="utf-8", + ) + (repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8") + (repo / "README.md").write_text("# Demo\n", encoding="utf-8") + (repo / "blob.bin").write_bytes(b"\x00\x01\x02binary") + + _git(repo, "add", ".") + _git(repo, "commit", "-m", "initial") + + (repo / "src" / "main.py").write_text( + "def alpha():\n" + " return 'changed'\n\n" + "def beta():\n" + " return 'b'\n", + encoding="utf-8", + ) + (repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8") + _git(repo, "add", "src/helper.py") + return repo + + +def test_parse_typed_references_ignores_emails_and_handles(): + from agent.context_references import parse_context_references + + message = ( + "email me at user@example.com and ping @teammate " + "but include @file:src/main.py:1-2 plus @diff and @git:2 " + "and @url:https://example.com/docs" + ) + + refs = parse_context_references(message) + + assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"] + assert refs[0].target == "src/main.py" + assert refs[0].line_start == 1 + assert refs[0].line_end == 2 + assert refs[2].target == "2" + + +def test_parse_references_strips_trailing_punctuation(): + from agent.context_references import parse_context_references + + refs = parse_context_references( + "review @file:README.md, then see (@url:https://example.com/docs)." + ) + + assert [ref.kind for ref in refs] == ["file", "url"] + assert refs[0].target == "README.md" + assert refs[1].target == "https://example.com/docs" + + +def test_expand_file_range_and_folder_listing(sample_repo: Path): + from agent.context_references import preprocess_context_references + + result = preprocess_context_references( + "Review @file:src/main.py:1-2 and @folder:src/", + cwd=sample_repo, + context_length=100_000, + ) + + assert result.expanded + assert "Review and" in result.message + assert "Review @file:src/main.py:1-2" not in result.message + assert "--- Attached Context ---" in result.message + assert "def alpha():" in result.message + assert "return 'changed'" in result.message + assert "def beta():" not in result.message + assert "src/" in result.message + assert "main.py" in result.message + assert "helper.py" in result.message + assert result.injected_tokens > 0 + assert not result.warnings + + +def test_expand_git_diff_staged_and_log(sample_repo: Path): + from agent.context_references import preprocess_context_references + + result = preprocess_context_references( + "Inspect @diff and @staged and @git:1", + cwd=sample_repo, + context_length=100_000, + ) + + assert result.expanded + assert "git diff" in result.message + assert "git diff --staged" in result.message + assert "git log -1 -p" in result.message + assert "initial" in result.message + assert "return 'changed'" in result.message + assert "VALUE = 2" in result.message + + +def test_binary_and_missing_files_become_warnings(sample_repo: Path): + from agent.context_references import preprocess_context_references + + result = preprocess_context_references( + "Check @file:blob.bin and @file:nope.txt", + cwd=sample_repo, + context_length=100_000, + ) + + assert result.expanded + assert len(result.warnings) == 2 + assert "binary" in result.message.lower() + assert "not found" in result.message.lower() + + +def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path): + from agent.context_references import preprocess_context_references + + soft = preprocess_context_references( + "Check @file:src/main.py", + cwd=sample_repo, + context_length=100, + ) + assert soft.expanded + assert any("25%" in warning for warning in soft.warnings) + + hard = preprocess_context_references( + "Check @file:src/main.py and @file:README.md", + cwd=sample_repo, + context_length=20, + ) + assert not hard.expanded + assert hard.blocked + assert "@file:src/main.py" in hard.message + assert any("50%" in warning for warning in hard.warnings) + + +@pytest.mark.asyncio +async def test_async_url_expansion_uses_fetcher(sample_repo: Path): + from agent.context_references import preprocess_context_references_async + + async def fake_fetch(url: str) -> str: + assert url == "https://example.com/spec" + return "# Spec\n\nImportant details." + + result = await preprocess_context_references_async( + "Use @url:https://example.com/spec", + cwd=sample_repo, + context_length=100_000, + url_fetcher=fake_fetch, + ) + + assert result.expanded + assert "Important details." in result.message + assert result.injected_tokens > 0 + + +def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path): + from agent.context_references import preprocess_context_references + + async def fake_fetch(url: str) -> str: + await asyncio.sleep(0) + return f"Content for {url}" + + result = preprocess_context_references( + "Use @url:https://example.com/spec", + cwd=sample_repo, + context_length=100_000, + url_fetcher=fake_fetch, + ) + + assert result.expanded + assert "Content for https://example.com/spec" in result.message + + +def test_restricts_paths_to_allowed_root(tmp_path: Path): + from agent.context_references import preprocess_context_references + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "notes.txt").write_text("inside\n", encoding="utf-8") + secret = tmp_path / "secret.txt" + secret.write_text("outside\n", encoding="utf-8") + + result = preprocess_context_references( + "read @file:../secret.txt and @file:notes.txt", + cwd=workspace, + context_length=100_000, + allowed_root=workspace, + ) + + assert result.expanded + assert "```\noutside\n```" not in result.message + assert "inside" in result.message + assert any("outside the allowed workspace" in warning for warning in result.warnings)