Compare commits
2 Commits
step35/443
...
issue-750-
| Author | SHA1 | Date | |
|---|---|---|---|
| cb1408aafb | |||
| 4b0cbd123e |
162
training/scripts/fix_training_indentation.py
Normal file
162
training/scripts/fix_training_indentation.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fix Training Data Code Block Indentation
|
||||||
|
Issue #750: Training data code blocks have inconsistent indentation
|
||||||
|
|
||||||
|
Normalizes code block indentation in JSONL training data files using textwrap.dedent.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 fix_training_indentation.py --input data.jsonl
|
||||||
|
python3 fix_training_indentation.py --input data.jsonl --output fixed.jsonl
|
||||||
|
python3 fix_training_indentation.py --input data.jsonl --dry-run
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def fix_code_block_indentation(text):
|
||||||
|
"""
|
||||||
|
Find code blocks in text and normalize their indentation.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- ```python ... ``` blocks
|
||||||
|
- ```bash ... ``` blocks
|
||||||
|
- ``` ... ``` blocks (no language)
|
||||||
|
- Nested code blocks in JSON strings
|
||||||
|
"""
|
||||||
|
if not text or '```' not in text:
|
||||||
|
return text, 0
|
||||||
|
|
||||||
|
fixes = 0
|
||||||
|
result = text
|
||||||
|
|
||||||
|
# Pattern to match code blocks: ```language\n...code...\n```
|
||||||
|
# Also handles cases where code block is indented
|
||||||
|
code_block_pattern = re.compile(
|
||||||
|
r'(```(?:\w+)?\n)(.*?)(```)',
|
||||||
|
re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
def fix_block(match):
|
||||||
|
nonlocal fixes
|
||||||
|
opening = match.group(1) # ```python\n
|
||||||
|
code = match.group(2) # The code content
|
||||||
|
closing = match.group(3) # ```
|
||||||
|
|
||||||
|
if not code.strip():
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
# Use textwrap.dedent to remove common leading whitespace
|
||||||
|
dedented = textwrap.dedent(code)
|
||||||
|
|
||||||
|
# Also handle the case where first line has different indentation
|
||||||
|
lines = dedented.split('\n')
|
||||||
|
if lines:
|
||||||
|
# Find minimum indentation (excluding empty lines)
|
||||||
|
min_indent = float('inf')
|
||||||
|
for line in lines:
|
||||||
|
if line.strip():
|
||||||
|
indent = len(line) - len(line.lstrip())
|
||||||
|
min_indent = min(min_indent, indent)
|
||||||
|
|
||||||
|
if min_indent > 0 and min_indent != float('inf'):
|
||||||
|
# Remove the minimum indentation from all lines
|
||||||
|
lines = [line[min_indent:] if line.strip() else line for line in lines]
|
||||||
|
dedented = '\n'.join(lines)
|
||||||
|
|
||||||
|
if dedented != code:
|
||||||
|
fixes += 1
|
||||||
|
|
||||||
|
return opening + dedented + closing
|
||||||
|
|
||||||
|
result = code_block_pattern.sub(fix_block, result)
|
||||||
|
return result, fixes
|
||||||
|
|
||||||
|
|
||||||
|
def process_jsonl_file(input_path, output_path=None, dry_run=False):
|
||||||
|
"""Process a JSONL file and fix code block indentation."""
|
||||||
|
input_path = Path(input_path)
|
||||||
|
if output_path is None:
|
||||||
|
output_path = input_path.with_suffix('.fixed.jsonl')
|
||||||
|
else:
|
||||||
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
if not input_path.exists():
|
||||||
|
print(f"Error: {input_path} does not exist")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
total_entries = 0
|
||||||
|
total_fixes = 0
|
||||||
|
entries_with_fixes = 0
|
||||||
|
|
||||||
|
with open(input_path, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
fixed_lines = []
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Warning: Line {i+1} is not valid JSON: {e}")
|
||||||
|
fixed_lines.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_entries += 1
|
||||||
|
entry_fixes = 0
|
||||||
|
|
||||||
|
# Process all string fields in the entry
|
||||||
|
for key in entry:
|
||||||
|
if isinstance(entry[key], str):
|
||||||
|
fixed_text, fixes = fix_code_block_indentation(entry[key])
|
||||||
|
if fixes > 0:
|
||||||
|
entry[key] = fixed_text
|
||||||
|
entry_fixes += fixes
|
||||||
|
|
||||||
|
if entry_fixes > 0:
|
||||||
|
entries_with_fixes += 1
|
||||||
|
total_fixes += entry_fixes
|
||||||
|
|
||||||
|
fixed_lines.append(json.dumps(entry, ensure_ascii=False))
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print(f"DRY RUN: Would fix {total_fixes} code blocks in {entries_with_fixes}/{total_entries} entries")
|
||||||
|
return total_fixes, entries_with_fixes
|
||||||
|
|
||||||
|
# Write fixed data
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
for line in fixed_lines:
|
||||||
|
f.write(line + '\n')
|
||||||
|
|
||||||
|
print(f"Fixed {total_fixes} code blocks in {entries_with_fixes}/{total_entries} entries")
|
||||||
|
print(f"Output: {output_path}")
|
||||||
|
|
||||||
|
return total_fixes, entries_with_fixes
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='Fix training data code block indentation')
|
||||||
|
parser.add_argument('--input', required=True, help='Input JSONL file')
|
||||||
|
parser.add_argument('--output', default=None, help='Output JSONL file (default: input.fixed.jsonl)')
|
||||||
|
parser.add_argument('--dry-run', action='store_true', help='Show what would be fixed without writing')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
fixes, entries = process_jsonl_file(args.input, args.output, args.dry_run)
|
||||||
|
|
||||||
|
if fixes == 0:
|
||||||
|
print("No fixes needed - code blocks are properly indented")
|
||||||
|
elif not args.dry_run:
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
121
training/tests/test_fix_training_indentation.py
Normal file
121
training/tests/test_fix_training_indentation.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for fix_training_indentation.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Import the module
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from fix_training_indentation import fix_code_block_indentation, process_jsonl_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_code_block_indentation():
|
||||||
|
"""Test code block indentation fixing."""
|
||||||
|
|
||||||
|
# Test 1: Python code block with extra indentation
|
||||||
|
text1 = """Here is some code:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/users/{user_id}")
|
||||||
|
def get_user(user_id: int):
|
||||||
|
return {"user_id": user_id}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
fixed1, fixes1 = fix_code_block_indentation(text1)
|
||||||
|
assert fixes1 == 1, f"Expected 1 fix, got {fixes1}"
|
||||||
|
assert "from fastapi import FastAPI" in fixed1
|
||||||
|
# Check that indentation is normalized
|
||||||
|
lines = fixed1.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if "from fastapi" in line:
|
||||||
|
assert line.startswith("from"), f"First line should not have leading spaces: {repr(line)}"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Test 2: Bash code block
|
||||||
|
text2 = """Run these commands:
|
||||||
|
```bash
|
||||||
|
python3 script.py
|
||||||
|
git commit -m "fix"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
fixed2, fixes2 = fix_code_block_indentation(text2)
|
||||||
|
assert fixes2 == 1, f"Expected 1 fix, got {fixes2}"
|
||||||
|
|
||||||
|
# Test 3: No code block
|
||||||
|
text3 = "This is plain text with no code blocks."
|
||||||
|
fixed3, fixes3 = fix_code_block_indentation(text3)
|
||||||
|
assert fixes3 == 0, f"Expected 0 fixes, got {fixes3}"
|
||||||
|
assert fixed3 == text3
|
||||||
|
|
||||||
|
# Test 4: Empty code block
|
||||||
|
text4 = """Empty:
|
||||||
|
```
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
fixed4, fixes4 = fix_code_block_indentation(text4)
|
||||||
|
assert fixes4 == 0, f"Expected 0 fixes for empty block, got {fixes4}"
|
||||||
|
|
||||||
|
print("All tests passed!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_jsonl_file():
|
||||||
|
"""Test processing a JSONL file."""
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
test_data = [
|
||||||
|
{
|
||||||
|
"prompt": "Write a function",
|
||||||
|
"chosen": "```python\ndef hello():\n print('hello')\n```",
|
||||||
|
"rejected": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "Run command",
|
||||||
|
"chosen": "```bash\necho 'test'\n```",
|
||||||
|
"rejected": ""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||||
|
for entry in test_data:
|
||||||
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
input_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
fixes, entries = process_jsonl_file(input_path, dry_run=True)
|
||||||
|
print(f"Dry run: {fixes} fixes in {entries} entries")
|
||||||
|
|
||||||
|
# Actually fix
|
||||||
|
output_path = input_path.replace('.jsonl', '.fixed.jsonl')
|
||||||
|
fixes, entries = process_jsonl_file(input_path, output_path)
|
||||||
|
print(f"Fixed: {fixes} fixes in {entries} entries")
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
with open(output_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
entry = json.loads(line)
|
||||||
|
if '```python' in entry.get('chosen', ''):
|
||||||
|
# Check indentation
|
||||||
|
code = entry['chosen']
|
||||||
|
lines = code.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if "def hello" in line:
|
||||||
|
assert not line.startswith(" "), f"Code should not have extra indentation: {repr(line)}"
|
||||||
|
|
||||||
|
print("JSONL processing test passed!")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
Path(input_path).unlink(missing_ok=True)
|
||||||
|
Path(input_path.replace('.jsonl', '.fixed.jsonl')).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_fix_code_block_indentation()
|
||||||
|
test_process_jsonl_file()
|
||||||
|
print("\nAll tests passed!")
|
||||||
Reference in New Issue
Block a user