diff --git a/hermes_cli/config.py b/hermes_cli/config.py index a56b15e91..ccf3debc1 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -851,24 +851,27 @@ _COMMENTED_SECTIONS = """ def save_config(config: Dict[str, Any]): """Save configuration to ~/.hermes/config.yaml.""" + from utils import atomic_yaml_write + ensure_hermes_home() config_path = get_config_path() normalized = _normalize_max_turns_config(config) - - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(normalized, f, default_flow_style=False, sort_keys=False) - # Append commented-out sections for features that are off by default - # or only relevant when explicitly configured. Skip sections the - # user has already uncommented and configured. - sections = [] - sec = normalized.get("security", {}) - if not sec or sec.get("redact_secrets") is None: - sections.append("security") - fb = normalized.get("fallback_model", {}) - if not fb or not (fb.get("provider") and fb.get("model")): - sections.append("fallback") - if sections: - f.write(_COMMENTED_SECTIONS) + + # Build optional commented-out sections for features that are off by + # default or only relevant when explicitly configured. + sections = [] + sec = normalized.get("security", {}) + if not sec or sec.get("redact_secrets") is None: + sections.append("security") + fb = normalized.get("fallback_model", {}) + if not fb or not (fb.get("provider") and fb.get("model")): + sections.append("fallback") + + atomic_yaml_write( + config_path, + normalized, + extra_content=_COMMENTED_SECTIONS if sections else None, + ) def load_env() -> Dict[str, str]: diff --git a/tests/hermes_cli/test_config.py b/tests/hermes_cli/test_config.py index f3b9f9355..df647fb6c 100644 --- a/tests/hermes_cli/test_config.py +++ b/tests/hermes_cli/test_config.py @@ -2,7 +2,9 @@ import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock + +import yaml import yaml @@ -90,3 +92,62 @@ class TestSaveAndLoadRoundtrip: reloaded = load_config() assert reloaded["terminal"]["timeout"] == 999 + + +class TestSaveConfigAtomicity: + """Verify save_config uses atomic writes (tempfile + os.replace).""" + + def test_no_partial_write_on_crash(self, tmp_path): + """If save_config crashes mid-write, the previous file stays intact.""" + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + # Write an initial config + config = load_config() + config["model"] = "original-model" + save_config(config) + + config_path = tmp_path / "config.yaml" + assert config_path.exists() + + # Simulate a crash during yaml.dump by making atomic_yaml_write's + # yaml.dump raise after the temp file is created but before replace. + with patch("utils.yaml.dump", side_effect=OSError("disk full")): + try: + config["model"] = "should-not-persist" + save_config(config) + except OSError: + pass + + # Original file must still be intact + reloaded = load_config() + assert reloaded["model"] == "original-model" + + def test_no_leftover_temp_files(self, tmp_path): + """Failed writes must clean up their temp files.""" + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + config = load_config() + save_config(config) + + with patch("utils.yaml.dump", side_effect=OSError("disk full")): + try: + save_config(config) + except OSError: + pass + + # No .tmp files should remain + tmp_files = list(tmp_path.glob(".*config*.tmp")) + assert tmp_files == [] + + def test_atomic_write_creates_valid_yaml(self, tmp_path): + """The written file must be valid YAML matching the input.""" + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + config = load_config() + config["model"] = "test/atomic-model" + config["agent"]["max_turns"] = 77 + save_config(config) + + # Read raw YAML to verify it's valid and correct + config_path = tmp_path / "config.yaml" + with open(config_path) as f: + raw = yaml.safe_load(f) + assert raw["model"] == "test/atomic-model" + assert raw["agent"]["max_turns"] == 77 diff --git a/utils.py b/utils.py index 9c8b5e8c6..1b99d60fe 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,8 @@ import tempfile from pathlib import Path from typing import Any, Union +import yaml + def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> None: """Write JSON data to a file atomically. @@ -39,3 +41,49 @@ def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> except OSError: pass raise + + +def atomic_yaml_write( + path: Union[str, Path], + data: Any, + *, + default_flow_style: bool = False, + sort_keys: bool = False, + extra_content: str | None = None, +) -> None: + """Write YAML data to a file atomically. + + Uses temp file + fsync + os.replace to ensure the target file is never + left in a partially-written state. If the process crashes mid-write, + the previous version of the file remains intact. + + Args: + path: Target file path (will be created or overwritten). + data: YAML-serializable data to write. + default_flow_style: YAML flow style (default False). + sort_keys: Whether to sort dict keys (default False). + extra_content: Optional string to append after the YAML dump + (e.g. commented-out sections for user reference). + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), + prefix=f".{path.stem}_", + suffix=".tmp", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + yaml.dump(data, f, default_flow_style=default_flow_style, sort_keys=sort_keys) + if extra_content: + f.write(extra_content) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise