Add @file:path, @folder:dir, @diff, @staged, @git:N, and @url: references that expand inline before the message reaches the LLM. Supports line ranges (@file:main.py:10-50), token budget enforcement (soft warn at 25%, hard block at 50%), and path sandboxing for gateway. Core module from PR #2090 by @kshitijk4poor. CLI and gateway wiring rewritten against current main. Fixed asyncio.run() crash when called from inside a running event loop (gateway). Closes #682.
222 lines
6.7 KiB
Python
222 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
def _git(cwd: Path, *args: str) -> str:
|
|
result = subprocess.run(
|
|
["git", *args],
|
|
cwd=cwd,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
return result.stdout.strip()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_repo(tmp_path: Path) -> Path:
|
|
repo = tmp_path / "repo"
|
|
repo.mkdir()
|
|
_git(repo, "init")
|
|
_git(repo, "config", "user.name", "Hermes Tests")
|
|
_git(repo, "config", "user.email", "tests@example.com")
|
|
|
|
(repo / "src").mkdir()
|
|
(repo / "src" / "main.py").write_text(
|
|
"def alpha():\n"
|
|
" return 'a'\n\n"
|
|
"def beta():\n"
|
|
" return 'b'\n",
|
|
encoding="utf-8",
|
|
)
|
|
(repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8")
|
|
(repo / "README.md").write_text("# Demo\n", encoding="utf-8")
|
|
(repo / "blob.bin").write_bytes(b"\x00\x01\x02binary")
|
|
|
|
_git(repo, "add", ".")
|
|
_git(repo, "commit", "-m", "initial")
|
|
|
|
(repo / "src" / "main.py").write_text(
|
|
"def alpha():\n"
|
|
" return 'changed'\n\n"
|
|
"def beta():\n"
|
|
" return 'b'\n",
|
|
encoding="utf-8",
|
|
)
|
|
(repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8")
|
|
_git(repo, "add", "src/helper.py")
|
|
return repo
|
|
|
|
|
|
def test_parse_typed_references_ignores_emails_and_handles():
|
|
from agent.context_references import parse_context_references
|
|
|
|
message = (
|
|
"email me at user@example.com and ping @teammate "
|
|
"but include @file:src/main.py:1-2 plus @diff and @git:2 "
|
|
"and @url:https://example.com/docs"
|
|
)
|
|
|
|
refs = parse_context_references(message)
|
|
|
|
assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"]
|
|
assert refs[0].target == "src/main.py"
|
|
assert refs[0].line_start == 1
|
|
assert refs[0].line_end == 2
|
|
assert refs[2].target == "2"
|
|
|
|
|
|
def test_parse_references_strips_trailing_punctuation():
|
|
from agent.context_references import parse_context_references
|
|
|
|
refs = parse_context_references(
|
|
"review @file:README.md, then see (@url:https://example.com/docs)."
|
|
)
|
|
|
|
assert [ref.kind for ref in refs] == ["file", "url"]
|
|
assert refs[0].target == "README.md"
|
|
assert refs[1].target == "https://example.com/docs"
|
|
|
|
|
|
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
result = preprocess_context_references(
|
|
"Review @file:src/main.py:1-2 and @folder:src/",
|
|
cwd=sample_repo,
|
|
context_length=100_000,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert "Review and" in result.message
|
|
assert "Review @file:src/main.py:1-2" not in result.message
|
|
assert "--- Attached Context ---" in result.message
|
|
assert "def alpha():" in result.message
|
|
assert "return 'changed'" in result.message
|
|
assert "def beta():" not in result.message
|
|
assert "src/" in result.message
|
|
assert "main.py" in result.message
|
|
assert "helper.py" in result.message
|
|
assert result.injected_tokens > 0
|
|
assert not result.warnings
|
|
|
|
|
|
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
result = preprocess_context_references(
|
|
"Inspect @diff and @staged and @git:1",
|
|
cwd=sample_repo,
|
|
context_length=100_000,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert "git diff" in result.message
|
|
assert "git diff --staged" in result.message
|
|
assert "git log -1 -p" in result.message
|
|
assert "initial" in result.message
|
|
assert "return 'changed'" in result.message
|
|
assert "VALUE = 2" in result.message
|
|
|
|
|
|
def test_binary_and_missing_files_become_warnings(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
result = preprocess_context_references(
|
|
"Check @file:blob.bin and @file:nope.txt",
|
|
cwd=sample_repo,
|
|
context_length=100_000,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert len(result.warnings) == 2
|
|
assert "binary" in result.message.lower()
|
|
assert "not found" in result.message.lower()
|
|
|
|
|
|
def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
soft = preprocess_context_references(
|
|
"Check @file:src/main.py",
|
|
cwd=sample_repo,
|
|
context_length=100,
|
|
)
|
|
assert soft.expanded
|
|
assert any("25%" in warning for warning in soft.warnings)
|
|
|
|
hard = preprocess_context_references(
|
|
"Check @file:src/main.py and @file:README.md",
|
|
cwd=sample_repo,
|
|
context_length=20,
|
|
)
|
|
assert not hard.expanded
|
|
assert hard.blocked
|
|
assert "@file:src/main.py" in hard.message
|
|
assert any("50%" in warning for warning in hard.warnings)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_url_expansion_uses_fetcher(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references_async
|
|
|
|
async def fake_fetch(url: str) -> str:
|
|
assert url == "https://example.com/spec"
|
|
return "# Spec\n\nImportant details."
|
|
|
|
result = await preprocess_context_references_async(
|
|
"Use @url:https://example.com/spec",
|
|
cwd=sample_repo,
|
|
context_length=100_000,
|
|
url_fetcher=fake_fetch,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert "Important details." in result.message
|
|
assert result.injected_tokens > 0
|
|
|
|
|
|
def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
async def fake_fetch(url: str) -> str:
|
|
await asyncio.sleep(0)
|
|
return f"Content for {url}"
|
|
|
|
result = preprocess_context_references(
|
|
"Use @url:https://example.com/spec",
|
|
cwd=sample_repo,
|
|
context_length=100_000,
|
|
url_fetcher=fake_fetch,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert "Content for https://example.com/spec" in result.message
|
|
|
|
|
|
def test_restricts_paths_to_allowed_root(tmp_path: Path):
|
|
from agent.context_references import preprocess_context_references
|
|
|
|
workspace = tmp_path / "workspace"
|
|
workspace.mkdir()
|
|
(workspace / "notes.txt").write_text("inside\n", encoding="utf-8")
|
|
secret = tmp_path / "secret.txt"
|
|
secret.write_text("outside\n", encoding="utf-8")
|
|
|
|
result = preprocess_context_references(
|
|
"read @file:../secret.txt and @file:notes.txt",
|
|
cwd=workspace,
|
|
context_length=100_000,
|
|
allowed_root=workspace,
|
|
)
|
|
|
|
assert result.expanded
|
|
assert "```\noutside\n```" not in result.message
|
|
assert "inside" in result.message
|
|
assert any("outside the allowed workspace" in warning for warning in result.warnings)
|