Compare commits
2 Commits
fix/628-ha
...
issue-750-
| Author | SHA1 | Date | |
|---|---|---|---|
| cb1408aafb | |||
| 4b0cbd123e |
@@ -1,264 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
hash_dedup.py — Deduplication with bounded hash storage.
|
||||
|
||||
Stores SHA-256 hashes of validated content in date-stamped files.
|
||||
Rotates daily. Keeps only last N days. Prevents unbounded growth.
|
||||
|
||||
Usage:
|
||||
from hash_dedup import HashDedup
|
||||
|
||||
dedup = HashDedup("/path/to/.hashes")
|
||||
if dedup.is_duplicate("some content"):
|
||||
print("Already seen")
|
||||
else:
|
||||
dedup.add("some content")
|
||||
print("New content")
|
||||
|
||||
# Cleanup old hashes
|
||||
dedup.cleanup(keep_days=7)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Set, Optional
|
||||
|
||||
|
||||
class HashDedup:
|
||||
"""
|
||||
Bounded hash-based deduplication with daily rotation.
|
||||
|
||||
Storage layout:
|
||||
.hashes/
|
||||
2026-04-15.json (one file per day)
|
||||
2026-04-14.json
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str, max_hashes_per_file: int = 100000):
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.max_hashes_per_file = max_hashes_per_file
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._today_hashes: Optional[Set[str]] = None
|
||||
self._today_file: Optional[Path] = None
|
||||
|
||||
def _today(self) -> str:
|
||||
"""Current date string."""
|
||||
return datetime.utcnow().strftime("%Y-%m-%d")
|
||||
|
||||
def _date_file(self, date_str: str) -> Path:
|
||||
"""Path to hash file for a given date."""
|
||||
return self.storage_dir / f"{date_str}.json"
|
||||
|
||||
@property
|
||||
def today_file(self) -> Path:
|
||||
if self._today_file is None:
|
||||
self._today_file = self._date_file(self._today())
|
||||
return self._today_file
|
||||
|
||||
def _load_today(self) -> Set[str]:
|
||||
"""Load today's hashes from disk."""
|
||||
if self._today_hashes is not None:
|
||||
return self._today_hashes
|
||||
|
||||
path = self.today_file
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
self._today_hashes = set(data.get("hashes", []))
|
||||
except (json.JSONDecodeError, IOError):
|
||||
self._today_hashes = set()
|
||||
else:
|
||||
self._today_hashes = set()
|
||||
return self._today_hashes
|
||||
|
||||
def _save_today(self):
|
||||
"""Save today's hashes to disk."""
|
||||
hashes = self._load_today()
|
||||
path = self.today_file
|
||||
|
||||
# Enforce max size
|
||||
if len(hashes) > self.max_hashes_per_file:
|
||||
hashes = set(list(hashes)[:self.max_hashes_per_file])
|
||||
|
||||
data = {
|
||||
"date": self._today(),
|
||||
"count": len(hashes),
|
||||
"hashes": sorted(hashes),
|
||||
}
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute SHA-256 hex digest of content."""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def is_duplicate(self, content: str) -> bool:
|
||||
"""Check if content hash exists in today's file or recent files."""
|
||||
h = self.compute_hash(content)
|
||||
|
||||
# Check today
|
||||
if h in self._load_today():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_duplicate_any(self, content: str, lookback_days: int = 7) -> bool:
|
||||
"""Check if content hash exists in any file within lookback period."""
|
||||
h = self.compute_hash(content)
|
||||
|
||||
# Check today first
|
||||
if h in self._load_today():
|
||||
return True
|
||||
|
||||
# Check recent files
|
||||
for i in range(1, lookback_days + 1):
|
||||
date_str = (datetime.utcnow() - timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
path = self._date_file(date_str)
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if h in set(data.get("hashes", [])):
|
||||
return True
|
||||
except (json.JSONDecodeError, IOError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def add(self, content: str) -> bool:
|
||||
"""Add content hash. Returns True if added (was new), False if duplicate."""
|
||||
h = self.compute_hash(content)
|
||||
hashes = self._load_today()
|
||||
|
||||
if h in hashes:
|
||||
return False
|
||||
|
||||
hashes.add(h)
|
||||
self._save_today()
|
||||
return True
|
||||
|
||||
def add_batch(self, contents: list) -> int:
|
||||
"""Add multiple content hashes. Returns count of new hashes added."""
|
||||
hashes = self._load_today()
|
||||
new_count = 0
|
||||
|
||||
for content in contents:
|
||||
h = self.compute_hash(content)
|
||||
if h not in hashes:
|
||||
hashes.add(h)
|
||||
new_count += 1
|
||||
|
||||
if new_count > 0:
|
||||
self._save_today()
|
||||
|
||||
return new_count
|
||||
|
||||
def cleanup(self, keep_days: int = 7) -> int:
|
||||
"""
|
||||
Remove hash files older than keep_days.
|
||||
|
||||
Returns count of files removed.
|
||||
"""
|
||||
removed = 0
|
||||
cutoff = datetime.utcnow() - timedelta(days=keep_days)
|
||||
|
||||
for path in self.storage_dir.glob("*.json"):
|
||||
try:
|
||||
date_str = path.stem
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
if file_date < cutoff:
|
||||
path.unlink()
|
||||
removed += 1
|
||||
except ValueError:
|
||||
# Not a date-named file, skip
|
||||
continue
|
||||
|
||||
return removed
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Get dedup statistics."""
|
||||
total_hashes = 0
|
||||
file_count = 0
|
||||
oldest = None
|
||||
newest = None
|
||||
|
||||
for path in self.storage_dir.glob("*.json"):
|
||||
try:
|
||||
date_str = path.stem
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
file_count += 1
|
||||
|
||||
if oldest is None or file_date < oldest:
|
||||
oldest = file_date
|
||||
if newest is None or file_date > newest:
|
||||
newest = file_date
|
||||
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
total_hashes += data.get("count", 0)
|
||||
except (ValueError, json.JSONDecodeError, IOError):
|
||||
continue
|
||||
|
||||
return {
|
||||
"file_count": file_count,
|
||||
"total_hashes": total_hashes,
|
||||
"oldest_file": oldest.strftime("%Y-%m-%d") if oldest else None,
|
||||
"newest_file": newest.strftime("%Y-%m-%d") if newest else None,
|
||||
"today_count": len(self._load_today()),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI for hash_dedup operations."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Hash dedup with bounded storage")
|
||||
parser.add_argument("--dir", default=".hashes", help="Hash storage directory")
|
||||
parser.add_argument("--cleanup", type=int, metavar="DAYS", help="Remove files older than N days")
|
||||
parser.add_argument("--stats", action="store_true", help="Show statistics")
|
||||
parser.add_argument("--check", type=str, help="Check if content hash exists")
|
||||
parser.add_argument("--add", type=str, help="Add content hash")
|
||||
parser.add_argument("--file", type=str, help="Add hashes from file (one per line)")
|
||||
args = parser.parse_args()
|
||||
|
||||
dedup = HashDedup(args.dir)
|
||||
|
||||
if args.cleanup is not None:
|
||||
removed = dedup.cleanup(keep_days=args.cleanup)
|
||||
print(f"Cleaned up {removed} files older than {args.cleanup} days")
|
||||
|
||||
if args.stats:
|
||||
stats = dedup.stats()
|
||||
print(f"Hash Dedup Statistics:")
|
||||
print(f" Files: {stats['file_count']}")
|
||||
print(f" Total hashes: {stats['total_hashes']}")
|
||||
print(f" Today: {stats['today_count']}")
|
||||
print(f" Date range: {stats['oldest_file']} to {stats['newest_file']}")
|
||||
|
||||
if args.check:
|
||||
if dedup.is_duplicate_any(args.check):
|
||||
print("DUPLICATE")
|
||||
else:
|
||||
print("NEW")
|
||||
|
||||
if args.add:
|
||||
if dedup.add(args.add):
|
||||
print(f"Added: {dedup.compute_hash(args.add)}")
|
||||
else:
|
||||
print("Already exists")
|
||||
|
||||
if args.file:
|
||||
with open(args.file) as f:
|
||||
lines = [l.strip() for l in f if l.strip()]
|
||||
added = dedup.add_batch(lines)
|
||||
print(f"Added {added}/{len(lines)} new hashes")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,139 +0,0 @@
|
||||
"""
|
||||
Tests for scripts/hash_dedup.py — Bounded hash deduplication.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||
from hash_dedup import HashDedup
|
||||
|
||||
|
||||
class TestHashDedup(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.dedup = HashDedup(self.tmpdir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def test_compute_hash(self):
|
||||
h = HashDedup.compute_hash("test content")
|
||||
self.assertEqual(len(h), 64) # SHA-256 hex
|
||||
self.assertTrue(all(c in '0123456789abcdef' for c in h))
|
||||
|
||||
def test_same_content_same_hash(self):
|
||||
h1 = HashDedup.compute_hash("hello")
|
||||
h2 = HashDedup.compute_hash("hello")
|
||||
self.assertEqual(h1, h2)
|
||||
|
||||
def test_different_content_different_hash(self):
|
||||
h1 = HashDedup.compute_hash("hello")
|
||||
h2 = HashDedup.compute_hash("world")
|
||||
self.assertNotEqual(h1, h2)
|
||||
|
||||
def test_add_new(self):
|
||||
result = self.dedup.add("new content")
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_add_duplicate(self):
|
||||
self.dedup.add("content")
|
||||
result = self.dedup.add("content")
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_is_duplicate_false(self):
|
||||
self.assertFalse(self.dedup.is_duplicate("unknown"))
|
||||
|
||||
def test_is_duplicate_true(self):
|
||||
self.dedup.add("known content")
|
||||
self.assertTrue(self.dedup.is_duplicate("known content"))
|
||||
|
||||
def test_add_batch(self):
|
||||
items = ["a", "b", "c"]
|
||||
added = self.dedup.add_batch(items)
|
||||
self.assertEqual(added, 3)
|
||||
|
||||
def test_add_batch_deduplicates(self):
|
||||
items = ["a", "b", "a", "c", "b"]
|
||||
added = self.dedup.add_batch(items)
|
||||
self.assertEqual(added, 3)
|
||||
|
||||
def test_creates_date_file(self):
|
||||
self.dedup.add("test")
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
path = Path(self.tmpdir) / f"{today}.json"
|
||||
self.assertTrue(path.exists())
|
||||
|
||||
def test_file_format(self):
|
||||
self.dedup.add("test")
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
path = Path(self.tmpdir) / f"{today}.json"
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
self.assertEqual(data["date"], today)
|
||||
self.assertEqual(data["count"], 1)
|
||||
self.assertEqual(len(data["hashes"]), 1)
|
||||
|
||||
def test_cleanup_removes_old(self):
|
||||
# Create fake old file
|
||||
old_date = (datetime.utcnow() - timedelta(days=10)).strftime("%Y-%m-%d")
|
||||
old_path = Path(self.tmpdir) / f"{old_date}.json"
|
||||
with open(old_path, 'w') as f:
|
||||
json.dump({"date": old_date, "count": 0, "hashes": []}, f)
|
||||
|
||||
removed = self.dedup.cleanup(keep_days=7)
|
||||
self.assertEqual(removed, 1)
|
||||
self.assertFalse(old_path.exists())
|
||||
|
||||
def test_cleanup_keeps_recent(self):
|
||||
recent_date = (datetime.utcnow() - timedelta(days=3)).strftime("%Y-%m-%d")
|
||||
recent_path = Path(self.tmpdir) / f"{recent_date}.json"
|
||||
with open(recent_path, 'w') as f:
|
||||
json.dump({"date": recent_date, "count": 0, "hashes": []}, f)
|
||||
|
||||
removed = self.dedup.cleanup(keep_days=7)
|
||||
self.assertEqual(removed, 0)
|
||||
self.assertTrue(recent_path.exists())
|
||||
|
||||
def test_cleanup_ignores_non_date_files(self):
|
||||
junk = Path(self.tmpdir) / "not-a-date.json"
|
||||
with open(junk, 'w') as f:
|
||||
f.write("{}")
|
||||
|
||||
removed = self.dedup.cleanup(keep_days=1)
|
||||
self.assertEqual(removed, 0)
|
||||
self.assertTrue(junk.exists())
|
||||
|
||||
def test_stats_empty(self):
|
||||
stats = self.dedup.stats()
|
||||
self.assertEqual(stats["file_count"], 0)
|
||||
self.assertEqual(stats["total_hashes"], 0)
|
||||
|
||||
def test_stats_with_data(self):
|
||||
self.dedup.add("one")
|
||||
self.dedup.add("two")
|
||||
stats = self.dedup.stats()
|
||||
self.assertEqual(stats["file_count"], 1)
|
||||
self.assertEqual(stats["total_hashes"], 2)
|
||||
self.assertEqual(stats["today_count"], 2)
|
||||
|
||||
def test_max_hashes_per_file(self):
|
||||
dedup = HashDedup(self.tmpdir, max_hashes_per_file=3)
|
||||
for i in range(10):
|
||||
dedup.add(f"content-{i}")
|
||||
|
||||
today = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
path = Path(self.tmpdir) / f"{today}.json"
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
self.assertLessEqual(len(data["hashes"]), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
162
training/scripts/fix_training_indentation.py
Normal file
162
training/scripts/fix_training_indentation.py
Normal file
@@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix Training Data Code Block Indentation
|
||||
Issue #750: Training data code blocks have inconsistent indentation
|
||||
|
||||
Normalizes code block indentation in JSONL training data files using textwrap.dedent.
|
||||
|
||||
Usage:
|
||||
python3 fix_training_indentation.py --input data.jsonl
|
||||
python3 fix_training_indentation.py --input data.jsonl --output fixed.jsonl
|
||||
python3 fix_training_indentation.py --input data.jsonl --dry-run
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def fix_code_block_indentation(text):
|
||||
"""
|
||||
Find code blocks in text and normalize their indentation.
|
||||
|
||||
Handles:
|
||||
- ```python ... ``` blocks
|
||||
- ```bash ... ``` blocks
|
||||
- ``` ... ``` blocks (no language)
|
||||
- Nested code blocks in JSON strings
|
||||
"""
|
||||
if not text or '```' not in text:
|
||||
return text, 0
|
||||
|
||||
fixes = 0
|
||||
result = text
|
||||
|
||||
# Pattern to match code blocks: ```language\n...code...\n```
|
||||
# Also handles cases where code block is indented
|
||||
code_block_pattern = re.compile(
|
||||
r'(```(?:\w+)?\n)(.*?)(```)',
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
def fix_block(match):
|
||||
nonlocal fixes
|
||||
opening = match.group(1) # ```python\n
|
||||
code = match.group(2) # The code content
|
||||
closing = match.group(3) # ```
|
||||
|
||||
if not code.strip():
|
||||
return match.group(0)
|
||||
|
||||
# Use textwrap.dedent to remove common leading whitespace
|
||||
dedented = textwrap.dedent(code)
|
||||
|
||||
# Also handle the case where first line has different indentation
|
||||
lines = dedented.split('\n')
|
||||
if lines:
|
||||
# Find minimum indentation (excluding empty lines)
|
||||
min_indent = float('inf')
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
indent = len(line) - len(line.lstrip())
|
||||
min_indent = min(min_indent, indent)
|
||||
|
||||
if min_indent > 0 and min_indent != float('inf'):
|
||||
# Remove the minimum indentation from all lines
|
||||
lines = [line[min_indent:] if line.strip() else line for line in lines]
|
||||
dedented = '\n'.join(lines)
|
||||
|
||||
if dedented != code:
|
||||
fixes += 1
|
||||
|
||||
return opening + dedented + closing
|
||||
|
||||
result = code_block_pattern.sub(fix_block, result)
|
||||
return result, fixes
|
||||
|
||||
|
||||
def process_jsonl_file(input_path, output_path=None, dry_run=False):
|
||||
"""Process a JSONL file and fix code block indentation."""
|
||||
input_path = Path(input_path)
|
||||
if output_path is None:
|
||||
output_path = input_path.with_suffix('.fixed.jsonl')
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_path} does not exist")
|
||||
return 0, 0
|
||||
|
||||
total_entries = 0
|
||||
total_fixes = 0
|
||||
entries_with_fixes = 0
|
||||
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
fixed_lines = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Warning: Line {i+1} is not valid JSON: {e}")
|
||||
fixed_lines.append(line)
|
||||
continue
|
||||
|
||||
total_entries += 1
|
||||
entry_fixes = 0
|
||||
|
||||
# Process all string fields in the entry
|
||||
for key in entry:
|
||||
if isinstance(entry[key], str):
|
||||
fixed_text, fixes = fix_code_block_indentation(entry[key])
|
||||
if fixes > 0:
|
||||
entry[key] = fixed_text
|
||||
entry_fixes += fixes
|
||||
|
||||
if entry_fixes > 0:
|
||||
entries_with_fixes += 1
|
||||
total_fixes += entry_fixes
|
||||
|
||||
fixed_lines.append(json.dumps(entry, ensure_ascii=False))
|
||||
|
||||
if dry_run:
|
||||
print(f"DRY RUN: Would fix {total_fixes} code blocks in {entries_with_fixes}/{total_entries} entries")
|
||||
return total_fixes, entries_with_fixes
|
||||
|
||||
# Write fixed data
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for line in fixed_lines:
|
||||
f.write(line + '\n')
|
||||
|
||||
print(f"Fixed {total_fixes} code blocks in {entries_with_fixes}/{total_entries} entries")
|
||||
print(f"Output: {output_path}")
|
||||
|
||||
return total_fixes, entries_with_fixes
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Fix training data code block indentation')
|
||||
parser.add_argument('--input', required=True, help='Input JSONL file')
|
||||
parser.add_argument('--output', default=None, help='Output JSONL file (default: input.fixed.jsonl)')
|
||||
parser.add_argument('--dry-run', action='store_true', help='Show what would be fixed without writing')
|
||||
args = parser.parse_args()
|
||||
|
||||
fixes, entries = process_jsonl_file(args.input, args.output, args.dry_run)
|
||||
|
||||
if fixes == 0:
|
||||
print("No fixes needed - code blocks are properly indented")
|
||||
elif not args.dry_run:
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
121
training/tests/test_fix_training_indentation.py
Normal file
121
training/tests/test_fix_training_indentation.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for fix_training_indentation.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Import the module
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from fix_training_indentation import fix_code_block_indentation, process_jsonl_file
|
||||
|
||||
|
||||
def test_fix_code_block_indentation():
|
||||
"""Test code block indentation fixing."""
|
||||
|
||||
# Test 1: Python code block with extra indentation
|
||||
text1 = """Here is some code:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/users/{user_id}")
|
||||
def get_user(user_id: int):
|
||||
return {"user_id": user_id}
|
||||
```
|
||||
"""
|
||||
fixed1, fixes1 = fix_code_block_indentation(text1)
|
||||
assert fixes1 == 1, f"Expected 1 fix, got {fixes1}"
|
||||
assert "from fastapi import FastAPI" in fixed1
|
||||
# Check that indentation is normalized
|
||||
lines = fixed1.split("\n")
|
||||
for line in lines:
|
||||
if "from fastapi" in line:
|
||||
assert line.startswith("from"), f"First line should not have leading spaces: {repr(line)}"
|
||||
break
|
||||
|
||||
# Test 2: Bash code block
|
||||
text2 = """Run these commands:
|
||||
```bash
|
||||
python3 script.py
|
||||
git commit -m "fix"
|
||||
```
|
||||
"""
|
||||
fixed2, fixes2 = fix_code_block_indentation(text2)
|
||||
assert fixes2 == 1, f"Expected 1 fix, got {fixes2}"
|
||||
|
||||
# Test 3: No code block
|
||||
text3 = "This is plain text with no code blocks."
|
||||
fixed3, fixes3 = fix_code_block_indentation(text3)
|
||||
assert fixes3 == 0, f"Expected 0 fixes, got {fixes3}"
|
||||
assert fixed3 == text3
|
||||
|
||||
# Test 4: Empty code block
|
||||
text4 = """Empty:
|
||||
```
|
||||
```
|
||||
"""
|
||||
fixed4, fixes4 = fix_code_block_indentation(text4)
|
||||
assert fixes4 == 0, f"Expected 0 fixes for empty block, got {fixes4}"
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
def test_process_jsonl_file():
|
||||
"""Test processing a JSONL file."""
|
||||
|
||||
# Create test data
|
||||
test_data = [
|
||||
{
|
||||
"prompt": "Write a function",
|
||||
"chosen": "```python\ndef hello():\n print('hello')\n```",
|
||||
"rejected": ""
|
||||
},
|
||||
{
|
||||
"prompt": "Run command",
|
||||
"chosen": "```bash\necho 'test'\n```",
|
||||
"rejected": ""
|
||||
}
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
for entry in test_data:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
input_path = f.name
|
||||
|
||||
try:
|
||||
fixes, entries = process_jsonl_file(input_path, dry_run=True)
|
||||
print(f"Dry run: {fixes} fixes in {entries} entries")
|
||||
|
||||
# Actually fix
|
||||
output_path = input_path.replace('.jsonl', '.fixed.jsonl')
|
||||
fixes, entries = process_jsonl_file(input_path, output_path)
|
||||
print(f"Fixed: {fixes} fixes in {entries} entries")
|
||||
|
||||
# Verify output
|
||||
with open(output_path, 'r') as f:
|
||||
for line in f:
|
||||
entry = json.loads(line)
|
||||
if '```python' in entry.get('chosen', ''):
|
||||
# Check indentation
|
||||
code = entry['chosen']
|
||||
lines = code.split("\n")
|
||||
for line in lines:
|
||||
if "def hello" in line:
|
||||
assert not line.startswith(" "), f"Code should not have extra indentation: {repr(line)}"
|
||||
|
||||
print("JSONL processing test passed!")
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
Path(input_path.replace('.jsonl', '.fixed.jsonl')).unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fix_code_block_indentation()
|
||||
test_process_jsonl_file()
|
||||
print("\nAll tests passed!")
|
||||
Reference in New Issue
Block a user