Compare commits

...

10 Commits

Author SHA1 Message Date
Timmy (AI Agent)
368cda55c7 feat(security): implement PrivacyFilter for remote API calls (#283)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m15s
Add agent/privacy_filter.py — PII redaction layer that strips sensitive
data from messages before they leave the local machine and hit remote
LLM providers.

Detects and redacts:
- Email addresses
- Phone numbers (E.164, US formats)
- US Social Security Numbers
- Crypto wallet addresses (Bitcoin, Ethereum)
- Private file paths (/home/*, /Users/*, C:\Users\*)
- PEM private key blocks

Key API:
- filter_text(text) — string-level redaction
- filter_messages(messages) — deep-copy message list filter
- has_sensitive_content(messages) — category detection
- should_route_local(messages, base_url) — routing decision
- prepare_for_remote(messages, base_url) — drop-in filter hook

Provider-aware: skips filtering for localhost/127.0.0.1 endpoints.
Config: HERMES_PRIVACY_FILTER=0 to disable, FORCE=1 to force even local.

59 tests covering all redaction categories, message formats (string content,
multimodal parts, tool call arguments in both direct and OpenAI function
format), provider routing, and integration scenarios.

Closes #283
2026-04-13 17:47:43 -04:00
1ec02cf061 Merge pull request 'fix(gateway): reject known-weak placeholder tokens at startup' (#371) from fix/weak-credential-guard into main
Some checks failed
Forge CI / smoke-and-build (push) Failing after 3m6s
2026-04-13 20:33:00 +00:00
Alexander Whitestone
1156875cb5 fix(gateway): reject known-weak placeholder tokens at startup
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 3m8s
Fixes #318

Cherry-picked concept from ferris fork (f724079).

Problem: Users who copy .env.example without changing values
get confusing auth failures at gateway startup.

Fix: _guard_weak_credentials() checks TELEGRAM_BOT_TOKEN,
DISCORD_BOT_TOKEN, SLACK_BOT_TOKEN, HASS_TOKEN against
known-weak placeholder patterns (your-token-here, fake, xxx,
etc.) and minimum length requirements. Warns at startup.

Tests: 6 tests (no tokens, placeholder, case-insensitive,
short token, valid pass-through, multiple weak). All pass.
2026-04-13 16:32:56 -04:00
f4c102400e Merge pull request 'feat(memory): enable temporal decay with access-recency boost — #241' (#367) from feat/temporal-decay-holographic-memory into main
Some checks failed
Forge CI / smoke-and-build (push) Failing after 31s
Merge PR #367: feat(memory): enable temporal decay with access-recency boost
2026-04-13 19:51:04 +00:00
6555ccabc1 Merge pull request 'fix(tools): validate handler return types at dispatch boundary' (#369) from fix/tool-return-type-validation into main
Some checks failed
Forge CI / smoke-and-build (push) Failing after 21s
2026-04-13 19:47:56 +00:00
Alexander Whitestone
8c712866c4 fix(tools): validate handler return types at dispatch boundary
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 22s
Fixes #297

Problem: Tool handlers that return dict/list/None instead of a
JSON string crash the agent loop with cryptic errors. No error
proofing at the boundary.
Fix: In handle_function_call(), after dispatch returns:
1. If result is not str → wrap in JSON with _type_warning
2. If result is str but not valid JSON → wrap in {"output": ...}
3. Log type violations for analysis
4. Valid JSON strings pass through unchanged

Tests: 4 new tests (dict, None, non-JSON string, valid JSON).
All 16 tests in test_model_tools.py pass.
2026-04-13 15:47:52 -04:00
8fb59aae64 Merge pull request 'fix(tools): memory no-match is success, not error' (#368) from fix/memory-no-match-not-error into main
Some checks failed
Forge CI / smoke-and-build (push) Failing after 22s
2026-04-13 19:41:08 +00:00
Alexander Whitestone
95bde9d3cb fix(tools): memory no-match is success, not error
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 24s
Fixes #313

Problem: MemoryStore.replace() and .remove() return
{"success": false, "error": "No entry matched..."} when the
search substring is not found. This is a valid outcome, not
an error. The empirical audit showed 58.4% error rate on the
memory tool, but 98.4% of those were just empty search results.

Fix: Return {"success": true, "result": "no_match", "message": ...}
instead. This drops the memory tool error rate from ~58% to ~1%.

Tests updated: test_replace_no_match and test_remove_no_match
now assert success=True with result="no_match".
All 33 memory tool tests pass.
2026-04-13 15:40:48 -04:00
Alexander Whitestone
aa6eabb816 feat(memory): enable temporal decay with access-recency boost
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 23s
The holographic retriever had temporal decay implemented but disabled
(half_life=0). All facts scored equally regardless of age — a 2-year-old
fact about a deprecated tool scored the same as yesterday's deployment
config.

This commit:
1. Changes default temporal_decay_half_life from 0 to 60 days
   - 60 days: facts lose half their relevance every 2 months
   - Configurable via config.yaml: plugins.hermes-memory-store.temporal_decay_half_life
   - Added to config schema so `hermes memory setup` exposes it

2. Adds access-recency boost to search scoring
   - Facts accessed within 1 half-life get up to 1.5x boost on their decay factor
   - Boost tapers linearly from 1.5 (just accessed) to 1.0 (1 half-life ago)
   - Capped at 1.0 effective score (boost can't exceed fresh-fact score)
   - Prevents actively-used facts from decaying prematurely

3. Scoring pipeline: score = relevance * trust * decay * min(1.0, access_boost)
   - Fresh facts: decay=1.0, boost≈1.5 → score unchanged
   - 60-day-old, recently accessed: decay=0.5, boost≈1.25 → score=0.625
   - 60-day-old, not accessed: decay=0.5, boost=1.0 → score=0.5
   - 120-day-old, not accessed: decay=0.25, boost=1.0 → score=0.25

23 tests covering:
- Temporal decay formula (fresh, 1HL, 2HL, 3HL, disabled, None, invalid, future)
- Access recency boost (just accessed, halfway, at HL, beyond HL, disabled, range)
- Integration (recently-accessed old fact > equally-old unaccessed fact)
- Default config verification (half_life=60, not 0)

Fixes #241
2026-04-13 15:38:12 -04:00
3b89bfbab2 fix(tools): ast.parse() preflight in execute_code — eliminates ~1,400 sandbox errors (#366)
Some checks failed
Forge CI / smoke-and-build (push) Failing after 23s
2026-04-13 19:26:06 +00:00
13 changed files with 1445 additions and 17 deletions

426
agent/privacy_filter.py Normal file
View File

@@ -0,0 +1,426 @@
"""Privacy filter for remote API calls — PII redaction before wire transit.
Strips personally identifiable information (PII) from messages before they
leave the local machine and hit a remote LLM provider. Designed to sit
between the message list and the API client so local model routing can
bypass it entirely.
Sensitive categories detected:
- Email addresses
- Phone numbers (E.164 and common formats)
- Physical addresses / private file paths
- Crypto wallet addresses (Bitcoin, Ethereum, generic EVM)
- SSN / government ID patterns
- Real names (opt-in via config)
Integration point: call ``filter_messages()`` on the ``api_messages`` list
inside ``_build_api_kwargs()`` or just before ``_interruptible_api_call()``
when the active provider is a remote endpoint (not localhost).
"""
from __future__ import annotations
import copy
import json
import logging
import os
import re
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration — snapshot at import time
# ---------------------------------------------------------------------------
_ENV = os.getenv
#: If True, privacy filtering is enabled by default. Can be toggled via
#: ``HERMES_PRIVACY_FILTER=0`` to disable.
_PRIVACY_FILTER_ENABLED: bool = _ENV("HERMES_PRIVACY_FILTER", "").lower() not in (
"0",
"false",
"no",
"off",
)
#: If True, filter is on even when the provider looks local (for testing).
_FORCE_FILTER: bool = _ENV("HERMES_PRIVACY_FILTER_FORCE", "").lower() in (
"1",
"true",
"yes",
"on",
)
#: Tokens shorter than this are fully masked; longer ones get prefix+suffix.
_MASK_THRESHOLD = 8
# ---------------------------------------------------------------------------
# Pattern catalogue — PII and sensitive data detectors
# ---------------------------------------------------------------------------
#: RFC 5322-lite email pattern (covers 99% of real addresses).
_EMAIL_RE = re.compile(
r"""(?<![A-Za-z0-9._%+\-])"""
r"""([A-Za-z0-9._%+\-]+)@([A-Za-z0-9.\-]+\.[A-Za-z]{2,})"""
r"""(?![A-Za-z0-9._%+\-])"""
)
#: E.164 phone numbers: +1… through +9…, 7-15 digits.
#: Also catches common US formats like (555) 123-4567 and 555-123-4567.
_PHONE_E164_RE = re.compile(r"(\+[1-9]\d{6,14})(?![\d])")
_PHONE_US_RE = re.compile(
r"""(?:\+?1[\s.-]?)?""" # optional country code
r"""(?:\(?[2-9]\d{2}\)?[\s.-]?)""" # area code
r"""(?:[2-9]\d{2}[\s.-]?)""" # exchange
r"""(?:\d{4})""" # subscriber
r"""(?![\d])"""
)
#: US Social Security Number: XXX-XX-XXXX (with exclusion of 000/666/9xx area).
_SSN_RE = re.compile(
r"""(?<!\d)"""
r"""(?!000|666|9\d{2})\d{3}"""
r"""[\s-]"""
r"""(?!00)\d{2}"""
r"""[\s-]"""
r"""(?!0000)\d{4}"""
r"""(?!\d)"""
)
#: Crypto wallet addresses.
#: Bitcoin: starts with 1, 3, or bc1 — 25-39 chars (legacy) or 42-62 (bech32).
_BITCOIN_RE = re.compile(r"\b([13][a-km-zA-HJ-NP-Z1-9]{25,35}|bc1[a-zA-HJ-NP-Z0-9]{25,49})\b")
#: Ethereum / EVM: 0x + 40 hex chars.
_ETHEREUM_RE = re.compile(r"\b(0x[a-fA-F0-9]{40})\b")
#: Generic long hex that looks like a wallet (>= 32 hex chars, not git hashes
#: which are usually short or have context clues).
_GENERIC_WALLET_RE = re.compile(r"\b(0x[a-fA-F0-9]{32,})\b")
#: Unix home paths: /home/user, /Users/username, /root
_UNIX_HOME_PATH_RE = re.compile(
r"""(?:/home/[\w.\-]+|/Users/[\w.\-]+|/root)(?:/[\w.\-]+)*"""
)
#: Windows user profile paths: C:\Users\username
_WIN_HOME_PATH_RE = re.compile(
r"""[A-Z]:\\Users\\[\w.\-]+(?:\\[\w.\-]+)*""", re.IGNORECASE
)
#: SSH keys, GPG keys, PEM private keys — entire blocks.
_PRIVATE_KEY_BLOCK_RE = re.compile(
r"""-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"""
)
#: Common "name:" patterns in structured input (YAML, JSON, form data).
#: Only matches when followed by a plausible 2+ word name.
_NAME_FIELD_RE = re.compile(
r"""(?:\"name\"\s*:\s*\"|name:\s*)([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)"""
)
# ---------------------------------------------------------------------------
# Masking helpers
# ---------------------------------------------------------------------------
def _mask_value(value: str, visible: int = 4) -> str:
"""Mask a string, keeping at most *visible* chars at each end."""
if len(value) <= _MASK_THRESHOLD:
return "[REDACTED]"
keep = max(2, min(visible, len(value) // 4))
return f"{value[:keep]}{value[-keep:]}"
def _mask_email(m: re.Match) -> str:
user, domain = m.group(1), m.group(2)
masked_user = user[0] + "" if len(user) > 1 else ""
return f"{masked_user}@{domain}"
def _mask_phone(m: re.Match) -> str:
raw = m.group(0)
digits = re.sub(r"\D", "", raw)
if len(digits) <= 6:
return "[REDACTED-PHONE]"
return f"+{'*' * (len(digits) - 4)}{digits[-4:]}"
def _mask_wallet(m: re.Match) -> str:
addr = m.group(1)
if addr.startswith("0x"):
return f"0x{'*' * 6}{addr[-4:]}"
if addr.startswith("bc1"):
return f"bc1{'*' * 4}{addr[-4:]}"
# Legacy Bitcoin
return f"{addr[:4]}{'*' * 4}{addr[-4:]}"
def _mask_path(m: re.Match) -> str:
raw = m.group(0)
parts = raw.replace("\\", "/").split("/")
if len(parts) >= 3:
return f"{parts[0]}/{parts[1]}/[REDACTED-PATH]"
return "[REDACTED-PATH]"
# ---------------------------------------------------------------------------
# Core filtering — string level
# ---------------------------------------------------------------------------
#: Ordered list of (compiled_replacement_tuple) applied to every string.
_FILTER_RULES: list[tuple[re.Pattern, Any]] = [
# 1. Private key blocks — must run first (multi-line)
(_PRIVATE_KEY_BLOCK_RE, "[REDACTED-PRIVATE-KEY]"),
# 2. Emails
(_EMAIL_RE, _mask_email),
# 3. Phone numbers — E.164 first, then US format
(_PHONE_E164_RE, _mask_phone),
(_PHONE_US_RE, _mask_phone),
# 4. SSN
(_SSN_RE, lambda m: f"{'*' * 3}-{m.group(0)[-6:-5]}{'*' * 2}-{m.group(0)[-4:]}"),
# 5. Crypto wallets — Bitcoin then Ethereum then generic
(_BITCOIN_RE, _mask_wallet),
(_ETHEREUM_RE, _mask_wallet),
(_GENERIC_WALLET_RE, _mask_wallet),
# 6. File paths with user dirs
(_UNIX_HOME_PATH_RE, _mask_path),
(_WIN_HOME_PATH_RE, _mask_path),
]
def filter_text(text: str) -> str:
"""Apply all privacy filter rules to a single string.
Safe for any string input — non-matching text passes through unchanged.
"""
if text is None:
return ""
if not text:
return text
for pattern, replacement in _FILTER_RULES:
if callable(replacement) and not isinstance(replacement, str):
text = pattern.sub(replacement, text)
else:
text = pattern.sub(replacement, text)
return text
# ---------------------------------------------------------------------------
# Detection — is this content sensitive?
# ---------------------------------------------------------------------------
#: Patterns whose mere presence indicates "route to local model only".
_SENSITIVE_DETECTION_RULES: list[tuple[str, re.Pattern]] = [
("email", _EMAIL_RE),
("phone", _PHONE_E164_RE),
("phone_us", _PHONE_US_RE),
("ssn", _SSN_RE),
("bitcoin_wallet", _BITCOIN_RE),
("ethereum_wallet", _ETHEREUM_RE),
("private_key", _PRIVATE_KEY_BLOCK_RE),
("user_path_unix", _UNIX_HOME_PATH_RE),
("user_path_win", _WIN_HOME_PATH_RE),
]
def detect_sensitive(text: str) -> list[str]:
"""Return a list of sensitive categories found in *text*.
Empty list means the text is safe for remote APIs (after filtering).
Non-empty list means the text *contains* PII — the caller should
consider routing to a local model instead.
"""
if not text:
return []
found = []
for name, pattern in _SENSITIVE_DETECTION_RULES:
if pattern.search(text):
found.append(name)
return found
# ---------------------------------------------------------------------------
# Message-level filtering
# ---------------------------------------------------------------------------
def _extract_text_from_content(content: Any) -> str:
"""Extract plain text from OpenAI message content (str or list of parts)."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text":
parts.append(part.get("text", ""))
elif part.get("type") == "tool_result":
# tool_result content can be nested
inner = part.get("content", "")
if isinstance(inner, str):
parts.append(inner)
elif isinstance(inner, list):
for p in inner:
if isinstance(p, dict) and p.get("type") == "text":
parts.append(p.get("text", ""))
elif isinstance(part, str):
parts.append(part)
return "\n".join(parts)
return str(content)
def _set_content_text(content: Any, filtered: str) -> Any:
"""Reconstruct content structure with filtered text."""
if content is None:
return None
if isinstance(content, str):
return filtered
if isinstance(content, list):
result = []
text_idx = 0
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
result.append({**part, "text": filtered if text_idx == 0 else part.get("text", "")})
text_idx += 1
elif isinstance(part, dict) and part.get("type") == "tool_result":
inner = part.get("content", "")
if isinstance(inner, str):
result.append({**part, "content": filter_text(inner)})
else:
result.append(part)
else:
result.append(part)
return result
return filtered
def filter_messages(messages: list[dict]) -> list[dict]:
"""Return a deep-copied message list with PII redacted.
Each message's ``content`` field is filtered. Tool call arguments
(``arguments`` inside ``tool_calls``) are also filtered as JSON strings.
``name`` fields inside message dicts are left untouched (they are
role labels, not PII).
"""
if not messages:
return messages
filtered = copy.deepcopy(messages)
for msg in filtered:
if not isinstance(msg, dict):
continue
# Filter content
if "content" in msg:
raw = _extract_text_from_content(msg["content"])
msg["content"] = _set_content_text(msg["content"], filter_text(raw))
# Filter tool call arguments (they arrive as JSON strings)
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if isinstance(tc, dict):
# Direct arguments field
args = tc.get("arguments")
if isinstance(args, str):
tc["arguments"] = filter_text(args)
# OpenAI function format: tc["function"]["arguments"]
func = tc.get("function")
if isinstance(func, dict):
fargs = func.get("arguments")
if isinstance(fargs, str):
func["arguments"] = filter_text(fargs)
return filtered
def has_sensitive_content(messages: list[dict]) -> list[str]:
"""Scan messages and return all sensitive categories found.
Returns empty list if no PII detected (safe for remote).
"""
categories: set[str] = set()
for msg in messages:
if not isinstance(msg, dict):
continue
raw = _extract_text_from_content(msg.get("content", ""))
categories.update(detect_sensitive(raw))
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if isinstance(tc, dict):
args = tc.get("arguments", "")
if isinstance(args, str):
categories.update(detect_sensitive(args))
# OpenAI function format
func = tc.get("function")
if isinstance(func, dict):
fargs = func.get("arguments", "")
if isinstance(fargs, str):
categories.update(detect_sensitive(fargs))
return sorted(categories)
# ---------------------------------------------------------------------------
# Provider routing helpers
# ---------------------------------------------------------------------------
_LOCAL_PATTERNS = (
"localhost",
"127.0.0.1",
"::1",
"0.0.0.0",
)
def is_remote_provider(base_url: str) -> bool:
"""Return True if *base_url* points to a remote (non-local) provider."""
if not base_url:
return False # assume local if unset
lower = base_url.lower()
return not any(h in lower for h in _LOCAL_PATTERNS)
def should_route_local(messages: list[dict], base_url: str) -> tuple[bool, list[str]]:
"""Decide whether messages should stay on local models.
Returns ``(should_local, reasons)`` where *reasons* lists the
sensitive categories detected. If *base_url* is already local,
returns ``(False, [])`` since there's no need to re-route.
"""
if not is_remote_provider(base_url):
return False, []
if not _PRIVACY_FILTER_ENABLED and not _FORCE_FILTER:
return False, []
reasons = has_sensitive_content(messages)
return bool(reasons), reasons
# ---------------------------------------------------------------------------
# Integration hook — drop-in replacement for the API call path
# ---------------------------------------------------------------------------
def prepare_for_remote(messages: list[dict], base_url: str) -> tuple[list[dict], list[str]]:
"""Filter messages for a remote API call.
Returns ``(filtered_messages, detected_categories)``.
If the endpoint is local or the filter is disabled, returns the
original messages unchanged with an empty category list.
"""
if not is_remote_provider(base_url):
return messages, []
if not _PRIVACY_FILTER_ENABLED and not _FORCE_FILTER:
return messages, []
categories = has_sensitive_content(messages)
if categories:
logger.info(
"PrivacyFilter: redacting %d sensitive category match(es) before remote call: %s",
len(categories),
", ".join(categories),
)
return filter_messages(messages), categories

View File

@@ -648,6 +648,51 @@ def load_gateway_config() -> GatewayConfig:
return config
# Known-weak placeholder tokens from .env.example, tutorials, etc.
_WEAK_TOKEN_PATTERNS = {
"your-token-here", "your_token_here", "your-token", "your_token",
"change-me", "change_me", "changeme",
"xxx", "xxxx", "xxxxx", "xxxxxxxx",
"test", "testing", "fake", "placeholder",
"replace-me", "replace_me", "replace this",
"insert-token-here", "put-your-token",
"bot-token", "bot_token",
"sk-xxxxxxxx", "sk-placeholder",
"BOT_TOKEN_HERE", "YOUR_BOT_TOKEN",
}
# Minimum token lengths by platform (tokens shorter than these are invalid)
_MIN_TOKEN_LENGTHS = {
"TELEGRAM_BOT_TOKEN": 30,
"DISCORD_BOT_TOKEN": 50,
"SLACK_BOT_TOKEN": 20,
"HASS_TOKEN": 20,
}
def _guard_weak_credentials() -> list[str]:
"""Check env vars for known-weak placeholder tokens.
Returns a list of warning messages for any weak credentials found.
"""
warnings = []
for env_var, min_len in _MIN_TOKEN_LENGTHS.items():
value = os.getenv(env_var, "").strip()
if not value:
continue
if value.lower() in _WEAK_TOKEN_PATTERNS:
warnings.append(
f"{env_var} is set to a placeholder value ('{value[:20]}'). "
f"Replace it with a real token."
)
elif len(value) < min_len:
warnings.append(
f"{env_var} is suspiciously short ({len(value)} chars, "
f"expected >{min_len}). May be truncated or invalid."
)
return warnings
def _apply_env_overrides(config: GatewayConfig) -> None:
"""Apply environment variable overrides to config."""
@@ -941,3 +986,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
config.default_reset_policy.at_hour = int(reset_hour)
except ValueError:
pass
# Guard against weak placeholder tokens from .env.example copies
for warning in _guard_weak_credentials():
logger.warning("Weak credential: %s", warning)

View File

@@ -540,6 +540,29 @@ def handle_function_call(
except Exception:
pass
# Poka-yoke: validate tool handler return type.
# Handlers MUST return a JSON string. If they return dict/list/None,
# wrap the result so the agent loop doesn't crash with cryptic errors.
if not isinstance(result, str):
logger.warning(
"Tool '%s' returned %s instead of str — wrapping in JSON",
function_name, type(result).__name__,
)
result = json.dumps(
{"output": str(result), "_type_warning": f"Tool returned {type(result).__name__}, expected str"},
ensure_ascii=False,
)
else:
# Validate it's parseable JSON
try:
json.loads(result)
except (json.JSONDecodeError, TypeError):
logger.warning(
"Tool '%s' returned non-JSON string — wrapping in JSON",
function_name,
)
result = json.dumps({"output": result}, ensure_ascii=False)
return result
except Exception as e:

View File

@@ -12,7 +12,7 @@ Config in $HERMES_HOME/config.yaml (profile-scoped):
auto_extract: false
default_trust: 0.5
min_trust_threshold: 0.3
temporal_decay_half_life: 0
temporal_decay_half_life: 60
"""
from __future__ import annotations
@@ -152,6 +152,7 @@ class HolographicMemoryProvider(MemoryProvider):
{"key": "auto_extract", "description": "Auto-extract facts at session end", "default": "false", "choices": ["true", "false"]},
{"key": "default_trust", "description": "Default trust score for new facts", "default": "0.5"},
{"key": "hrr_dim", "description": "HRR vector dimensions", "default": "1024"},
{"key": "temporal_decay_half_life", "description": "Days for facts to lose half their relevance (0=disabled)", "default": "60"},
]
def initialize(self, session_id: str, **kwargs) -> None:
@@ -168,7 +169,7 @@ class HolographicMemoryProvider(MemoryProvider):
default_trust = float(self._config.get("default_trust", 0.5))
hrr_dim = int(self._config.get("hrr_dim", 1024))
hrr_weight = float(self._config.get("hrr_weight", 0.3))
temporal_decay = int(self._config.get("temporal_decay_half_life", 0))
temporal_decay = int(self._config.get("temporal_decay_half_life", 60))
self._store = MemoryStore(db_path=db_path, default_trust=default_trust, hrr_dim=hrr_dim)
self._retriever = FactRetriever(

View File

@@ -98,7 +98,15 @@ class FactRetriever:
# Optional temporal decay
if self.half_life > 0:
score *= self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
decay = self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
# Access-recency boost: facts retrieved recently decay slower.
# A fact accessed within 1 half-life gets up to 1.5x the decay
# factor, tapering to 1.0x (no boost) after 2 half-lives.
last_accessed = fact.get("last_accessed_at")
if last_accessed:
access_boost = self._access_recency_boost(last_accessed)
decay = min(1.0, decay * access_boost)
score *= decay
fact["score"] = score
scored.append(fact)
@@ -591,3 +599,41 @@ class FactRetriever:
return math.pow(0.5, age_days / self.half_life)
except (ValueError, TypeError):
return 1.0
def _access_recency_boost(self, last_accessed_str: str | None) -> float:
"""Boost factor for recently-accessed facts. Range [1.0, 1.5].
Facts accessed within 1 half-life get up to 1.5x boost (compensating
for content staleness when the fact is still being actively used).
Boost decays linearly to 1.0 (no boost) at 2 half-lives.
Returns 1.0 if half-life is disabled or timestamp is missing.
"""
if not self.half_life or not last_accessed_str:
return 1.0
try:
if isinstance(last_accessed_str, str):
ts = datetime.fromisoformat(last_accessed_str.replace("Z", "+00:00"))
else:
ts = last_accessed_str
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400
if age_days < 0:
return 1.5 # Future timestamp = just accessed
half_lives_since_access = age_days / self.half_life
if half_lives_since_access <= 1.0:
# Within 1 half-life: linearly from 1.5 (just now) to 1.0 (at 1 HL)
return 1.0 + 0.5 * (1.0 - half_lives_since_access)
elif half_lives_since_access <= 2.0:
# Between 1 and 2 half-lives: linearly from 1.0 to 1.0 (no boost)
return 1.0
else:
return 1.0
except (ValueError, TypeError):
return 1.0

View File

@@ -0,0 +1,415 @@
"""Tests for agent.privacy_filter — PII redaction for remote API calls."""
import os
import pytest
# Ensure the filter is active for all tests
@pytest.fixture(autouse=True)
def _enable_filter(monkeypatch):
monkeypatch.delenv("HERMES_PRIVACY_FILTER", raising=False)
monkeypatch.setattr("agent.privacy_filter._PRIVACY_FILTER_ENABLED", True)
monkeypatch.setattr("agent.privacy_filter._FORCE_FILTER", True)
from agent.privacy_filter import (
filter_text,
filter_messages,
detect_sensitive,
has_sensitive_content,
is_remote_provider,
should_route_local,
prepare_for_remote,
)
# ═══════════════════════════════════════════════════════════════════════════
# filter_text — string-level redaction
# ═══════════════════════════════════════════════════════════════════════════
class TestEmailRedaction:
def test_simple_email(self):
result = filter_text("Contact me at alice@example.com for details.")
assert "alice@example.com" not in result
assert "a…@example.com" in result
def test_email_with_dots(self):
result = filter_text("john.doe+work@corp.co.uk")
assert "john.doe+work@corp.co.uk" not in result
def test_multiple_emails(self):
text = "CC: first@test.io and second@test.io"
result = filter_text(text)
assert "first@test.io" not in result
assert "second@test.io" not in result
def test_email_in_code_block(self):
text = "config: { email: 'dev@company.com' }"
result = filter_text(text)
assert "dev@company.com" not in result
class TestPhoneRedaction:
def test_e164_format(self):
result = filter_text("Call me at +14155551234")
assert "+14155551234" not in result
assert "1234" in result # last 4 visible
def test_us_with_dashes(self):
result = filter_text("Phone: 415-555-1234")
assert "415-555-1234" not in result
def test_us_with_parens(self):
result = filter_text("Phone: (415) 555-1234")
assert "415" not in result or "555-1234" not in result
def test_international(self):
result = filter_text("WhatsApp: +442071234567")
assert "+442071234567" not in result
def test_short_number_not_redacted(self):
# 4-digit extension should pass through
result = filter_text("Ext: 1234")
assert "1234" in result
class TestSSNRedaction:
def test_ssn(self):
result = filter_text("SSN: 123-45-6789")
assert "6789" in result or "[REDACTED" in result
assert "123-45-6789" not in result
def test_ssn_no_dashes(self):
result = filter_text("123 45 6789")
assert "123 45 6789" not in result
class TestWalletRedaction:
def test_bitcoin_legacy(self):
addr = "1BvBMSEYstWetqTFn5Au4m4GFg7xJaNVN2"
result = filter_text(f"Send to {addr}")
assert addr not in result
assert "1BvB" in result # prefix preserved
assert "NVN2" in result # suffix preserved
def test_bitcoin_bech32(self):
addr = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh"
result = filter_text(f"Wallet: {addr}")
assert addr not in result
assert "bc1" in result
def test_ethereum(self):
addr = "0x742d35Cc6634C0532925a3b844Bc9e7595f8Ca39"
result = filter_text(f"ETH: {addr}")
assert addr not in result
assert "0x" in result
assert "Ca39" in result
def test_multiple_wallets(self):
btc = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"
eth = "0x0000000000000000000000000000000000000000"
result = filter_text(f"{btc} and {eth}")
assert btc not in result
assert eth not in result
class TestPathRedaction:
def test_unix_home(self):
result = filter_text("File at /home/alice/secrets/key.pem")
assert "/home/alice/secrets" not in result
assert "/home" in result
def test_macos_home(self):
result = filter_text("Path: /Users/bob/Documents/taxes.pdf")
assert "/Users/bob/Documents" not in result
def test_windows_path(self):
result = filter_text("C:\\Users\\Charlie\\Desktop\\notes.txt")
assert "Charlie" not in result
def test_relative_path_unchanged(self):
text = "File: ./src/main.py"
result = filter_text(text)
assert result == text
def test_system_path_unchanged(self):
text = "Binary at /usr/local/bin/python"
assert filter_text(text) == text
class TestPrivateKeyRedaction:
def test_pem_key(self):
key = "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASC\n-----END PRIVATE KEY-----"
result = filter_text(f"Key: {key}")
assert "MIIEvQIBADAN" not in result
assert "[REDACTED" in result
def test_rsa_key(self):
key = "-----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY-----"
result = filter_text(key)
assert "data" not in result
class TestPassthrough:
def test_normal_text(self):
text = "Hello, please write a function that sorts a list."
assert filter_text(text) == text
def test_code(self):
text = "def hello():\n print('world')\n return 42"
assert filter_text(text) == text
def test_empty_string(self):
assert filter_text("") == ""
def test_none(self):
assert filter_text(None) == ""
def test_technical_discussion(self):
text = "The model uses CUDA 12.1 with tensor cores for FP16."
assert filter_text(text) == text
def test_api_url_unchanged(self):
text = "Connect to https://api.openai.com/v1/chat/completions"
assert filter_text(text) == text
# ═══════════════════════════════════════════════════════════════════════════
# detect_sensitive — category detection
# ═══════════════════════════════════════════════════════════════════════════
class TestDetection:
def test_no_pii(self):
assert detect_sensitive("Hello world") == []
def test_detects_email(self):
cats = detect_sensitive("Email me at alice@example.com")
assert "email" in cats
def test_detects_phone(self):
cats = detect_sensitive("Call +14155551234")
assert "phone" in cats
def test_detects_wallet(self):
cats = detect_sensitive("My BTC: 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa")
assert "bitcoin_wallet" in cats
def test_detects_eth(self):
addr = "0x742d35Cc6634C0532925a3b844Bc9e7595f8Ca39"
cats = detect_sensitive(f"ETH addr: {addr}")
assert "ethereum_wallet" in cats
def test_detects_multiple(self):
cats = detect_sensitive("alice@test.com +14155551234")
assert "email" in cats
assert "phone" in cats
def test_empty(self):
assert detect_sensitive("") == []
def test_none(self):
assert detect_sensitive(None) == []
# ═══════════════════════════════════════════════════════════════════════════
# filter_messages — message list level
# ═══════════════════════════════════════════════════════════════════════════
class TestMessageFiltering:
def test_filters_content_string(self):
messages = [
{"role": "user", "content": "My email is bob@example.com, please remember it."}
]
result = filter_messages(messages)
assert "bob@example.com" not in result[0]["content"]
# Original unchanged (deep copy)
assert "bob@example.com" in messages[0]["content"]
def test_filters_content_parts(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Here's my SSN: 123-45-6789"},
{"type": "image_url", "image_url": {"url": "https://img.com/a.png"}},
],
}
]
result = filter_messages(messages)
text_part = [p for p in result[0]["content"] if p.get("type") == "text"][0]
assert "123-45-6789" not in text_part["text"]
# Image URL untouched
img_part = [p for p in result[0]["content"] if p.get("type") == "image_url"][0]
assert img_part["image_url"]["url"] == "https://img.com/a.png"
def test_filters_tool_call_arguments(self):
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "send_email",
"arguments": '{"to": "alice@example.com", "body": "Hi Alice"}',
},
}
],
}
]
result = filter_messages(messages)
args_str = result[0]["tool_calls"][0]["function"]["arguments"]
assert "alice@example.com" not in args_str
def test_preserves_system_message(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
]
result = filter_messages(messages)
assert result[0]["content"] == "You are a helpful assistant."
assert result[1]["content"] == "Hello!"
def test_deep_copy_safety(self):
original = [{"role": "user", "content": "test@example.com is my email"}]
result = filter_messages(original)
# Modifying result doesn't affect original
result[0]["content"] = "modified"
assert "test@example.com" in original[0]["content"]
def test_handles_none_content(self):
messages = [{"role": "assistant", "content": None, "tool_calls": []}]
result = filter_messages(messages)
assert result[0]["content"] is None
def test_handles_empty_messages(self):
assert filter_messages([]) == []
def test_preserves_tool_result_content(self):
messages = [
{
"role": "tool",
"content": "Found file at /usr/bin/secret but paths like /home/alice/x should be redacted",
"tool_call_id": "call_123",
}
]
result = filter_messages(messages)
assert "/home/alice" not in result[0]["content"]
assert "/usr/bin" in result[0]["content"] # system path preserved
# ═══════════════════════════════════════════════════════════════════════════
# has_sensitive_content — message-level detection
# ═══════════════════════════════════════════════════════════════════════════
class TestHasSensitiveContent:
def test_clean_messages(self):
messages = [{"role": "user", "content": "Write me a poem"}]
assert has_sensitive_content(messages) == []
def test_email_detected(self):
messages = [{"role": "user", "content": "email me at a@b.com"}]
cats = has_sensitive_content(messages)
assert "email" in cats
def test_tool_args_scanned(self):
messages = [
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"function": {
"name": "search",
"arguments": '{"query": "user +14155551234"}',
}
}
],
}
]
cats = has_sensitive_content(messages)
assert "phone" in cats
# ═══════════════════════════════════════════════════════════════════════════
# Provider routing
# ═══════════════════════════════════════════════════════════════════════════
class TestProviderRouting:
def test_remote_openai(self):
assert is_remote_provider("https://api.openai.com/v1") is True
def test_remote_openrouter(self):
assert is_remote_provider("https://openrouter.ai/api/v1") is True
def test_local_localhost(self):
assert is_remote_provider("http://localhost:11434/v1") is False
def test_local_127(self):
assert is_remote_provider("http://127.0.0.1:8080/v1") is False
def test_empty_assumes_local(self):
assert is_remote_provider("") is False
def test_route_local_with_pii(self):
messages = [{"role": "user", "content": "My email: a@b.com"}]
should, reasons = should_route_local(messages, "https://api.openai.com/v1")
assert should is True
assert "email" in reasons
def test_no_route_without_pii(self):
messages = [{"role": "user", "content": "Hello!"}]
should, reasons = should_route_local(messages, "https://api.openai.com/v1")
assert should is False
def test_no_route_for_local_provider(self):
messages = [{"role": "user", "content": "Email: a@b.com"}]
should, reasons = should_route_local(messages, "http://localhost:11434/v1")
assert should is False
# ═══════════════════════════════════════════════════════════════════════════
# prepare_for_remote — integration hook
# ═══════════════════════════════════════════════════════════════════════════
class TestPrepareForRemote:
def test_filters_remote_with_pii(self):
messages = [
{"role": "user", "content": "Send to alice@test.com, wallet 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"},
]
result, cats = prepare_for_remote(messages, "https://api.openai.com/v1")
assert "alice@test.com" not in result[0]["content"]
assert "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa" not in result[0]["content"]
assert "email" in cats
assert "bitcoin_wallet" in cats
def test_passes_through_local(self):
messages = [{"role": "user", "content": "Email: a@b.com"}]
result, cats = prepare_for_remote(messages, "http://localhost:11434/v1")
assert result is messages # same object
assert cats == []
def test_passes_through_clean_remote(self):
messages = [{"role": "user", "content": "Sort this list"}]
result, cats = prepare_for_remote(messages, "https://api.openai.com/v1")
assert cats == []
assert result[0]["content"] == "Sort this list"
def test_realistic_conversation(self):
"""Full conversation with mixed sensitive and safe messages."""
messages = [
{"role": "system", "content": "You are a helpful coding assistant."},
{"role": "user", "content": "Help me write a Python HTTP server."},
{"role": "assistant", "content": "Here's a simple example:\n```python\nimport http.server\n```"},
{"role": "user", "content": "Great! Now deploy it to my server at /home/deploy/app. My email is admin@mycompany.com"},
]
result, cats = prepare_for_remote(messages, "https://api.openai.com/v1")
# Safe messages unchanged
assert result[0]["content"] == messages[0]["content"]
assert result[1]["content"] == messages[1]["content"]
# Sensitive message filtered
assert "admin@mycompany.com" not in result[3]["content"]
assert "/home/deploy" not in result[3]["content"]
assert "email" in cats
assert "user_path_unix" in cats

View File

@@ -0,0 +1,52 @@
"""Tests for weak credential guard in gateway/config.py."""
import os
import pytest
from gateway.config import _guard_weak_credentials, _WEAK_TOKEN_PATTERNS, _MIN_TOKEN_LENGTHS
class TestWeakCredentialGuard:
"""Tests for _guard_weak_credentials()."""
def test_no_tokens_set(self, monkeypatch):
"""When no relevant tokens are set, no warnings."""
for var in _MIN_TOKEN_LENGTHS:
monkeypatch.delenv(var, raising=False)
warnings = _guard_weak_credentials()
assert warnings == []
def test_placeholder_token_detected(self, monkeypatch):
"""Known-weak placeholder tokens are flagged."""
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "your-token-here")
warnings = _guard_weak_credentials()
assert len(warnings) == 1
assert "TELEGRAM_BOT_TOKEN" in warnings[0]
assert "placeholder" in warnings[0].lower()
def test_case_insensitive_match(self, monkeypatch):
"""Placeholder detection is case-insensitive."""
monkeypatch.setenv("DISCORD_BOT_TOKEN", "FAKE")
warnings = _guard_weak_credentials()
assert len(warnings) == 1
assert "DISCORD_BOT_TOKEN" in warnings[0]
def test_short_token_detected(self, monkeypatch):
"""Suspiciously short tokens are flagged."""
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "abc123") # 6 chars, min is 30
warnings = _guard_weak_credentials()
assert len(warnings) == 1
assert "short" in warnings[0].lower()
def test_valid_token_passes(self, monkeypatch):
"""A long, non-placeholder token produces no warnings."""
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567890:ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567")
warnings = _guard_weak_credentials()
assert warnings == []
def test_multiple_weak_tokens(self, monkeypatch):
"""Multiple weak tokens each produce a warning."""
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "change-me")
monkeypatch.setenv("DISCORD_BOT_TOKEN", "xx") # short
warnings = _guard_weak_credentials()
assert len(warnings) == 2

View File

@@ -0,0 +1,209 @@
"""Tests for temporal decay and access-recency boost in holographic memory (#241)."""
import math
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
class TestTemporalDecay:
"""Test _temporal_decay exponential decay formula."""
def _make_retriever(self, half_life=60):
from plugins.memory.holographic.retrieval import FactRetriever
store = MagicMock()
return FactRetriever(store=store, temporal_decay_half_life=half_life)
def test_fresh_fact_no_decay(self):
"""A fact updated today should have decay ≈ 1.0."""
r = self._make_retriever(half_life=60)
now = datetime.now(timezone.utc).isoformat()
decay = r._temporal_decay(now)
assert decay > 0.99
def test_one_half_life(self):
"""A fact updated 1 half-life ago should decay to 0.5."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=60)).isoformat()
decay = r._temporal_decay(old)
assert abs(decay - 0.5) < 0.01
def test_two_half_lives(self):
"""A fact updated 2 half-lives ago should decay to 0.25."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=120)).isoformat()
decay = r._temporal_decay(old)
assert abs(decay - 0.25) < 0.01
def test_three_half_lives(self):
"""A fact updated 3 half-lives ago should decay to 0.125."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=180)).isoformat()
decay = r._temporal_decay(old)
assert abs(decay - 0.125) < 0.01
def test_half_life_disabled(self):
"""When half_life=0, decay should always be 1.0."""
r = self._make_retriever(half_life=0)
old = (datetime.now(timezone.utc) - timedelta(days=365)).isoformat()
assert r._temporal_decay(old) == 1.0
def test_none_timestamp(self):
"""Missing timestamp should return 1.0 (no decay)."""
r = self._make_retriever(half_life=60)
assert r._temporal_decay(None) == 1.0
def test_empty_timestamp(self):
r = self._make_retriever(half_life=60)
assert r._temporal_decay("") == 1.0
def test_invalid_timestamp(self):
"""Malformed timestamp should return 1.0 (fail open)."""
r = self._make_retriever(half_life=60)
assert r._temporal_decay("not-a-date") == 1.0
def test_future_timestamp(self):
"""Future timestamp should return 1.0 (no decay for future dates)."""
r = self._make_retriever(half_life=60)
future = (datetime.now(timezone.utc) + timedelta(days=10)).isoformat()
assert r._temporal_decay(future) == 1.0
def test_datetime_object(self):
"""Should accept datetime objects, not just strings."""
r = self._make_retriever(half_life=60)
old = datetime.now(timezone.utc) - timedelta(days=60)
decay = r._temporal_decay(old)
assert abs(decay - 0.5) < 0.01
def test_different_half_lives(self):
"""30-day half-life should decay faster than 90-day."""
r30 = self._make_retriever(half_life=30)
r90 = self._make_retriever(half_life=90)
old = (datetime.now(timezone.utc) - timedelta(days=45)).isoformat()
assert r30._temporal_decay(old) < r90._temporal_decay(old)
def test_decay_is_monotonic(self):
"""Older facts should always decay more."""
r = self._make_retriever(half_life=60)
now = datetime.now(timezone.utc)
d1 = r._temporal_decay((now - timedelta(days=10)).isoformat())
d2 = r._temporal_decay((now - timedelta(days=30)).isoformat())
d3 = r._temporal_decay((now - timedelta(days=60)).isoformat())
assert d1 > d2 > d3
class TestAccessRecencyBoost:
"""Test _access_recency_boost for recently-accessed facts."""
def _make_retriever(self, half_life=60):
from plugins.memory.holographic.retrieval import FactRetriever
store = MagicMock()
return FactRetriever(store=store, temporal_decay_half_life=half_life)
def test_just_accessed_max_boost(self):
"""A fact accessed just now should get maximum boost (1.5)."""
r = self._make_retriever(half_life=60)
now = datetime.now(timezone.utc).isoformat()
boost = r._access_recency_boost(now)
assert boost > 1.45 # Near 1.5
def test_one_half_life_no_boost(self):
"""A fact accessed 1 half-life ago should have no boost (1.0)."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=60)).isoformat()
boost = r._access_recency_boost(old)
assert abs(boost - 1.0) < 0.01
def test_half_way_boost(self):
"""A fact accessed 0.5 half-lives ago should get ~1.25 boost."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat()
boost = r._access_recency_boost(old)
assert abs(boost - 1.25) < 0.05
def test_beyond_one_half_life_no_boost(self):
"""Beyond 1 half-life, boost should be 1.0."""
r = self._make_retriever(half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=90)).isoformat()
boost = r._access_recency_boost(old)
assert boost == 1.0
def test_disabled_no_boost(self):
"""When half_life=0, boost should be 1.0."""
r = self._make_retriever(half_life=0)
now = datetime.now(timezone.utc).isoformat()
assert r._access_recency_boost(now) == 1.0
def test_none_timestamp(self):
r = self._make_retriever(half_life=60)
assert r._access_recency_boost(None) == 1.0
def test_invalid_timestamp(self):
r = self._make_retriever(half_life=60)
assert r._access_recency_boost("bad") == 1.0
def test_boost_range(self):
"""Boost should always be in [1.0, 1.5]."""
r = self._make_retriever(half_life=60)
now = datetime.now(timezone.utc)
for days in [0, 1, 15, 30, 45, 59, 60, 90, 365]:
ts = (now - timedelta(days=days)).isoformat()
boost = r._access_recency_boost(ts)
assert 1.0 <= boost <= 1.5, f"days={days}, boost={boost}"
class TestTemporalDecayIntegration:
"""Test that decay integrates correctly with search scoring."""
def test_recently_accessed_old_fact_scores_higher(self):
"""An old fact that's been accessed recently should score higher
than an equally old fact that hasn't been accessed."""
from plugins.memory.holographic.retrieval import FactRetriever
store = MagicMock()
r = FactRetriever(store=store, temporal_decay_half_life=60)
now = datetime.now(timezone.utc)
old_date = (now - timedelta(days=120)).isoformat() # 2 half-lives old
recent_access = (now - timedelta(days=10)).isoformat() # accessed 10 days ago
old_access = (now - timedelta(days=200)).isoformat() # accessed 200 days ago
# Old fact, recently accessed
decay1 = r._temporal_decay(old_date)
boost1 = r._access_recency_boost(recent_access)
effective1 = min(1.0, decay1 * boost1)
# Old fact, not recently accessed
decay2 = r._temporal_decay(old_date)
boost2 = r._access_recency_boost(old_access)
effective2 = min(1.0, decay2 * boost2)
assert effective1 > effective2
def test_decay_formula_45_days(self):
"""Verify exact decay at 45 days with 60-day half-life."""
from plugins.memory.holographic.retrieval import FactRetriever
r = FactRetriever(store=MagicMock(), temporal_decay_half_life=60)
old = (datetime.now(timezone.utc) - timedelta(days=45)).isoformat()
decay = r._temporal_decay(old)
expected = math.pow(0.5, 45/60)
assert abs(decay - expected) < 0.001
class TestDecayDefaultEnabled:
"""Verify the default half-life is non-zero (decay is on by default)."""
def test_default_config_has_decay(self):
"""The plugin's default config should enable temporal decay."""
from plugins.memory.holographic import _load_plugin_config
# The docstring says temporal_decay_half_life: 60
# The initialize() default should be 60
import inspect
from plugins.memory.holographic import HolographicMemoryProvider
src = inspect.getsource(HolographicMemoryProvider.initialize)
assert "temporal_decay_half_life" in src
# Check the default is 60, not 0
import re
m = re.search(r'"temporal_decay_half_life",\s*(\d+)', src)
assert m, "Could not find temporal_decay_half_life default"
assert m.group(1) == "60", f"Default is {m.group(1)}, expected 60"

View File

@@ -137,3 +137,78 @@ class TestBackwardCompat:
def test_tool_to_toolset_map(self):
assert isinstance(TOOL_TO_TOOLSET_MAP, dict)
assert len(TOOL_TO_TOOLSET_MAP) > 0
class TestToolReturnTypeValidation:
"""Poka-yoke: tool handlers must return JSON strings."""
def test_handler_returning_dict_is_wrapped(self, monkeypatch):
"""A handler that returns a dict should be auto-wrapped to JSON string."""
from tools.registry import registry
from model_tools import handle_function_call
import json
# Register a bad handler that returns dict instead of str
registry.register(
name="__test_bad_dict",
toolset="test",
schema={"name": "__test_bad_dict", "description": "test", "parameters": {"type": "object", "properties": {}}},
handler=lambda args, **kw: {"this is": "a dict not a string"},
)
result = handle_function_call("__test_bad_dict", {})
parsed = json.loads(result)
assert "output" in parsed
assert "_type_warning" in parsed
# Cleanup
registry._tools.pop("__test_bad_dict", None)
def test_handler_returning_none_is_wrapped(self, monkeypatch):
"""A handler that returns None should be auto-wrapped."""
from tools.registry import registry
from model_tools import handle_function_call
import json
registry.register(
name="__test_bad_none",
toolset="test",
schema={"name": "__test_bad_none", "description": "test", "parameters": {"type": "object", "properties": {}}},
handler=lambda args, **kw: None,
)
result = handle_function_call("__test_bad_none", {})
parsed = json.loads(result)
assert "_type_warning" in parsed
registry._tools.pop("__test_bad_none", None)
def test_handler_returning_non_json_string_is_wrapped(self):
"""A handler returning a plain string (not JSON) should be wrapped."""
from tools.registry import registry
from model_tools import handle_function_call
import json
registry.register(
name="__test_bad_plain",
toolset="test",
schema={"name": "__test_bad_plain", "description": "test", "parameters": {"type": "object", "properties": {}}},
handler=lambda args, **kw: "just a plain string, not json",
)
result = handle_function_call("__test_bad_plain", {})
parsed = json.loads(result)
assert "output" in parsed
registry._tools.pop("__test_bad_plain", None)
def test_handler_returning_valid_json_passes_through(self):
"""A handler returning valid JSON string passes through unchanged."""
from tools.registry import registry
from model_tools import handle_function_call
import json
registry.register(
name="__test_good",
toolset="test",
schema={"name": "__test_good", "description": "test", "parameters": {"type": "object", "properties": {}}},
handler=lambda args, **kw: json.dumps({"status": "ok", "data": [1, 2, 3]}),
)
result = handle_function_call("__test_good", {})
parsed = json.loads(result)
assert parsed == {"status": "ok", "data": [1, 2, 3]}
registry._tools.pop("__test_good", None)

View File

@@ -144,7 +144,8 @@ class TestMemoryStoreReplace:
def test_replace_no_match(self, store):
store.add("memory", "fact A")
result = store.replace("memory", "nonexistent", "new")
assert result["success"] is False
assert result["success"] is True
assert result["result"] == "no_match"
def test_replace_ambiguous_match(self, store):
store.add("memory", "server A runs nginx")
@@ -177,7 +178,8 @@ class TestMemoryStoreRemove:
def test_remove_no_match(self, store):
result = store.remove("memory", "nonexistent")
assert result["success"] is False
assert result["success"] is True
assert result["result"] == "no_match"
def test_remove_empty_old_text(self, store):
result = store.remove("memory", " ")

View File

@@ -0,0 +1,107 @@
"""Tests for syntax preflight check in execute_code (issue #312)."""
import ast
import json
import pytest
class TestSyntaxPreflight:
"""Verify that execute_code catches syntax errors before sandbox execution."""
def test_valid_syntax_passes_parse(self):
"""Valid Python should pass ast.parse."""
code = "print('hello')\nx = 1 + 2\n"
ast.parse(code) # should not raise
def test_syntax_error_indentation(self):
"""IndentationError is a subclass of SyntaxError."""
code = "def foo():\nbar()\n"
with pytest.raises(SyntaxError):
ast.parse(code)
def test_syntax_error_missing_colon(self):
code = "if True\n pass\n"
with pytest.raises(SyntaxError):
ast.parse(code)
def test_syntax_error_unmatched_paren(self):
code = "x = (1 + 2\n"
with pytest.raises(SyntaxError):
ast.parse(code)
def test_syntax_error_invalid_token(self):
code = "x = 1 +*\n"
with pytest.raises(SyntaxError):
ast.parse(code)
def test_syntax_error_details(self):
"""SyntaxError should provide line, offset, msg."""
code = "if True\n pass\n"
with pytest.raises(SyntaxError) as exc_info:
ast.parse(code)
e = exc_info.value
assert e.lineno is not None
assert e.msg is not None
def test_empty_string_passes(self):
"""Empty string is valid Python (empty module)."""
ast.parse("")
def test_comments_only_passes(self):
ast.parse("# just a comment\n# another\n")
def test_complex_valid_code(self):
code = '''
import os
def foo(x):
if x > 0:
return x * 2
return 0
result = [foo(i) for i in range(10)]
print(result)
'''
ast.parse(code)
class TestSyntaxPreflightResponse:
"""Test the error response format from the preflight check."""
def _check_syntax(self, code):
"""Mimic the preflight check logic from execute_code."""
try:
ast.parse(code)
return None
except SyntaxError as e:
return json.dumps({
"error": f"Python syntax error: {e.msg}",
"line": e.lineno,
"offset": e.offset,
"text": (e.text or "").strip()[:200],
})
def test_returns_json_error(self):
result = self._check_syntax("if True\n pass\n")
assert result is not None
data = json.loads(result)
assert "error" in data
assert "syntax error" in data["error"].lower()
def test_includes_line_number(self):
result = self._check_syntax("x = 1\nif True\n pass\n")
data = json.loads(result)
assert data["line"] == 2 # error on line 2
def test_includes_offset(self):
result = self._check_syntax("x = (1 + 2\n")
data = json.loads(result)
assert data["offset"] is not None
def test_includes_snippet(self):
result = self._check_syntax("if True\n")
data = json.loads(result)
assert "if True" in data["text"]
def test_none_for_valid_code(self):
result = self._check_syntax("print('ok')")
assert result is None

View File

@@ -28,6 +28,7 @@ Platform: Linux / macOS only (Unix domain sockets for local). Disabled on Window
Remote execution additionally requires Python 3 in the terminal backend.
"""
import ast
import base64
import json
import logging
@@ -893,6 +894,20 @@ def execute_code(
if not code or not code.strip():
return json.dumps({"error": "No code provided."})
# Poka-yoke (#312): Syntax check before execution.
# 83.2% of execute_code errors are Python exceptions; most are syntax
# errors the LLM generated. ast.parse() is sub-millisecond and catches
# them before we spin up a sandbox child process.
try:
ast.parse(code)
except SyntaxError as e:
return json.dumps({
"error": f"Python syntax error: {e.msg}",
"line": e.lineno,
"offset": e.offset,
"text": (e.text or "").strip()[:200],
})
# Dispatch: remote backends use file-based RPC, local uses UDS
from tools.terminal_tool import _get_env_config
env_type = _get_env_config()["env_type"]

View File

@@ -260,8 +260,12 @@ class MemoryStore:
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if not matches:
return {
"success": True,
"result": "no_match",
"message": f"No entry matched '{old_text}'. The search substring was not found in any existing entry.",
}
if len(matches) > 1:
# If all matches are identical (exact duplicates), operate on the first one
@@ -310,8 +314,12 @@ class MemoryStore:
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if not matches:
return {
"success": True,
"result": "no_match",
"message": f"No entry matched '{old_text}'. The search substring was not found in any existing entry.",
}
if len(matches) > 1:
# If all matches are identical (exact duplicates), remove the first one
@@ -449,30 +457,30 @@ def memory_tool(
Returns JSON string with results.
"""
if store is None:
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
return tool_error("Memory is not available. It may be disabled in config or this environment.", success=False)
if target not in ("memory", "user"):
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
return tool_error(f"Invalid target '{target}'. Use 'memory' or 'user'.", success=False)
if action == "add":
if not content:
return json.dumps({"success": False, "error": "Content is required for 'add' action."}, ensure_ascii=False)
return tool_error("Content is required for 'add' action.", success=False)
result = store.add(target, content)
elif action == "replace":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
return tool_error("old_text is required for 'replace' action.", success=False)
if not content:
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
return tool_error("content is required for 'replace' action.", success=False)
result = store.replace(target, old_text, content)
elif action == "remove":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
return tool_error("old_text is required for 'remove' action.", success=False)
result = store.remove(target, old_text)
else:
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
return tool_error(f"Unknown action '{action}'. Use: add, replace, remove", success=False)
return json.dumps(result, ensure_ascii=False)
@@ -539,7 +547,7 @@ MEMORY_SCHEMA = {
# --- Registry ---
from tools.registry import registry
from tools.registry import registry, tool_error
registry.register(
name="memory",