Compare commits
2 Commits
claude/iss
...
fix/923
| Author | SHA1 | Date | |
|---|---|---|---|
| d27ca6d39a | |||
| c6f2855745 |
@@ -8,7 +8,6 @@ Handles loading and validating configuration for:
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -680,26 +679,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose a server beyond the loopback interface.
|
||||
|
||||
Duplicates the logic in ``gateway.platforms.base.is_network_accessible``
|
||||
without creating a circular import (base.py imports from this module).
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.x.x.x — Python's is_loopback returns False for
|
||||
# IPv4-mapped loopback; unwrap and check the underlying IPv4.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# Hostname: assume it could be network-accessible.
|
||||
return True
|
||||
|
||||
|
||||
def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
"""Validate and sanitize a loaded GatewayConfig in place.
|
||||
|
||||
@@ -768,22 +747,6 @@ def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
)
|
||||
pconfig.enabled = False
|
||||
|
||||
# Warn when the API server is enabled on a network-accessible address
|
||||
# without an auth key. The adapter will refuse to start anyway, but
|
||||
# surfacing this at config-load time lets operators see the problem in
|
||||
# the startup log before any platform adapter initialisation runs.
|
||||
api_cfg = config.platforms.get(Platform.API_SERVER)
|
||||
if api_cfg and api_cfg.enabled:
|
||||
key = api_cfg.extra.get("key", "")
|
||||
host = api_cfg.extra.get("host", "127.0.0.1")
|
||||
if not key and _is_network_accessible(host):
|
||||
logger.warning(
|
||||
"API Server is enabled on %s but API_SERVER_KEY is not set. "
|
||||
"The adapter will refuse to start on a network-accessible address. "
|
||||
"Set API_SERVER_KEY or bind to 127.0.0.1 for local-only access.",
|
||||
host,
|
||||
)
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
@@ -10,7 +10,6 @@ from gateway.config import (
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
_apply_env_overrides,
|
||||
_validate_gateway_config,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
@@ -295,151 +294,3 @@ class TestHomeChannelEnvOverrides:
|
||||
home = config.platforms[platform].home_channel
|
||||
assert home is not None, f"{platform.value}: home_channel should not be None"
|
||||
assert (home.chat_id, home.name) == expected, platform.value
|
||||
|
||||
|
||||
class TestValidateGatewayConfig:
|
||||
"""Tests for _validate_gateway_config — in-place sanitisation of loaded config."""
|
||||
|
||||
# -- idle_minutes validation --
|
||||
|
||||
def test_idle_minutes_zero_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 0
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = -60
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_none_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = None # type: ignore[assignment]
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_valid_idle_minutes_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 90
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 90
|
||||
|
||||
# -- at_hour validation --
|
||||
|
||||
def test_at_hour_too_high_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 24
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_at_hour_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = -1
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_valid_at_hour_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 3
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 3
|
||||
|
||||
def test_at_hour_boundary_values_are_valid(self):
|
||||
for valid_hour in (0, 23):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = valid_hour
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == valid_hour
|
||||
|
||||
# -- empty-token warning (enabled platforms) --
|
||||
|
||||
def test_empty_string_token_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"TELEGRAM_BOT_TOKEN" in r.message and "empty" in r.message
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_disabled_platform_with_empty_token_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any("TELEGRAM_BOT_TOKEN" in r.message for r in caplog.records)
|
||||
|
||||
# -- API Server key / binding warnings --
|
||||
|
||||
def test_api_server_network_binding_without_key_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_loopback_without_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "127.0.0.1"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_network_binding_with_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0", "key": "sk-real-key-here"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_default_loopback_without_key_no_warning(self, caplog):
|
||||
"""API server with no explicit host defaults to 127.0.0.1 — no warning."""
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(enabled=True),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
122
tools/skill_edit_guard.py
Normal file
122
tools/skill_edit_guard.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Skill Edit Guard — Poka-yoke auto-revert for incomplete skill edits.
|
||||
|
||||
Creates atomic skill edits with automatic rollback on failure.
|
||||
Prevents broken skills from corrupting future sessions.
|
||||
|
||||
Usage:
|
||||
from tools.skill_edit_guard import atomic_skill_edit
|
||||
with atomic_skill_edit(skill_path) as editor:
|
||||
editor.write(new_content)
|
||||
# If exception occurs, file is automatically reverted
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillEditGuard:
|
||||
"""Atomic skill file editing with auto-revert on failure."""
|
||||
|
||||
def __init__(self, skill_path: str):
|
||||
self._path = Path(skill_path)
|
||||
self._backup: Optional[Path] = None
|
||||
self._committed = False
|
||||
|
||||
def backup(self) -> bool:
|
||||
"""Create backup before editing."""
|
||||
if not self._path.exists():
|
||||
return True # New file, nothing to backup
|
||||
|
||||
backup_dir = self._path.parent / ".skill_backups"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
ts = int(time.time() * 1000)
|
||||
self._backup = backup_dir / f"{self._path.name}.{ts}.bak"
|
||||
shutil.copy2(self._path, self._backup)
|
||||
logger.debug("Skill backup created: %s", self._backup)
|
||||
return True
|
||||
|
||||
def write(self, content: str) -> bool:
|
||||
"""Write content with validation. Returns True if valid."""
|
||||
# Validate YAML frontmatter
|
||||
if content.startswith("---"):
|
||||
end = content.find("---", 3)
|
||||
if end < 0:
|
||||
logger.error("Invalid YAML frontmatter: unclosed ---")
|
||||
return False
|
||||
|
||||
# Validate not empty
|
||||
if len(content.strip()) < 10:
|
||||
logger.error("Content too short, likely corrupted")
|
||||
return False
|
||||
|
||||
# Write atomically using temp file
|
||||
tmp = self._path.with_suffix(".tmp")
|
||||
try:
|
||||
tmp.write_text(content, encoding="utf-8")
|
||||
tmp.rename(self._path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Write failed: %s", e)
|
||||
if tmp.exists():
|
||||
tmp.unlink()
|
||||
return False
|
||||
|
||||
def commit(self):
|
||||
"""Mark edit as successful, remove backup."""
|
||||
self._committed = True
|
||||
if self._backup and self._backup.exists():
|
||||
self._backup.unlink()
|
||||
logger.debug("Skill backup removed: %s", self._backup)
|
||||
|
||||
def rollback(self) -> bool:
|
||||
"""Revert to backup."""
|
||||
if self._backup and self._backup.exists():
|
||||
shutil.copy2(self._backup, self._path)
|
||||
self._backup.unlink()
|
||||
logger.warning("Skill reverted from backup: %s", self._path)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __enter__(self):
|
||||
self.backup()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None:
|
||||
self.rollback()
|
||||
return False # Re-raise exception
|
||||
if not self._committed:
|
||||
self.rollback()
|
||||
return False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def atomic_skill_edit(skill_path: str):
|
||||
"""Context manager for atomic skill editing.
|
||||
|
||||
Usage:
|
||||
with atomic_skill_edit("/path/to/skill/SKILL.md") as editor:
|
||||
success = editor.write(new_content)
|
||||
if not success:
|
||||
raise ValueError("Write failed")
|
||||
# __exit__ commits on success, reverts on exception
|
||||
"""
|
||||
guard = SkillEditGuard(skill_path)
|
||||
guard.backup()
|
||||
try:
|
||||
yield guard
|
||||
guard.commit()
|
||||
except Exception:
|
||||
guard.rollback()
|
||||
raise
|
||||
@@ -44,6 +44,34 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user