fix(config): atomic write for config.yaml to prevent data loss on crash
This commit is contained in:
@@ -851,24 +851,27 @@ _COMMENTED_SECTIONS = """
|
||||
|
||||
def save_config(config: Dict[str, Any]):
|
||||
"""Save configuration to ~/.hermes/config.yaml."""
|
||||
from utils import atomic_yaml_write
|
||||
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
normalized = _normalize_max_turns_config(config)
|
||||
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(normalized, f, default_flow_style=False, sort_keys=False)
|
||||
# Append commented-out sections for features that are off by default
|
||||
# or only relevant when explicitly configured. Skip sections the
|
||||
# user has already uncommented and configured.
|
||||
sections = []
|
||||
sec = normalized.get("security", {})
|
||||
if not sec or sec.get("redact_secrets") is None:
|
||||
sections.append("security")
|
||||
fb = normalized.get("fallback_model", {})
|
||||
if not fb or not (fb.get("provider") and fb.get("model")):
|
||||
sections.append("fallback")
|
||||
if sections:
|
||||
f.write(_COMMENTED_SECTIONS)
|
||||
|
||||
# Build optional commented-out sections for features that are off by
|
||||
# default or only relevant when explicitly configured.
|
||||
sections = []
|
||||
sec = normalized.get("security", {})
|
||||
if not sec or sec.get("redact_secrets") is None:
|
||||
sections.append("security")
|
||||
fb = normalized.get("fallback_model", {})
|
||||
if not fb or not (fb.get("provider") and fb.get("model")):
|
||||
sections.append("fallback")
|
||||
|
||||
atomic_yaml_write(
|
||||
config_path,
|
||||
normalized,
|
||||
extra_content=_COMMENTED_SECTIONS if sections else None,
|
||||
)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import yaml
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -90,3 +92,62 @@ class TestSaveAndLoadRoundtrip:
|
||||
|
||||
reloaded = load_config()
|
||||
assert reloaded["terminal"]["timeout"] == 999
|
||||
|
||||
|
||||
class TestSaveConfigAtomicity:
|
||||
"""Verify save_config uses atomic writes (tempfile + os.replace)."""
|
||||
|
||||
def test_no_partial_write_on_crash(self, tmp_path):
|
||||
"""If save_config crashes mid-write, the previous file stays intact."""
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
# Write an initial config
|
||||
config = load_config()
|
||||
config["model"] = "original-model"
|
||||
save_config(config)
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
assert config_path.exists()
|
||||
|
||||
# Simulate a crash during yaml.dump by making atomic_yaml_write's
|
||||
# yaml.dump raise after the temp file is created but before replace.
|
||||
with patch("utils.yaml.dump", side_effect=OSError("disk full")):
|
||||
try:
|
||||
config["model"] = "should-not-persist"
|
||||
save_config(config)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Original file must still be intact
|
||||
reloaded = load_config()
|
||||
assert reloaded["model"] == "original-model"
|
||||
|
||||
def test_no_leftover_temp_files(self, tmp_path):
|
||||
"""Failed writes must clean up their temp files."""
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
config = load_config()
|
||||
save_config(config)
|
||||
|
||||
with patch("utils.yaml.dump", side_effect=OSError("disk full")):
|
||||
try:
|
||||
save_config(config)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# No .tmp files should remain
|
||||
tmp_files = list(tmp_path.glob(".*config*.tmp"))
|
||||
assert tmp_files == []
|
||||
|
||||
def test_atomic_write_creates_valid_yaml(self, tmp_path):
|
||||
"""The written file must be valid YAML matching the input."""
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
config = load_config()
|
||||
config["model"] = "test/atomic-model"
|
||||
config["agent"]["max_turns"] = 77
|
||||
save_config(config)
|
||||
|
||||
# Read raw YAML to verify it's valid and correct
|
||||
config_path = tmp_path / "config.yaml"
|
||||
with open(config_path) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
assert raw["model"] == "test/atomic-model"
|
||||
assert raw["agent"]["max_turns"] == 77
|
||||
|
||||
48
utils.py
48
utils.py
@@ -6,6 +6,8 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> None:
|
||||
"""Write JSON data to a file atomically.
|
||||
@@ -39,3 +41,49 @@ def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) ->
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def atomic_yaml_write(
|
||||
path: Union[str, Path],
|
||||
data: Any,
|
||||
*,
|
||||
default_flow_style: bool = False,
|
||||
sort_keys: bool = False,
|
||||
extra_content: str | None = None,
|
||||
) -> None:
|
||||
"""Write YAML data to a file atomically.
|
||||
|
||||
Uses temp file + fsync + os.replace to ensure the target file is never
|
||||
left in a partially-written state. If the process crashes mid-write,
|
||||
the previous version of the file remains intact.
|
||||
|
||||
Args:
|
||||
path: Target file path (will be created or overwritten).
|
||||
data: YAML-serializable data to write.
|
||||
default_flow_style: YAML flow style (default False).
|
||||
sort_keys: Whether to sort dict keys (default False).
|
||||
extra_content: Optional string to append after the YAML dump
|
||||
(e.g. commented-out sections for user reference).
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent),
|
||||
prefix=f".{path.stem}_",
|
||||
suffix=".tmp",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, default_flow_style=default_flow_style, sort_keys=sort_keys)
|
||||
if extra_content:
|
||||
f.write(extra_content)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user