feat: @ context references — inline file, folder, diff, git, and URL injection
Add @file:path, @folder:dir, @diff, @staged, @git:N, and @url: references that expand inline before the message reaches the LLM. Supports line ranges (@file:main.py:10-50), token budget enforcement (soft warn at 25%, hard block at 50%), and path sandboxing for gateway. Core module from PR #2090 by @kshitijk4poor. CLI and gateway wiring rewritten against current main. Fixed asyncio.run() crash when called from inside a running event loop (gateway). Closes #682.
This commit is contained in:
440
agent/context_references.py
Normal file
440
agent/context_references.py
Normal file
@@ -0,0 +1,440 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextReference:
|
||||
raw: str
|
||||
kind: str
|
||||
target: str
|
||||
start: int
|
||||
end: int
|
||||
line_start: int | None = None
|
||||
line_end: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextReferenceResult:
|
||||
message: str
|
||||
original_message: str
|
||||
references: list[ContextReference] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
injected_tokens: int = 0
|
||||
expanded: bool = False
|
||||
blocked: bool = False
|
||||
|
||||
|
||||
def parse_context_references(message: str) -> list[ContextReference]:
|
||||
refs: list[ContextReference] = []
|
||||
if not message:
|
||||
return refs
|
||||
|
||||
for match in REFERENCE_PATTERN.finditer(message):
|
||||
simple = match.group("simple")
|
||||
if simple:
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=simple,
|
||||
target="",
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
kind = match.group("kind")
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=kind,
|
||||
target=target,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
)
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
def preprocess_context_references(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
coro = preprocess_context_references_async(
|
||||
message,
|
||||
cwd=cwd,
|
||||
context_length=context_length,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root,
|
||||
)
|
||||
# Safe for both CLI (no loop) and gateway (loop already running).
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
async def preprocess_context_references_async(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
refs = parse_context_references(message)
|
||||
if not refs:
|
||||
return ContextReferenceResult(message=message, original_message=message)
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
allowed_root_path = Path(allowed_root).expanduser().resolve() if allowed_root is not None else None
|
||||
warnings: list[str] = []
|
||||
blocks: list[str] = []
|
||||
injected_tokens = 0
|
||||
|
||||
for ref in refs:
|
||||
warning, block = await _expand_reference(
|
||||
ref,
|
||||
cwd_path,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root_path,
|
||||
)
|
||||
if warning:
|
||||
warnings.append(warning)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
injected_tokens += estimate_tokens_rough(block)
|
||||
|
||||
hard_limit = max(1, int(context_length * 0.50))
|
||||
soft_limit = max(1, int(context_length * 0.25))
|
||||
if injected_tokens > hard_limit:
|
||||
warnings.append(
|
||||
f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
|
||||
)
|
||||
return ContextReferenceResult(
|
||||
message=message,
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=False,
|
||||
blocked=True,
|
||||
)
|
||||
|
||||
if injected_tokens > soft_limit:
|
||||
warnings.append(
|
||||
f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
|
||||
)
|
||||
|
||||
stripped = _remove_reference_tokens(message, refs)
|
||||
final = stripped
|
||||
if warnings:
|
||||
final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
|
||||
if blocks:
|
||||
final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
|
||||
|
||||
return ContextReferenceResult(
|
||||
message=final.strip(),
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=bool(blocks or warnings),
|
||||
blocked=False,
|
||||
)
|
||||
|
||||
|
||||
async def _expand_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
try:
|
||||
if ref.kind == "file":
|
||||
return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "folder":
|
||||
return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "diff":
|
||||
return _expand_git_reference(ref, cwd, ["diff"], "git diff")
|
||||
if ref.kind == "staged":
|
||||
return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
|
||||
if ref.kind == "git":
|
||||
count = max(1, min(int(ref.target or "1"), 10))
|
||||
return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
|
||||
if ref.kind == "url":
|
||||
content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
|
||||
if not content:
|
||||
return f"{ref.raw}: no content extracted", None
|
||||
return None, f"🌐 {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
|
||||
except Exception as exc:
|
||||
return f"{ref.raw}: {exc}", None
|
||||
|
||||
return f"{ref.raw}: unsupported reference type", None
|
||||
|
||||
|
||||
def _expand_file_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: file not found", None
|
||||
if not path.is_file():
|
||||
return f"{ref.raw}: path is not a file", None
|
||||
if _is_binary_file(path):
|
||||
return f"{ref.raw}: binary files are not supported", None
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if ref.line_start is not None:
|
||||
lines = text.splitlines()
|
||||
start_idx = max(ref.line_start - 1, 0)
|
||||
end_idx = min(ref.line_end or ref.line_start, len(lines))
|
||||
text = "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
lang = _code_fence_language(path)
|
||||
label = ref.raw
|
||||
return None, f"📄 {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
|
||||
|
||||
|
||||
def _expand_folder_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: folder not found", None
|
||||
if not path.is_dir():
|
||||
return f"{ref.raw}: path is not a folder", None
|
||||
|
||||
listing = _build_folder_listing(path, cwd)
|
||||
return None, f"📁 {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
|
||||
|
||||
|
||||
def _expand_git_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
args: list[str],
|
||||
label: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip() or "git command failed"
|
||||
return f"{ref.raw}: {stderr}", None
|
||||
content = result.stdout.strip()
|
||||
if not content:
|
||||
content = "(no output)"
|
||||
return None, f"🧾 {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
|
||||
|
||||
|
||||
async def _fetch_url_content(
|
||||
url: str,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
) -> str:
|
||||
fetcher = url_fetcher or _default_url_fetcher
|
||||
content = fetcher(url)
|
||||
if inspect.isawaitable(content):
|
||||
content = await content
|
||||
return str(content or "").strip()
|
||||
|
||||
|
||||
async def _default_url_fetcher(url: str) -> str:
|
||||
from tools.web_tools import web_extract_tool
|
||||
|
||||
raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
|
||||
payload = json.loads(raw)
|
||||
docs = payload.get("data", {}).get("documents", [])
|
||||
if not docs:
|
||||
return ""
|
||||
doc = docs[0]
|
||||
return str(doc.get("content") or doc.get("raw_content") or "").strip()
|
||||
|
||||
|
||||
def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
|
||||
path = Path(os.path.expanduser(target))
|
||||
if not path.is_absolute():
|
||||
path = cwd / path
|
||||
resolved = path.resolve()
|
||||
if allowed_root is not None:
|
||||
try:
|
||||
resolved.relative_to(allowed_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError("path is outside the allowed workspace") from exc
|
||||
return resolved
|
||||
|
||||
|
||||
def _strip_trailing_punctuation(value: str) -> str:
|
||||
stripped = value.rstrip(TRAILING_PUNCTUATION)
|
||||
while stripped.endswith((")", "]", "}")):
|
||||
closer = stripped[-1]
|
||||
opener = {")": "(", "]": "[", "}": "{"}[closer]
|
||||
if stripped.count(closer) > stripped.count(opener):
|
||||
stripped = stripped[:-1]
|
||||
continue
|
||||
break
|
||||
return stripped
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
for ref in refs:
|
||||
pieces.append(message[cursor:ref.start])
|
||||
cursor = ref.end
|
||||
pieces.append(message[cursor:])
|
||||
text = "".join(pieces)
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
text = re.sub(r"\s+([,.;:!?])", r"\1", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _is_binary_file(path: Path) -> bool:
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
if mime and not mime.startswith("text/") and not any(
|
||||
path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
|
||||
):
|
||||
return True
|
||||
chunk = path.read_bytes()[:4096]
|
||||
return b"\x00" in chunk
|
||||
|
||||
|
||||
def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
|
||||
lines = [f"{path.relative_to(cwd)}/"]
|
||||
entries = _iter_visible_entries(path, cwd, limit=limit)
|
||||
for entry in entries:
|
||||
rel = entry.relative_to(cwd)
|
||||
indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
|
||||
if entry.is_dir():
|
||||
lines.append(f"{indent}- {entry.name}/")
|
||||
else:
|
||||
meta = _file_metadata(entry)
|
||||
lines.append(f"{indent}- {entry.name} ({meta})")
|
||||
if len(entries) >= limit:
|
||||
lines.append("- ...")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
|
||||
rg_entries = _rg_files(path, cwd, limit=limit)
|
||||
if rg_entries is not None:
|
||||
output: list[Path] = []
|
||||
seen_dirs: set[Path] = set()
|
||||
for rel in rg_entries:
|
||||
full = cwd / rel
|
||||
for parent in full.parents:
|
||||
if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
|
||||
continue
|
||||
seen_dirs.add(parent)
|
||||
output.append(parent)
|
||||
output.append(full)
|
||||
return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
|
||||
|
||||
output = []
|
||||
for root, dirs, files in os.walk(path):
|
||||
dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
|
||||
files = sorted(f for f in files if not f.startswith("."))
|
||||
root_path = Path(root)
|
||||
for d in dirs:
|
||||
output.append(root_path / d)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
for f in files:
|
||||
output.append(root_path / f)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rg", "--files", str(path.relative_to(cwd))],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
|
||||
return files[:limit]
|
||||
|
||||
|
||||
def _file_metadata(path: Path) -> str:
|
||||
if _is_binary_file(path):
|
||||
return f"{path.stat().st_size} bytes"
|
||||
try:
|
||||
line_count = path.read_text(encoding="utf-8").count("\n") + 1
|
||||
except Exception:
|
||||
return f"{path.stat().st_size} bytes"
|
||||
return f"{line_count} lines"
|
||||
|
||||
|
||||
def _code_fence_language(path: Path) -> str:
|
||||
mapping = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".json": "json",
|
||||
".md": "markdown",
|
||||
".sh": "bash",
|
||||
".yml": "yaml",
|
||||
".yaml": "yaml",
|
||||
".toml": "toml",
|
||||
}
|
||||
return mapping.get(path.suffix.lower(), "")
|
||||
22
cli.py
22
cli.py
@@ -5320,6 +5320,28 @@ class HermesCLI:
|
||||
message if isinstance(message, str) else "", images
|
||||
)
|
||||
|
||||
# Expand @ context references (e.g. @file:main.py, @diff, @folder:src/)
|
||||
if isinstance(message, str) and "@" in message:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_ctx_len = get_model_context_length(
|
||||
self.model, base_url=self.base_url or "", api_key=self.api_key or "")
|
||||
_ctx_result = preprocess_context_references(
|
||||
message, cwd=os.getcwd(), context_length=_ctx_len)
|
||||
if _ctx_result.expanded or _ctx_result.blocked:
|
||||
if _ctx_result.references:
|
||||
_cprint(
|
||||
f" {_DIM}[@ context: {len(_ctx_result.references)} ref(s), "
|
||||
f"{_ctx_result.injected_tokens} tokens]{_RST}")
|
||||
for w in _ctx_result.warnings:
|
||||
_cprint(f" {_DIM}⚠ {w}{_RST}")
|
||||
if _ctx_result.blocked:
|
||||
return "\n".join(_ctx_result.warnings) or "Context injection refused."
|
||||
message = _ctx_result.message
|
||||
except Exception as e:
|
||||
logging.debug("@ context reference expansion failed: %s", e)
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
|
||||
|
||||
@@ -2058,7 +2058,31 @@ class GatewayRunner:
|
||||
"message": message_text[:500],
|
||||
}
|
||||
await self.hooks.emit("agent:start", hook_ctx)
|
||||
|
||||
|
||||
# Expand @ context references (@file:, @folder:, @diff, etc.)
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model, base_url=self._base_url or "")
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text, cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
# Run the agent
|
||||
agent_result = await self._run_agent(
|
||||
message=message_text,
|
||||
|
||||
221
tests/test_context_references.py
Normal file
221
tests/test_context_references.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _git(cwd: Path, *args: str) -> str:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_repo(tmp_path: Path) -> Path:
|
||||
repo = tmp_path / "repo"
|
||||
repo.mkdir()
|
||||
_git(repo, "init")
|
||||
_git(repo, "config", "user.name", "Hermes Tests")
|
||||
_git(repo, "config", "user.email", "tests@example.com")
|
||||
|
||||
(repo / "src").mkdir()
|
||||
(repo / "src" / "main.py").write_text(
|
||||
"def alpha():\n"
|
||||
" return 'a'\n\n"
|
||||
"def beta():\n"
|
||||
" return 'b'\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8")
|
||||
(repo / "README.md").write_text("# Demo\n", encoding="utf-8")
|
||||
(repo / "blob.bin").write_bytes(b"\x00\x01\x02binary")
|
||||
|
||||
_git(repo, "add", ".")
|
||||
_git(repo, "commit", "-m", "initial")
|
||||
|
||||
(repo / "src" / "main.py").write_text(
|
||||
"def alpha():\n"
|
||||
" return 'changed'\n\n"
|
||||
"def beta():\n"
|
||||
" return 'b'\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8")
|
||||
_git(repo, "add", "src/helper.py")
|
||||
return repo
|
||||
|
||||
|
||||
def test_parse_typed_references_ignores_emails_and_handles():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
message = (
|
||||
"email me at user@example.com and ping @teammate "
|
||||
"but include @file:src/main.py:1-2 plus @diff and @git:2 "
|
||||
"and @url:https://example.com/docs"
|
||||
)
|
||||
|
||||
refs = parse_context_references(message)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"]
|
||||
assert refs[0].target == "src/main.py"
|
||||
assert refs[0].line_start == 1
|
||||
assert refs[0].line_end == 2
|
||||
assert refs[2].target == "2"
|
||||
|
||||
|
||||
def test_parse_references_strips_trailing_punctuation():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
refs = parse_context_references(
|
||||
"review @file:README.md, then see (@url:https://example.com/docs)."
|
||||
)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "url"]
|
||||
assert refs[0].target == "README.md"
|
||||
assert refs[1].target == "https://example.com/docs"
|
||||
|
||||
|
||||
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Review @file:src/main.py:1-2 and @folder:src/",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Review and" in result.message
|
||||
assert "Review @file:src/main.py:1-2" not in result.message
|
||||
assert "--- Attached Context ---" in result.message
|
||||
assert "def alpha():" in result.message
|
||||
assert "return 'changed'" in result.message
|
||||
assert "def beta():" not in result.message
|
||||
assert "src/" in result.message
|
||||
assert "main.py" in result.message
|
||||
assert "helper.py" in result.message
|
||||
assert result.injected_tokens > 0
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Inspect @diff and @staged and @git:1",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "git diff" in result.message
|
||||
assert "git diff --staged" in result.message
|
||||
assert "git log -1 -p" in result.message
|
||||
assert "initial" in result.message
|
||||
assert "return 'changed'" in result.message
|
||||
assert "VALUE = 2" in result.message
|
||||
|
||||
|
||||
def test_binary_and_missing_files_become_warnings(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Check @file:blob.bin and @file:nope.txt",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert len(result.warnings) == 2
|
||||
assert "binary" in result.message.lower()
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
soft = preprocess_context_references(
|
||||
"Check @file:src/main.py",
|
||||
cwd=sample_repo,
|
||||
context_length=100,
|
||||
)
|
||||
assert soft.expanded
|
||||
assert any("25%" in warning for warning in soft.warnings)
|
||||
|
||||
hard = preprocess_context_references(
|
||||
"Check @file:src/main.py and @file:README.md",
|
||||
cwd=sample_repo,
|
||||
context_length=20,
|
||||
)
|
||||
assert not hard.expanded
|
||||
assert hard.blocked
|
||||
assert "@file:src/main.py" in hard.message
|
||||
assert any("50%" in warning for warning in hard.warnings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_url_expansion_uses_fetcher(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
|
||||
async def fake_fetch(url: str) -> str:
|
||||
assert url == "https://example.com/spec"
|
||||
return "# Spec\n\nImportant details."
|
||||
|
||||
result = await preprocess_context_references_async(
|
||||
"Use @url:https://example.com/spec",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
url_fetcher=fake_fetch,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Important details." in result.message
|
||||
assert result.injected_tokens > 0
|
||||
|
||||
|
||||
def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
async def fake_fetch(url: str) -> str:
|
||||
await asyncio.sleep(0)
|
||||
return f"Content for {url}"
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Use @url:https://example.com/spec",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
url_fetcher=fake_fetch,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Content for https://example.com/spec" in result.message
|
||||
|
||||
|
||||
def test_restricts_paths_to_allowed_root(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "notes.txt").write_text("inside\n", encoding="utf-8")
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("outside\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
"read @file:../secret.txt and @file:notes.txt",
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
allowed_root=workspace,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "```\noutside\n```" not in result.message
|
||||
assert "inside" in result.message
|
||||
assert any("outside the allowed workspace" in warning for warning in result.warnings)
|
||||
Reference in New Issue
Block a user