Compare commits
4 Commits
fix/format
...
burn/923-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 29177ba850 | |||
| 3811d470d1 | |||
| 9bc90dff02 | |||
| c6f2855745 |
146
tests/test_skill_auto_revert.py
Normal file
146
tests/test_skill_auto_revert.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for poka-yoke skill edit auto-revert."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_skill_md
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateSkillMd:
|
||||
def test_valid_content(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
content = "---\nname: test\n---\n\nSome skill content here."
|
||||
assert _validate_skill_md(content) is None
|
||||
|
||||
def test_empty_content(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
assert "empty" in _validate_skill_md("").lower()
|
||||
assert "empty" in _validate_skill_md(" ").lower()
|
||||
|
||||
def test_too_small(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
assert "small" in _validate_skill_md("abc").lower()
|
||||
|
||||
def test_missing_frontmatter(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
content = "No frontmatter here. Just content."
|
||||
assert "frontmatter" in _validate_skill_md(content).lower()
|
||||
|
||||
def test_unclosed_frontmatter(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
content = "---\nname: test\nNo closing dashes"
|
||||
assert "closing" in _validate_skill_md(content).lower()
|
||||
|
||||
def test_missing_name_field(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
content = "---\nversion: 1.0\n---\n\nContent here."
|
||||
assert "name" in _validate_skill_md(content).lower()
|
||||
|
||||
def test_invalid_yaml(self):
|
||||
from tools.skill_manager_tool import _validate_skill_md
|
||||
content = "---\n: invalid: yaml: {{{\n---\n\nContent"
|
||||
result = _validate_skill_md(content)
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _create_backup / _restore_backup / _cleanup_old_backups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackupOps:
|
||||
def test_create_and_restore(self, tmp_path):
|
||||
from tools.skill_manager_tool import _create_backup, _restore_backup
|
||||
target = tmp_path / "SKILL.md"
|
||||
target.write_text("original content")
|
||||
|
||||
bak = _create_backup(target)
|
||||
assert bak is not None
|
||||
assert bak.exists()
|
||||
|
||||
target.write_text("corrupted content")
|
||||
assert target.read_text() == "corrupted content"
|
||||
|
||||
_restore_backup(bak, target)
|
||||
assert target.read_text() == "original content"
|
||||
|
||||
def test_create_backup_no_file(self, tmp_path):
|
||||
from tools.skill_manager_tool import _create_backup
|
||||
target = tmp_path / "nonexistent.md"
|
||||
assert _create_backup(target) is None
|
||||
|
||||
def test_cleanup_keeps_recent(self, tmp_path):
|
||||
from tools.skill_manager_tool import _cleanup_old_backups
|
||||
target = tmp_path / "SKILL.md"
|
||||
target.write_text("content")
|
||||
|
||||
# Create 5 backups
|
||||
import time
|
||||
for i in range(5):
|
||||
bak = target.parent / f".SKILL.md.bak.2026010{i}_00000{i}"
|
||||
bak.write_text(f"backup {i}")
|
||||
time.sleep(0.01) # ensure different mtime
|
||||
|
||||
_cleanup_old_backups(target, keep=3)
|
||||
|
||||
remaining = list(target.parent.glob(".SKILL.md.bak.*"))
|
||||
assert len(remaining) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transactional_write
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTransactionalWrite:
|
||||
def test_success_no_validation(self, tmp_path):
|
||||
from tools.skill_manager_tool import _transactional_write
|
||||
target = tmp_path / "test.txt"
|
||||
ok, err, bak = _transactional_write(target, "new content")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
assert target.read_text() == "new content"
|
||||
|
||||
def test_success_with_validation(self, tmp_path):
|
||||
from tools.skill_manager_tool import _transactional_write, _validate_skill_md
|
||||
target = tmp_path / "SKILL.md"
|
||||
content = "---\nname: test\n---\n\nValid content."
|
||||
ok, err, bak = _transactional_write(target, content, validate_fn=_validate_skill_md)
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
def test_revert_on_validation_failure(self, tmp_path):
|
||||
from tools.skill_manager_tool import _transactional_write, _validate_skill_md
|
||||
target = tmp_path / "SKILL.md"
|
||||
target.write_text("---\nname: original\n---\n\nOriginal content.")
|
||||
|
||||
# Try to write invalid content (no frontmatter)
|
||||
ok, err, bak = _transactional_write(target, "broken content", validate_fn=_validate_skill_md)
|
||||
assert ok is False
|
||||
assert err is not None
|
||||
# Original content restored
|
||||
assert "original" in target.read_text()
|
||||
|
||||
def test_revert_empty_content(self, tmp_path):
|
||||
from tools.skill_manager_tool import _transactional_write, _validate_skill_md
|
||||
target = tmp_path / "SKILL.md"
|
||||
target.write_text("---\nname: good\n---\n\nGood content.")
|
||||
|
||||
ok, err, bak = _transactional_write(target, "", validate_fn=_validate_skill_md)
|
||||
assert ok is False
|
||||
assert "empty" in err.lower()
|
||||
assert "good" in target.read_text()
|
||||
|
||||
def test_new_file_reverted_on_failure(self, tmp_path):
|
||||
from tools.skill_manager_tool import _transactional_write, _validate_skill_md
|
||||
target = tmp_path / "new_skill" / "SKILL.md"
|
||||
|
||||
ok, err, bak = _transactional_write(target, "bad", validate_fn=_validate_skill_md)
|
||||
assert ok is False
|
||||
# File should not exist (was new, no backup to restore)
|
||||
assert not target.exists()
|
||||
@@ -44,6 +44,34 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
@@ -284,6 +312,103 @@ def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -
|
||||
logger.error("Failed to remove temporary file %s during atomic write", temp_path, exc_info=True)
|
||||
raise
|
||||
|
||||
# =============================================================================
|
||||
# Poka-yoke: Transactional writes with auto-revert
|
||||
# =============================================================================
|
||||
|
||||
_MAX_BACKUPS_PER_SKILL = 3
|
||||
|
||||
|
||||
def _create_backup(file_path):
|
||||
"""Create a .bak.{timestamp} backup before overwriting."""
|
||||
import time
|
||||
if not file_path.exists():
|
||||
return None
|
||||
ts = time.strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = file_path.parent / f".{file_path.name}.bak.{ts}"
|
||||
try:
|
||||
backup_path.write_bytes(file_path.read_bytes())
|
||||
return backup_path
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _restore_backup(backup_path, target_path):
|
||||
"""Restore a file from backup."""
|
||||
try:
|
||||
if backup_path.exists():
|
||||
target_path.write_bytes(backup_path.read_bytes())
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _cleanup_old_backups(file_path, keep=3):
|
||||
"""Remove old .bak files, keeping only the most recent."""
|
||||
try:
|
||||
parent = file_path.parent
|
||||
stem = file_path.name
|
||||
bak_prefix = f".{stem}.bak."
|
||||
backups = sorted(
|
||||
[f for f in parent.iterdir() if f.name.startswith(bak_prefix)],
|
||||
key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
for old in backups[keep:]:
|
||||
try:
|
||||
old.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _validate_skill_md(content):
|
||||
"""Post-write validation for SKILL.md. Returns error or None."""
|
||||
if not content or not content.strip():
|
||||
return "Content is empty after write"
|
||||
if len(content.strip()) < 10:
|
||||
return f"Content suspiciously small ({len(content.strip())} bytes) — possible truncation"
|
||||
if content.startswith("---"):
|
||||
end = content.find("---", 3)
|
||||
if end == -1:
|
||||
return "YAML frontmatter has opening --- but no closing ---"
|
||||
fm_text = content[3:end].strip()
|
||||
try:
|
||||
import yaml
|
||||
fm = yaml.safe_load(fm_text)
|
||||
if not isinstance(fm, dict):
|
||||
return "YAML frontmatter did not parse to a dict"
|
||||
if "name" not in fm:
|
||||
return "YAML frontmatter missing required 'name' field"
|
||||
except Exception as e:
|
||||
return f"YAML frontmatter parse error: {e}"
|
||||
else:
|
||||
return "Missing YAML frontmatter (must start with ---)"
|
||||
return None
|
||||
|
||||
|
||||
def _transactional_write(file_path, content, validate_fn=None, encoding="utf-8"):
|
||||
"""Write with backup, validation, and auto-revert. Returns (success, error, backup_path)."""
|
||||
backup_path = _create_backup(file_path)
|
||||
_atomic_write_text(file_path, content, encoding=encoding)
|
||||
validation_error = None
|
||||
if validate_fn:
|
||||
try:
|
||||
written = file_path.read_text(encoding=encoding)
|
||||
validation_error = validate_fn(written)
|
||||
except Exception as e:
|
||||
validation_error = f"Post-write read failed: {e}"
|
||||
if validation_error:
|
||||
if backup_path and backup_path.exists():
|
||||
_restore_backup(backup_path, file_path)
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return False, validation_error, backup_path
|
||||
_cleanup_old_backups(file_path)
|
||||
return True, None, backup_path
|
||||
|
||||
# =============================================================================
|
||||
# Core actions
|
||||
@@ -361,17 +486,15 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
|
||||
|
||||
skill_md = existing["path"] / "SKILL.md"
|
||||
# Back up original content for rollback
|
||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
ok, err, _ = _transactional_write(skill_md, content, validate_fn=_validate_skill_md)
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Edit reverted: {err}"}
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(skill_md, original_content)
|
||||
backups = sorted([f for f in skill_md.parent.iterdir() if f.name.startswith(f".{skill_md.name}.bak.")], key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
if backups:
|
||||
_restore_backup(backups[0], skill_md)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' updated.",
|
||||
@@ -452,15 +575,16 @@ def _patch_skill(
|
||||
"error": f"Patch would break SKILL.md structure: {err}",
|
||||
}
|
||||
|
||||
original_content = content # for rollback
|
||||
_atomic_write_text(target, new_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
validate = _validate_skill_md if not file_path else None
|
||||
ok, err, _ = _transactional_write(target, new_content, validate_fn=validate)
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Patch reverted: {err}"}
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
_atomic_write_text(target, original_content)
|
||||
backups = sorted([f for f in target.parent.iterdir() if f.name.startswith(f".{target.name}.bak.")], key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
if backups:
|
||||
_restore_backup(backups[0], target)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({match_count} replacement{'s' if match_count > 1 else ''}).",
|
||||
@@ -519,15 +643,14 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Back up for rollback
|
||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||
_atomic_write_text(target, file_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
ok, err, _ = _transactional_write(target, file_content)
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Write reverted: {err}"}
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(target, original_content)
|
||||
backups = sorted([f for f in target.parent.iterdir() if f.name.startswith(f".{target.name}.bak.")], key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
if backups:
|
||||
_restore_backup(backups[0], target)
|
||||
else:
|
||||
target.unlink(missing_ok=True)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
Reference in New Issue
Block a user