diff --git a/tests/test_normalize_code_blocks.py b/tests/test_normalize_code_blocks.py new file mode 100644 index 00000000..ed323eb0 --- /dev/null +++ b/tests/test_normalize_code_blocks.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +"""Tests for normalize-code-blocks.py — training data code block indentation fix (#750).""" + +import json +import os +import sys +import tempfile +import textwrap +from pathlib import Path + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts")) +from normalize_code_blocks import normalize_code_block, process_line, CODE_BLOCK_RE + + +class TestNormalizeCodeBlock: + def test_basic_dedent(self): + block = "```python\n from fastapi import FastAPI\n app = FastAPI()\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert " from fastapi" not in result + assert "from fastapi" in result + + def test_preserves_language_tag(self): + block = "```python\n x = 1\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert result.startswith("```python") + + def test_empty_block_unchanged(self): + block = "```python\n \n \n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert result == block + + def test_multiple_blocks(self): + text = 'First: ```python\n x = 1\n``` and second: ```python\n y = 2\n```' + result = CODE_BLOCK_RE.sub(normalize_code_block, text) + assert " x = 1" not in result + assert " y = 2" not in result + assert "x = 1" in result + assert "y = 2" in result + + def test_bash_block(self): + block = "```bash\n echo hello\n ls -la\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert " echo" not in result + assert "echo hello" in result + + def test_unlabeled_block(self): + block = "```\n some code\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert " some code" not in result + + def test_mixed_indentation(self): + block = "```python\n def foo():\n return 42\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + lines = result.split("\n") + # First code line should not have leading spaces from embedding + code_lines = [l for l in lines if l.strip() and not l.startswith("```")] + assert code_lines[0].startswith("def") + + def test_strips_leading_trailing_blanks(self): + block = "```python\n\n x = 1\n\n```" + result = CODE_BLOCK_RE.sub(normalize_code_block, block) + assert "\n\n" not in result.split("```python")[1].split("```")[0] + + +class TestProcessLine: + def test_valid_jsonl_with_code(self): + obj = {"prompt": "write code", "response": "```python\n x = 1\n```"} + line = json.dumps(obj) + fixed, n = process_line(line) + parsed = json.loads(fixed) + assert n == 1 + assert " x = 1" not in parsed["response"] + + def test_no_code_blocks(self): + obj = {"text": "hello world"} + line = json.dumps(obj) + fixed, n = process_line(line) + assert n == 0 + assert json.loads(fixed)["text"] == "hello world" + + def test_invalid_jsonl(self): + line = "not valid json {{{" + fixed, n = process_line(line) + assert n == 0 + assert fixed == line + + def test_nested_code_blocks(self): + obj = { + "messages": [ + {"role": "user", "content": "write code"}, + {"role": "assistant", "content": "```python\n def f():\n pass\n```"} + ] + } + line = json.dumps(obj) + fixed, n = process_line(line) + assert n == 1 + parsed = json.loads(fixed) + assert " def f" not in parsed["messages"][1]["content"] + + def test_multiple_fields_with_code(self): + obj = { + "terse": "```python\n x = 1\n```", + "rich": "```python\n y = 2\n```" + } + line = json.dumps(obj) + fixed, n = process_line(line) + parsed = json.loads(fixed) + assert n == 2 + assert " x = 1" not in parsed["terse"] + assert " y = 2" not in parsed["rich"] + + +class TestEndToEnd: + def test_file_processing(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + f.write(json.dumps({"r": "```python\n x = 1\n```"}) + "\n") + f.write(json.dumps({"r": "no code here"}) + "\n") + f.write(json.dumps({"r": "```python\n def g():\n return 99\n```"}) + "\n") + f.flush() + + # Process using the script logic + lines = Path(f.name).read_text().splitlines(keepends=True) + fixed = [] + total = 0 + for line in lines: + fl, n = process_line(line) + fixed.append(fl) + total += n + + os.unlink(f.name) + assert total == 2 + # Verify first line is fixed + first = json.loads(fixed[0]) + assert " x = 1" not in first["r"] + + +if __name__ == "__main__": + import unittest + unittest.main()