Compare commits
1 Commits
burn/750-1
...
fix/686-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07570c652d |
267
scripts/config_drift.py
Normal file
267
scripts/config_drift.py
Normal file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
config_drift.py — Detect configuration drift across fleet nodes.
|
||||
|
||||
Collects config from all nodes via SSH, diffs against canonical config,
|
||||
and reports which keys differ on which nodes.
|
||||
|
||||
Usage:
|
||||
python3 config_drift.py --nodes allegro,ezra,bezalel
|
||||
python3 config_drift.py --inventory ansible/playbooks/inventory
|
||||
python3 config_drift.py --check-only # don't fetch, compare existing
|
||||
python3 config_drift.py --sync # auto-sync with approval
|
||||
|
||||
Exit codes:
|
||||
0 = no drift detected
|
||||
1 = drift detected
|
||||
2 = error
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
# Canonical config keys to check (from timmy-config)
|
||||
CANONICAL_KEYS = [
|
||||
"provider",
|
||||
"model",
|
||||
"provider_name",
|
||||
"system_prompt",
|
||||
"cron.enabled",
|
||||
"cron.workers",
|
||||
"cron.tick_seconds",
|
||||
"session.reset_after",
|
||||
"session.max_turns",
|
||||
]
|
||||
|
||||
CANONICAL_CONFIG_PATH = Path(__file__).parent.parent / "config" / "config.yaml"
|
||||
|
||||
|
||||
def parse_inventory(inventory_path: str) -> Dict[str, str]:
|
||||
"""Parse Ansible inventory to get node name → host mapping."""
|
||||
nodes = {}
|
||||
current_section = None
|
||||
|
||||
with open(inventory_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
if line.startswith('[') and line.endswith(']'):
|
||||
current_section = line[1:-1]
|
||||
continue
|
||||
if current_section and 'ansible_host=' in line:
|
||||
parts = line.split()
|
||||
name = parts[0]
|
||||
host = None
|
||||
for p in parts:
|
||||
if p.startswith('ansible_host='):
|
||||
host = p.split('=')[1]
|
||||
if host and host != 'localhost':
|
||||
nodes[name] = host
|
||||
return nodes
|
||||
|
||||
|
||||
def fetch_remote_config(host: str, config_path: str = "/root/.hermes/config.yaml") -> Optional[Dict]:
|
||||
"""Fetch config from remote node via SSH."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
|
||||
f"root@{host}", f"cat {config_path} 2>/dev/null || echo '{{}}'"],
|
||||
capture_output=True, text=True, timeout=30
|
||||
)
|
||||
if result.returncode == 0:
|
||||
try:
|
||||
import yaml
|
||||
return yaml.safe_load(result.stdout) or {}
|
||||
except ImportError:
|
||||
# Fallback: parse basic YAML manually
|
||||
return parse_yaml_basic(result.stdout)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def parse_yaml_basic(content: str) -> Dict:
|
||||
"""Basic YAML parser for simple key-value configs."""
|
||||
result = {}
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
if ':' in line:
|
||||
key, _, value = line.partition(':')
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if value.lower() == 'true':
|
||||
value = True
|
||||
elif value.lower() == 'false':
|
||||
value = False
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def get_nested_value(config: Dict, key_path: str):
|
||||
"""Get value from nested dict using dot notation."""
|
||||
keys = key_path.split('.')
|
||||
value = config
|
||||
for k in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(k)
|
||||
else:
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def compare_configs(canonical: Dict, remote: Dict, keys: List[str]) -> List[Tuple[str, str, any, any]]:
|
||||
"""
|
||||
Compare canonical config against remote config.
|
||||
|
||||
Returns list of (key, node, canonical_value, remote_value) for differences.
|
||||
"""
|
||||
diffs = []
|
||||
for key in keys:
|
||||
canonical_val = get_nested_value(canonical, key)
|
||||
remote_val = get_nested_value(remote, key)
|
||||
|
||||
if canonical_val != remote_val:
|
||||
diffs.append((key, canonical_val, remote_val))
|
||||
return diffs
|
||||
|
||||
|
||||
def load_canonical_config() -> Dict:
|
||||
"""Load the canonical config from timmy-config."""
|
||||
if CANONICAL_CONFIG_PATH.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(CANONICAL_CONFIG_PATH) as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except ImportError:
|
||||
with open(CANONICAL_CONFIG_PATH) as f:
|
||||
return parse_yaml_basic(f.read())
|
||||
return {}
|
||||
|
||||
|
||||
def run_drift_check(nodes: Dict[str, str], canonical: Dict, keys: List[str]) -> Dict[str, List]:
|
||||
"""Run drift check across all nodes."""
|
||||
results = {}
|
||||
for name, host in nodes.items():
|
||||
remote_config = fetch_remote_config(host)
|
||||
if remote_config is None:
|
||||
results[name] = {"status": "unreachable", "diffs": []}
|
||||
continue
|
||||
|
||||
diffs = compare_configs(canonical, remote_config, keys)
|
||||
results[name] = {
|
||||
"status": "drift" if diffs else "ok",
|
||||
"host": host,
|
||||
"diffs": [(k, str(cv), str(rv)) for k, cv, rv in diffs],
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def generate_report(results: Dict, canonical_keys: List[str]) -> str:
|
||||
"""Generate human-readable drift report."""
|
||||
lines = []
|
||||
lines.append("=" * 60)
|
||||
lines.append(" CONFIG DRIFT REPORT")
|
||||
lines.append(f" {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}")
|
||||
lines.append("=" * 60)
|
||||
|
||||
drift_count = 0
|
||||
ok_count = 0
|
||||
unreachable_count = 0
|
||||
|
||||
for node, data in sorted(results.items()):
|
||||
status = data["status"]
|
||||
if status == "unreachable":
|
||||
unreachable_count += 1
|
||||
lines.append(f"\n {node}: UNREACHABLE")
|
||||
continue
|
||||
elif status == "drift":
|
||||
drift_count += 1
|
||||
lines.append(f"\n {node}: DRIFT DETECTED")
|
||||
for key, canonical_val, remote_val in data["diffs"]:
|
||||
lines.append(f" {key}:")
|
||||
lines.append(f" canonical: {canonical_val}")
|
||||
lines.append(f" remote: {remote_val}")
|
||||
else:
|
||||
ok_count += 1
|
||||
lines.append(f"\n {node}: OK")
|
||||
|
||||
lines.append(f"\n{'=' * 60}")
|
||||
lines.append(f" Summary: {ok_count} ok, {drift_count} drift, {unreachable_count} unreachable")
|
||||
lines.append(f" Keys checked: {len(canonical_keys)}")
|
||||
lines.append("=" * 60)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Config drift detection across fleet")
|
||||
parser.add_argument("--inventory", help="Ansible inventory file path")
|
||||
parser.add_argument("--nodes", help="Comma-separated node list (name:host)")
|
||||
parser.add_argument("--canonical", help="Path to canonical config (default: timmy-config)")
|
||||
parser.add_argument("--keys", help="Comma-separated keys to check")
|
||||
parser.add_argument("--json", action="store_true", help="JSON output")
|
||||
parser.add_argument("--check-only", action="store_true", help="Use cached configs only")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load canonical config
|
||||
if args.canonical:
|
||||
global CANONICAL_CONFIG_PATH
|
||||
CANONICAL_CONFIG_PATH = Path(args.canonical)
|
||||
canonical = load_canonical_config()
|
||||
|
||||
# Determine keys to check
|
||||
keys = CANONICAL_KEYS
|
||||
if args.keys:
|
||||
keys = args.keys.split(',')
|
||||
|
||||
# Determine nodes
|
||||
nodes = {}
|
||||
if args.inventory:
|
||||
nodes = parse_inventory(args.inventory)
|
||||
elif args.nodes:
|
||||
for pair in args.nodes.split(','):
|
||||
if ':' in pair:
|
||||
name, host = pair.split(':')
|
||||
nodes[name] = host
|
||||
else:
|
||||
nodes[pair] = pair
|
||||
else:
|
||||
# Default nodes from fleet
|
||||
nodes = {
|
||||
"allegro": "167.99.126.228",
|
||||
"ezra": "143.198.27.163",
|
||||
"bezalel": "159.203.146.185",
|
||||
}
|
||||
|
||||
if not nodes:
|
||||
print("ERROR: No nodes specified", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
# Run check
|
||||
results = run_drift_check(nodes, canonical, keys)
|
||||
|
||||
# Output
|
||||
if args.json:
|
||||
print(json.dumps(results, indent=2))
|
||||
else:
|
||||
report = generate_report(results, keys)
|
||||
print(report)
|
||||
|
||||
# Exit code
|
||||
has_drift = any(d["status"] == "drift" for d in results.values())
|
||||
sys.exit(1 if has_drift else 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,139 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
normalize-code-blocks.py — Fix inconsistent indentation in training data code blocks.
|
||||
|
||||
When code blocks are embedded in JSONL as triple-quoted strings, indentation
|
||||
accumulates from the surrounding context. This script normalizes code block
|
||||
content using textwrap.dedent and consistent 4-space indentation.
|
||||
|
||||
Usage:
|
||||
python3 scripts/normalize-code-blocks.py training/data/preference_pairs.jsonl
|
||||
python3 scripts/normalize-code-blocks.py --dry-run training/data/*.jsonl
|
||||
python3 scripts/normalize-code-blocks.py --check training/data/*.jsonl # CI mode
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
# Matches ```python ... ``` or ``` ... ``` blocks inside string values
|
||||
CODE_BLOCK_RE = re.compile(
|
||||
r'(?P<open>```(?:python|py|bash|sh|javascript|js|typescript|ts|go|rust|ruby)?\s*\n)'
|
||||
r'(?P<code>.*?)'
|
||||
r'(?P<close>```)',
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def normalize_code_block(match: re.Match) -> str:
|
||||
"""Normalize indentation in a single code block."""
|
||||
open_tag = match.group("open")
|
||||
code = match.group("code")
|
||||
close_tag = match.group("close")
|
||||
|
||||
# Skip empty blocks
|
||||
if not code.strip():
|
||||
return match.group(0)
|
||||
|
||||
# Dedent the code
|
||||
dedented = textwrap.dedent(code)
|
||||
|
||||
# Strip leading/trailing blank lines
|
||||
lines = dedented.split("\n")
|
||||
while lines and not lines[0].strip():
|
||||
lines.pop(0)
|
||||
while lines and not lines[-1].strip():
|
||||
lines.pop()
|
||||
|
||||
normalized = "\n".join(lines)
|
||||
|
||||
return f"{open_tag}{normalized}\n{close_tag}"
|
||||
|
||||
|
||||
def process_line(line: str) -> tuple[str, int]:
|
||||
"""Process a single JSONL line. Returns (new_line, num_fixes)."""
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return line, 0
|
||||
|
||||
fixes = 0
|
||||
|
||||
def fix_strings(obj):
|
||||
nonlocal fixes
|
||||
if isinstance(obj, str):
|
||||
original = obj
|
||||
fixed = CODE_BLOCK_RE.sub(normalize_code_block, obj)
|
||||
if fixed != original:
|
||||
fixes += 1
|
||||
return fixed
|
||||
elif isinstance(obj, dict):
|
||||
return {k: fix_strings(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [fix_strings(item) for item in obj]
|
||||
return obj
|
||||
|
||||
fixed_obj = fix_strings(obj)
|
||||
return json.dumps(fixed_obj, ensure_ascii=False) + "\n", fixes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Normalize code block indentation in JSONL training data")
|
||||
parser.add_argument("files", nargs="+", help="JSONL files to process")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show changes without writing")
|
||||
parser.add_argument("--check", action="store_true", help="CI mode: exit 1 if fixes needed")
|
||||
args = parser.parse_args()
|
||||
|
||||
total_fixes = 0
|
||||
total_lines = 0
|
||||
files_changed = 0
|
||||
|
||||
for filepath in args.files:
|
||||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
print(f"SKIP: {path} not found", file=sys.stderr)
|
||||
continue
|
||||
|
||||
lines = path.read_text().splitlines(keepends=True)
|
||||
fixed_lines = []
|
||||
file_fixes = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if not line.strip():
|
||||
fixed_lines.append(line)
|
||||
continue
|
||||
fixed_line, n = process_line(line)
|
||||
fixed_lines.append(fixed_line)
|
||||
file_fixes += n
|
||||
total_lines += 1
|
||||
|
||||
if file_fixes > 0:
|
||||
files_changed += 1
|
||||
total_fixes += file_fixes
|
||||
print(f"{'CHECK' if args.check else 'FIX'}: {path} — {file_fixes} code blocks normalized")
|
||||
|
||||
if args.check:
|
||||
# Show diff
|
||||
for i, (old, new) in enumerate(zip(lines, fixed_lines)):
|
||||
if old != new:
|
||||
print(f" Line {i+1}: indentation changed")
|
||||
elif not args.dry_run:
|
||||
path.write_text("".join(fixed_lines))
|
||||
print(f" Written: {path}")
|
||||
else:
|
||||
print(f"OK: {path} — no indentation issues")
|
||||
|
||||
print(f"\nSummary: {total_fixes} code blocks fixed across {files_changed} files ({total_lines} lines processed)")
|
||||
|
||||
if args.check and total_fixes > 0:
|
||||
print("FAIL: Code block indentation issues found. Run without --check to fix.")
|
||||
sys.exit(1)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
149
tests/test_config_drift.py
Normal file
149
tests/test_config_drift.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Tests for scripts/config_drift.py — Config drift detection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||
from config_drift import (
|
||||
get_nested_value,
|
||||
compare_configs,
|
||||
parse_yaml_basic,
|
||||
generate_report,
|
||||
)
|
||||
|
||||
|
||||
class TestGetNestedValue(unittest.TestCase):
|
||||
def test_top_level(self):
|
||||
config = {"provider": "openrouter"}
|
||||
self.assertEqual(get_nested_value(config, "provider"), "openrouter")
|
||||
|
||||
def test_nested(self):
|
||||
config = {"cron": {"enabled": True, "workers": 4}}
|
||||
self.assertEqual(get_nested_value(config, "cron.enabled"), True)
|
||||
self.assertEqual(get_nested_value(config, "cron.workers"), 4)
|
||||
|
||||
def test_missing_key(self):
|
||||
config = {"provider": "openrouter"}
|
||||
self.assertIsNone(get_nested_value(config, "missing"))
|
||||
|
||||
def test_missing_nested(self):
|
||||
config = {"cron": {}}
|
||||
self.assertIsNone(get_nested_value(config, "cron.enabled"))
|
||||
|
||||
def test_deep_nesting(self):
|
||||
config = {"a": {"b": {"c": "value"}}}
|
||||
self.assertEqual(get_nested_value(config, "a.b.c"), "value")
|
||||
|
||||
|
||||
class TestCompareConfigs(unittest.TestCase):
|
||||
def test_no_diff(self):
|
||||
canonical = {"provider": "openrouter", "model": "mimo"}
|
||||
remote = {"provider": "openrouter", "model": "mimo"}
|
||||
diffs = compare_configs(canonical, remote, ["provider", "model"])
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
def test_single_diff(self):
|
||||
canonical = {"provider": "openrouter"}
|
||||
remote = {"provider": "anthropic"}
|
||||
diffs = compare_configs(canonical, remote, ["provider"])
|
||||
self.assertEqual(len(diffs), 1)
|
||||
self.assertEqual(diffs[0][0], "provider")
|
||||
self.assertEqual(diffs[0][1], "openrouter")
|
||||
self.assertEqual(diffs[0][2], "anthropic")
|
||||
|
||||
def test_multiple_diffs(self):
|
||||
canonical = {"provider": "openrouter", "model": "mimo"}
|
||||
remote = {"provider": "anthropic", "model": "claude"}
|
||||
diffs = compare_configs(canonical, remote, ["provider", "model"])
|
||||
self.assertEqual(len(diffs), 2)
|
||||
|
||||
def test_nested_diff(self):
|
||||
canonical = {"cron": {"enabled": True}}
|
||||
remote = {"cron": {"enabled": False}}
|
||||
diffs = compare_configs(canonical, remote, ["cron.enabled"])
|
||||
self.assertEqual(len(diffs), 1)
|
||||
self.assertEqual(diffs[0][0], "cron.enabled")
|
||||
|
||||
def test_missing_in_remote(self):
|
||||
canonical = {"provider": "openrouter"}
|
||||
remote = {}
|
||||
diffs = compare_configs(canonical, remote, ["provider"])
|
||||
self.assertEqual(len(diffs), 1)
|
||||
|
||||
def test_extra_in_remote(self):
|
||||
canonical = {}
|
||||
remote = {"provider": "openrouter"}
|
||||
diffs = compare_configs(canonical, remote, ["provider"])
|
||||
self.assertEqual(len(diffs), 1)
|
||||
|
||||
|
||||
class TestParseYamlBasic(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
content = "provider: openrouter\nmodel: mimo-v2-pro\n"
|
||||
result = parse_yaml_basic(content)
|
||||
self.assertEqual(result["provider"], "openrouter")
|
||||
self.assertEqual(result["model"], "mimo-v2-pro")
|
||||
|
||||
def test_boolean(self):
|
||||
content = "enabled: true\ndisabled: false\n"
|
||||
result = parse_yaml_basic(content)
|
||||
self.assertEqual(result["enabled"], True)
|
||||
self.assertEqual(result["disabled"], False)
|
||||
|
||||
def test_integer(self):
|
||||
content = "workers: 4\nport: 8080\n"
|
||||
result = parse_yaml_basic(content)
|
||||
self.assertEqual(result["workers"], 4)
|
||||
self.assertEqual(result["port"], 8080)
|
||||
|
||||
def test_comments_skipped(self):
|
||||
content = "# This is a comment\nprovider: openrouter\n"
|
||||
result = parse_yaml_basic(content)
|
||||
self.assertNotIn("#", result)
|
||||
self.assertEqual(result["provider"], "openrouter")
|
||||
|
||||
def test_quoted_values(self):
|
||||
content = 'name: "hello world"\nother: \'single quotes\'\n'
|
||||
result = parse_yaml_basic(content)
|
||||
self.assertEqual(result["name"], "hello world")
|
||||
self.assertEqual(result["other"], "single quotes")
|
||||
|
||||
|
||||
class TestGenerateReport(unittest.TestCase):
|
||||
def test_all_ok(self):
|
||||
results = {
|
||||
"node1": {"status": "ok", "diffs": []},
|
||||
"node2": {"status": "ok", "diffs": []},
|
||||
}
|
||||
report = generate_report(results, ["provider"])
|
||||
self.assertIn("OK", report)
|
||||
self.assertIn("2 ok", report)
|
||||
|
||||
def test_drift_reported(self):
|
||||
results = {
|
||||
"node1": {
|
||||
"status": "drift",
|
||||
"diffs": [("provider", "openrouter", "anthropic")]
|
||||
},
|
||||
"node2": {"status": "ok", "diffs": []},
|
||||
}
|
||||
report = generate_report(results, ["provider"])
|
||||
self.assertIn("DRIFT DETECTED", report)
|
||||
self.assertIn("openrouter", report)
|
||||
self.assertIn("anthropic", report)
|
||||
|
||||
def test_unreachable_reported(self):
|
||||
results = {
|
||||
"node1": {"status": "unreachable", "diffs": []},
|
||||
}
|
||||
report = generate_report(results, ["provider"])
|
||||
self.assertIn("UNREACHABLE", report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,139 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for normalize-code-blocks.py — training data code block indentation fix (#750)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
|
||||
from normalize_code_blocks import normalize_code_block, process_line, CODE_BLOCK_RE
|
||||
|
||||
|
||||
class TestNormalizeCodeBlock:
|
||||
def test_basic_dedent(self):
|
||||
block = "```python\n from fastapi import FastAPI\n app = FastAPI()\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert " from fastapi" not in result
|
||||
assert "from fastapi" in result
|
||||
|
||||
def test_preserves_language_tag(self):
|
||||
block = "```python\n x = 1\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert result.startswith("```python")
|
||||
|
||||
def test_empty_block_unchanged(self):
|
||||
block = "```python\n \n \n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert result == block
|
||||
|
||||
def test_multiple_blocks(self):
|
||||
text = 'First: ```python\n x = 1\n``` and second: ```python\n y = 2\n```'
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, text)
|
||||
assert " x = 1" not in result
|
||||
assert " y = 2" not in result
|
||||
assert "x = 1" in result
|
||||
assert "y = 2" in result
|
||||
|
||||
def test_bash_block(self):
|
||||
block = "```bash\n echo hello\n ls -la\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert " echo" not in result
|
||||
assert "echo hello" in result
|
||||
|
||||
def test_unlabeled_block(self):
|
||||
block = "```\n some code\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert " some code" not in result
|
||||
|
||||
def test_mixed_indentation(self):
|
||||
block = "```python\n def foo():\n return 42\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
lines = result.split("\n")
|
||||
# First code line should not have leading spaces from embedding
|
||||
code_lines = [l for l in lines if l.strip() and not l.startswith("```")]
|
||||
assert code_lines[0].startswith("def")
|
||||
|
||||
def test_strips_leading_trailing_blanks(self):
|
||||
block = "```python\n\n x = 1\n\n```"
|
||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
||||
assert "\n\n" not in result.split("```python")[1].split("```")[0]
|
||||
|
||||
|
||||
class TestProcessLine:
|
||||
def test_valid_jsonl_with_code(self):
|
||||
obj = {"prompt": "write code", "response": "```python\n x = 1\n```"}
|
||||
line = json.dumps(obj)
|
||||
fixed, n = process_line(line)
|
||||
parsed = json.loads(fixed)
|
||||
assert n == 1
|
||||
assert " x = 1" not in parsed["response"]
|
||||
|
||||
def test_no_code_blocks(self):
|
||||
obj = {"text": "hello world"}
|
||||
line = json.dumps(obj)
|
||||
fixed, n = process_line(line)
|
||||
assert n == 0
|
||||
assert json.loads(fixed)["text"] == "hello world"
|
||||
|
||||
def test_invalid_jsonl(self):
|
||||
line = "not valid json {{{"
|
||||
fixed, n = process_line(line)
|
||||
assert n == 0
|
||||
assert fixed == line
|
||||
|
||||
def test_nested_code_blocks(self):
|
||||
obj = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "write code"},
|
||||
{"role": "assistant", "content": "```python\n def f():\n pass\n```"}
|
||||
]
|
||||
}
|
||||
line = json.dumps(obj)
|
||||
fixed, n = process_line(line)
|
||||
assert n == 1
|
||||
parsed = json.loads(fixed)
|
||||
assert " def f" not in parsed["messages"][1]["content"]
|
||||
|
||||
def test_multiple_fields_with_code(self):
|
||||
obj = {
|
||||
"terse": "```python\n x = 1\n```",
|
||||
"rich": "```python\n y = 2\n```"
|
||||
}
|
||||
line = json.dumps(obj)
|
||||
fixed, n = process_line(line)
|
||||
parsed = json.loads(fixed)
|
||||
assert n == 2
|
||||
assert " x = 1" not in parsed["terse"]
|
||||
assert " y = 2" not in parsed["rich"]
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
def test_file_processing(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(json.dumps({"r": "```python\n x = 1\n```"}) + "\n")
|
||||
f.write(json.dumps({"r": "no code here"}) + "\n")
|
||||
f.write(json.dumps({"r": "```python\n def g():\n return 99\n```"}) + "\n")
|
||||
f.flush()
|
||||
|
||||
# Process using the script logic
|
||||
lines = Path(f.name).read_text().splitlines(keepends=True)
|
||||
fixed = []
|
||||
total = 0
|
||||
for line in lines:
|
||||
fl, n = process_line(line)
|
||||
fixed.append(fl)
|
||||
total += n
|
||||
|
||||
os.unlink(f.name)
|
||||
assert total == 2
|
||||
# Verify first line is fixed
|
||||
first = json.loads(fixed[0])
|
||||
assert " x = 1" not in first["r"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user