""" Integration tests for poka-yoke auto-revert on incomplete skill edits (#923). Verifies the transactional write-validate-commit-or-rollback pattern: - Backup created before every write - Post-write validation triggers revert on corrupted/empty file - Successful writes clean up the backup - At most MAX_BACKUPS_PER_FILE backups retained per file """ import time import pytest from pathlib import Path from unittest.mock import patch from tools.skill_manager_tool import ( MAX_BACKUPS_PER_FILE, _backup_skill_file, _cleanup_old_backups, _edit_skill, _patch_skill, _revert_from_backup, _validate_written_file, _write_file, ) VALID_SKILL_MD = """\ --- name: test-skill description: A skill for testing auto-revert --- ## Overview Test skill body content. """ VALID_UPDATED_MD = """\ --- name: test-skill description: Updated description --- ## Overview Updated test skill body. """ # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_skill(tmp_path: Path, content: str = VALID_SKILL_MD) -> Path: """Write a minimal SKILL.md in *tmp_path* and return its path.""" skill_md = tmp_path / "SKILL.md" skill_md.write_text(content, encoding="utf-8") return skill_md # --------------------------------------------------------------------------- # Unit tests: _backup_skill_file # --------------------------------------------------------------------------- class TestBackupSkillFile: def test_creates_bak_file(self, tmp_path): skill_md = _make_skill(tmp_path) backup = _backup_skill_file(skill_md) assert backup is not None assert backup.exists() assert ".bak." in backup.name def test_backup_preserves_content(self, tmp_path): skill_md = _make_skill(tmp_path) backup = _backup_skill_file(skill_md) assert backup.read_text(encoding="utf-8") == VALID_SKILL_MD def test_no_backup_for_nonexistent_file(self, tmp_path): missing = tmp_path / "SKILL.md" assert _backup_skill_file(missing) is None def test_backup_name_contains_timestamp(self, tmp_path): skill_md = _make_skill(tmp_path) before = int(time.time()) backup = _backup_skill_file(skill_md) after = int(time.time()) ts = int(backup.name.split(".bak.")[-1]) assert before <= ts <= after # --------------------------------------------------------------------------- # Unit tests: _cleanup_old_backups # --------------------------------------------------------------------------- class TestCleanupOldBackups: def _create_backups(self, skill_md: Path, n: int) -> list: backups = [] for i in range(n): bp = skill_md.parent / f"{skill_md.name}.bak.{1000 + i}" bp.write_text("backup content", encoding="utf-8") backups.append(bp) return backups def test_prunes_excess_backups(self, tmp_path): skill_md = _make_skill(tmp_path) self._create_backups(skill_md, MAX_BACKUPS_PER_FILE + 2) _cleanup_old_backups(skill_md) remaining = list(tmp_path.glob(f"SKILL.md.bak.*")) assert len(remaining) == MAX_BACKUPS_PER_FILE def test_keeps_backups_within_limit(self, tmp_path): skill_md = _make_skill(tmp_path) self._create_backups(skill_md, MAX_BACKUPS_PER_FILE) _cleanup_old_backups(skill_md) remaining = list(tmp_path.glob("SKILL.md.bak.*")) assert len(remaining) == MAX_BACKUPS_PER_FILE def test_noop_when_no_backups(self, tmp_path): skill_md = _make_skill(tmp_path) _cleanup_old_backups(skill_md) # should not raise # --------------------------------------------------------------------------- # Unit tests: _validate_written_file # --------------------------------------------------------------------------- class TestValidateWrittenFile: def test_valid_skill_md(self, tmp_path): skill_md = _make_skill(tmp_path) assert _validate_written_file(skill_md, is_skill_md=True) is None def test_empty_file_fails(self, tmp_path): skill_md = tmp_path / "SKILL.md" skill_md.write_text("", encoding="utf-8") err = _validate_written_file(skill_md, is_skill_md=False) assert err is not None assert "empty" in err.lower() def test_broken_frontmatter_fails(self, tmp_path): skill_md = tmp_path / "SKILL.md" skill_md.write_text("Not a skill\nno frontmatter\n", encoding="utf-8") err = _validate_written_file(skill_md, is_skill_md=True) assert err is not None def test_missing_required_field_fails(self, tmp_path): skill_md = tmp_path / "SKILL.md" skill_md.write_text("---\ndescription: no name\n---\nbody\n", encoding="utf-8") err = _validate_written_file(skill_md, is_skill_md=True) assert err is not None assert "name" in err.lower() def test_missing_file_returns_error(self, tmp_path): missing = tmp_path / "SKILL.md" err = _validate_written_file(missing, is_skill_md=False) assert err is not None def test_non_skill_md_only_checks_emptiness(self, tmp_path): ref = tmp_path / "references" / "guide.md" ref.parent.mkdir() ref.write_text("# Guide\nsome content\n", encoding="utf-8") assert _validate_written_file(ref, is_skill_md=False) is None # --------------------------------------------------------------------------- # Unit tests: _revert_from_backup # --------------------------------------------------------------------------- class TestRevertFromBackup: def test_restores_from_backup(self, tmp_path): original = "original content" skill_md = tmp_path / "SKILL.md" skill_md.write_text(original, encoding="utf-8") backup = tmp_path / "SKILL.md.bak.99999" backup.write_text(original, encoding="utf-8") skill_md.write_text("corrupted content", encoding="utf-8") _revert_from_backup(skill_md, backup) assert skill_md.read_text(encoding="utf-8") == original def test_removes_file_when_no_backup(self, tmp_path): skill_md = tmp_path / "SKILL.md" skill_md.write_text("corrupted", encoding="utf-8") _revert_from_backup(skill_md, None) assert not skill_md.exists() # --------------------------------------------------------------------------- # Integration tests: _edit_skill auto-revert # --------------------------------------------------------------------------- class TestEditSkillAutoRevert: @pytest.fixture def skill_dir(self, tmp_path): """Create a minimal skill directory and patch _find_skill.""" d = tmp_path / "test-skill" d.mkdir() skill_md = d / "SKILL.md" skill_md.write_text(VALID_SKILL_MD, encoding="utf-8") return d def test_successful_edit_removes_backup(self, skill_dir): with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._security_scan_skill", return_value=None): mock_find.return_value = {"path": skill_dir} result = _edit_skill("test-skill", VALID_UPDATED_MD) assert result["success"] is True backups = list(skill_dir.glob("SKILL.md.bak.*")) assert len(backups) == 0 def test_revert_when_post_write_validation_fails(self, skill_dir): """Simulate a write that produces an empty file on disk.""" skill_md = skill_dir / "SKILL.md" def corrupt_write(path, content, **kw): # Write an empty file to simulate truncation path.write_text("", encoding="utf-8") with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write): mock_find.return_value = {"path": skill_dir} result = _edit_skill("test-skill", VALID_UPDATED_MD) assert result["success"] is False assert "reverted" in result["error"].lower() # Original content restored assert skill_md.read_text(encoding="utf-8") == VALID_SKILL_MD def test_backup_preserved_after_revert(self, skill_dir): """A .bak file should survive when the edit is reverted (debugging aid).""" def corrupt_write(path, content, **kw): path.write_text("", encoding="utf-8") with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write): mock_find.return_value = {"path": skill_dir} _edit_skill("test-skill", VALID_UPDATED_MD) backups = list(skill_dir.glob("SKILL.md.bak.*")) assert len(backups) == 1 def test_max_backups_enforced_after_multiple_edits(self, skill_dir): """After many successful edits, at most MAX_BACKUPS_PER_FILE .bak files remain.""" n = MAX_BACKUPS_PER_FILE + 4 for i in range(n): # Plant stale backup files to simulate prior runs bp = skill_dir / f"SKILL.md.bak.{1000 + i}" bp.write_text("old backup", encoding="utf-8") with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._security_scan_skill", return_value=None): mock_find.return_value = {"path": skill_dir} result = _edit_skill("test-skill", VALID_UPDATED_MD) assert result["success"] is True backups = list(skill_dir.glob("SKILL.md.bak.*")) assert len(backups) <= MAX_BACKUPS_PER_FILE # --------------------------------------------------------------------------- # Integration tests: _patch_skill auto-revert # --------------------------------------------------------------------------- class TestPatchSkillAutoRevert: @pytest.fixture def skill_dir(self, tmp_path): d = tmp_path / "test-skill" d.mkdir() (d / "SKILL.md").write_text(VALID_SKILL_MD, encoding="utf-8") return d def test_successful_patch_removes_backup(self, skill_dir): with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._security_scan_skill", return_value=None): mock_find.return_value = {"path": skill_dir} result = _patch_skill( "test-skill", "A skill for testing auto-revert", "Updated description", ) assert result["success"] is True assert len(list(skill_dir.glob("SKILL.md.bak.*"))) == 0 def test_revert_on_corrupt_write(self, skill_dir): skill_md = skill_dir / "SKILL.md" original = skill_md.read_text(encoding="utf-8") def corrupt_write(path, content, **kw): path.write_text("", encoding="utf-8") with patch("tools.skill_manager_tool._find_skill") as mock_find, \ patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write): mock_find.return_value = {"path": skill_dir} result = _patch_skill( "test-skill", "A skill for testing", "A skill for testing auto-revert", ) assert result["success"] is False assert "reverted" in result["error"].lower() assert skill_md.read_text(encoding="utf-8") == original