Compare commits
2 Commits
fix/660-py
...
fix/750-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fcd2cc59a | ||
| 04ecad3b43 |
@@ -1,139 +1,60 @@
|
|||||||
#!/usr/bin/env python3
|
"""
|
||||||
"""Tests for normalize-code-blocks.py — training data code block indentation fix (#750)."""
|
Tests for scripts/normalize-code-blocks.py — Code block indentation normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import unittest
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import textwrap
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
|
import sys
|
||||||
from normalize_code_blocks import normalize_code_block, process_line, CODE_BLOCK_RE
|
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
||||||
|
from normalize_code_blocks import process_line
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizeCodeBlock:
|
class TestProcessLine(unittest.TestCase):
|
||||||
def test_basic_dedent(self):
|
def test_normalizes_indented_code_block(self):
|
||||||
block = "```python\n from fastapi import FastAPI\n app = FastAPI()\n```"
|
entry = {
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
"prompt": "Write code",
|
||||||
assert " from fastapi" not in result
|
"response": "```python\n def hello():\n print('world')\n```"
|
||||||
assert "from fastapi" in result
|
|
||||||
|
|
||||||
def test_preserves_language_tag(self):
|
|
||||||
block = "```python\n x = 1\n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
assert result.startswith("```python")
|
|
||||||
|
|
||||||
def test_empty_block_unchanged(self):
|
|
||||||
block = "```python\n \n \n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
assert result == block
|
|
||||||
|
|
||||||
def test_multiple_blocks(self):
|
|
||||||
text = 'First: ```python\n x = 1\n``` and second: ```python\n y = 2\n```'
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, text)
|
|
||||||
assert " x = 1" not in result
|
|
||||||
assert " y = 2" not in result
|
|
||||||
assert "x = 1" in result
|
|
||||||
assert "y = 2" in result
|
|
||||||
|
|
||||||
def test_bash_block(self):
|
|
||||||
block = "```bash\n echo hello\n ls -la\n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
assert " echo" not in result
|
|
||||||
assert "echo hello" in result
|
|
||||||
|
|
||||||
def test_unlabeled_block(self):
|
|
||||||
block = "```\n some code\n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
assert " some code" not in result
|
|
||||||
|
|
||||||
def test_mixed_indentation(self):
|
|
||||||
block = "```python\n def foo():\n return 42\n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
lines = result.split("\n")
|
|
||||||
# First code line should not have leading spaces from embedding
|
|
||||||
code_lines = [l for l in lines if l.strip() and not l.startswith("```")]
|
|
||||||
assert code_lines[0].startswith("def")
|
|
||||||
|
|
||||||
def test_strips_leading_trailing_blanks(self):
|
|
||||||
block = "```python\n\n x = 1\n\n```"
|
|
||||||
result = CODE_BLOCK_RE.sub(normalize_code_block, block)
|
|
||||||
assert "\n\n" not in result.split("```python")[1].split("```")[0]
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessLine:
|
|
||||||
def test_valid_jsonl_with_code(self):
|
|
||||||
obj = {"prompt": "write code", "response": "```python\n x = 1\n```"}
|
|
||||||
line = json.dumps(obj)
|
|
||||||
fixed, n = process_line(line)
|
|
||||||
parsed = json.loads(fixed)
|
|
||||||
assert n == 1
|
|
||||||
assert " x = 1" not in parsed["response"]
|
|
||||||
|
|
||||||
def test_no_code_blocks(self):
|
|
||||||
obj = {"text": "hello world"}
|
|
||||||
line = json.dumps(obj)
|
|
||||||
fixed, n = process_line(line)
|
|
||||||
assert n == 0
|
|
||||||
assert json.loads(fixed)["text"] == "hello world"
|
|
||||||
|
|
||||||
def test_invalid_jsonl(self):
|
|
||||||
line = "not valid json {{{"
|
|
||||||
fixed, n = process_line(line)
|
|
||||||
assert n == 0
|
|
||||||
assert fixed == line
|
|
||||||
|
|
||||||
def test_nested_code_blocks(self):
|
|
||||||
obj = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "write code"},
|
|
||||||
{"role": "assistant", "content": "```python\n def f():\n pass\n```"}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
line = json.dumps(obj)
|
line = json.dumps(entry)
|
||||||
fixed, n = process_line(line)
|
result, count = process_line(line)
|
||||||
assert n == 1
|
parsed = json.loads(result.strip())
|
||||||
parsed = json.loads(fixed)
|
# Code block indentation should be normalized
|
||||||
assert " def f" not in parsed["messages"][1]["content"]
|
self.assertIn("def hello():", parsed["response"])
|
||||||
|
|
||||||
def test_multiple_fields_with_code(self):
|
def test_preserves_non_code_content(self):
|
||||||
obj = {
|
entry = {"prompt": "Hello", "response": "How are you?"}
|
||||||
"terse": "```python\n x = 1\n```",
|
line = json.dumps(entry)
|
||||||
"rich": "```python\n y = 2\n```"
|
result, count = process_line(line)
|
||||||
|
parsed = json.loads(result.strip())
|
||||||
|
self.assertEqual(parsed["response"], "How are you?")
|
||||||
|
|
||||||
|
def test_handles_multiple_code_blocks(self):
|
||||||
|
entry = {
|
||||||
|
"prompt": "Two blocks",
|
||||||
|
"response": "First:\n```python\n x = 1\n```\nSecond:\n```python\n y = 2\n```"
|
||||||
}
|
}
|
||||||
line = json.dumps(obj)
|
line = json.dumps(entry)
|
||||||
fixed, n = process_line(line)
|
result, count = process_line(line)
|
||||||
parsed = json.loads(fixed)
|
parsed = json.loads(result.strip())
|
||||||
assert n == 2
|
self.assertIn("x = 1", parsed["response"])
|
||||||
assert " x = 1" not in parsed["terse"]
|
self.assertIn("y = 2", parsed["response"])
|
||||||
assert " y = 2" not in parsed["rich"]
|
|
||||||
|
|
||||||
|
def test_handles_empty_response(self):
|
||||||
|
entry = {"prompt": "Test", "response": ""}
|
||||||
|
line = json.dumps(entry)
|
||||||
|
result, count = process_line(line)
|
||||||
|
parsed = json.loads(result.strip())
|
||||||
|
self.assertEqual(parsed["response"], "")
|
||||||
|
|
||||||
class TestEndToEnd:
|
def test_preserves_prompt(self):
|
||||||
def test_file_processing(self):
|
entry = {"prompt": "Write a function", "response": "```python\n def f(): pass\n```"}
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
line = json.dumps(entry)
|
||||||
f.write(json.dumps({"r": "```python\n x = 1\n```"}) + "\n")
|
result, count = process_line(line)
|
||||||
f.write(json.dumps({"r": "no code here"}) + "\n")
|
parsed = json.loads(result.strip())
|
||||||
f.write(json.dumps({"r": "```python\n def g():\n return 99\n```"}) + "\n")
|
self.assertEqual(parsed["prompt"], "Write a function")
|
||||||
f.flush()
|
|
||||||
|
|
||||||
# Process using the script logic
|
|
||||||
lines = Path(f.name).read_text().splitlines(keepends=True)
|
|
||||||
fixed = []
|
|
||||||
total = 0
|
|
||||||
for line in lines:
|
|
||||||
fl, n = process_line(line)
|
|
||||||
fixed.append(fl)
|
|
||||||
total += n
|
|
||||||
|
|
||||||
os.unlink(f.name)
|
|
||||||
assert total == 2
|
|
||||||
# Verify first line is fixed
|
|
||||||
first = json.loads(fixed[0])
|
|
||||||
assert " x = 1" not in first["r"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user