Make skill file writes atomic
This commit is contained in:
@@ -37,6 +37,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
@@ -190,6 +191,38 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||
"""
|
||||
Atomically write text content to a file.
|
||||
|
||||
Uses a temporary file in the same directory and os.replace() to ensure
|
||||
the target file is never left in a partially-written state if the process
|
||||
crashes or is interrupted.
|
||||
|
||||
Args:
|
||||
file_path: Target file path
|
||||
content: Content to write
|
||||
encoding: Text encoding (default: utf-8)
|
||||
"""
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
dir=str(file_path.parent),
|
||||
prefix=f".{file_path.name}.tmp.",
|
||||
suffix="",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
os.replace(temp_path, file_path)
|
||||
except Exception:
|
||||
# Clean up temp file on error
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core actions
|
||||
# =============================================================================
|
||||
@@ -218,9 +251,9 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
||||
skill_dir = _resolve_skill_dir(name, category)
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write SKILL.md
|
||||
# Write SKILL.md atomically
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
@@ -256,13 +289,13 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
skill_md = existing["path"] / "SKILL.md"
|
||||
# Back up original content for rollback
|
||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
skill_md.write_text(original_content, encoding="utf-8")
|
||||
_atomic_write_text(skill_md, original_content)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
return {
|
||||
@@ -342,12 +375,12 @@ def _patch_skill(
|
||||
}
|
||||
|
||||
original_content = content # for rollback
|
||||
target.write_text(new_content, encoding="utf-8")
|
||||
_atomic_write_text(target, new_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
target.write_text(original_content, encoding="utf-8")
|
||||
_atomic_write_text(target, original_content)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
replacements = count if replace_all else 1
|
||||
@@ -394,13 +427,13 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Back up for rollback
|
||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||
target.write_text(file_content, encoding="utf-8")
|
||||
_atomic_write_text(target, file_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
target.write_text(original_content, encoding="utf-8")
|
||||
_atomic_write_text(target, original_content)
|
||||
else:
|
||||
target.unlink(missing_ok=True)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
Reference in New Issue
Block a user