Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
840214c8c0 |
@@ -3,11 +3,9 @@
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from typing import List, Optional
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -24,6 +22,7 @@ class FunctionInfo:
|
|||||||
has_return: bool = False
|
has_return: bool = False
|
||||||
raises: List[str] = field(default_factory=list)
|
raises: List[str] = field(default_factory=list)
|
||||||
decorators: List[str] = field(default_factory=list)
|
decorators: List[str] = field(default_factory=list)
|
||||||
|
calls: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def qualified_name(self):
|
def qualified_name(self):
|
||||||
@@ -69,21 +68,39 @@ class SourceAnalyzer(ast.NodeVisitor):
|
|||||||
args = [a.arg for a in node.args.args if a.arg not in ("self", "cls")]
|
args = [a.arg for a in node.args.args if a.arg not in ("self", "cls")]
|
||||||
has_ret = any(isinstance(c, ast.Return) and c.value for c in ast.walk(node))
|
has_ret = any(isinstance(c, ast.Return) and c.value for c in ast.walk(node))
|
||||||
raises = []
|
raises = []
|
||||||
|
calls = []
|
||||||
for c in ast.walk(node):
|
for c in ast.walk(node):
|
||||||
if isinstance(c, ast.Raise) and c.exc:
|
if isinstance(c, ast.Raise) and c.exc:
|
||||||
if isinstance(c.exc, ast.Call) and isinstance(c.exc.func, ast.Name):
|
if isinstance(c.exc, ast.Call) and isinstance(c.exc.func, ast.Name):
|
||||||
raises.append(c.exc.func.id)
|
raises.append(c.exc.func.id)
|
||||||
|
if isinstance(c, ast.Call):
|
||||||
|
if isinstance(c.func, ast.Name):
|
||||||
|
calls.append(c.func.id)
|
||||||
|
elif isinstance(c.func, ast.Attribute):
|
||||||
|
calls.append(c.func.attr)
|
||||||
decos = []
|
decos = []
|
||||||
for d in node.decorator_list:
|
for d in node.decorator_list:
|
||||||
if isinstance(d, ast.Name): decos.append(d.id)
|
if isinstance(d, ast.Name):
|
||||||
elif isinstance(d, ast.Attribute): decos.append(d.attr)
|
decos.append(d.id)
|
||||||
self.functions.append(FunctionInfo(
|
elif isinstance(d, ast.Attribute):
|
||||||
name=node.name, module_path=self.module_path, class_name=cls,
|
decos.append(d.attr)
|
||||||
lineno=node.lineno, args=args, is_async=is_async,
|
self.functions.append(
|
||||||
is_private=node.name.startswith("_") and not node.name.startswith("__"),
|
FunctionInfo(
|
||||||
is_property="property" in decos,
|
name=node.name,
|
||||||
docstring=ast.get_docstring(node), has_return=has_ret,
|
module_path=self.module_path,
|
||||||
raises=raises, decorators=decos))
|
class_name=cls,
|
||||||
|
lineno=node.lineno,
|
||||||
|
args=args,
|
||||||
|
is_async=is_async,
|
||||||
|
is_private=node.name.startswith("_") and not node.name.startswith("__"),
|
||||||
|
is_property="property" in decos,
|
||||||
|
docstring=ast.get_docstring(node),
|
||||||
|
has_return=has_ret,
|
||||||
|
raises=raises,
|
||||||
|
decorators=decos,
|
||||||
|
calls=sorted(set(calls)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def analyze_file(filepath, base_dir):
|
def analyze_file(filepath, base_dir):
|
||||||
@@ -93,9 +110,9 @@ def analyze_file(filepath, base_dir):
|
|||||||
tree = ast.parse(f.read(), filename=filepath)
|
tree = ast.parse(f.read(), filename=filepath)
|
||||||
except (SyntaxError, UnicodeDecodeError):
|
except (SyntaxError, UnicodeDecodeError):
|
||||||
return []
|
return []
|
||||||
a = SourceAnalyzer(module_path)
|
analyzer = SourceAnalyzer(module_path)
|
||||||
a.visit(tree)
|
analyzer.visit(tree)
|
||||||
return a.functions
|
return analyzer.functions
|
||||||
|
|
||||||
|
|
||||||
def find_source_files(source_dir):
|
def find_source_files(source_dir):
|
||||||
@@ -111,7 +128,9 @@ def find_source_files(source_dir):
|
|||||||
|
|
||||||
def find_existing_tests(test_dir):
|
def find_existing_tests(test_dir):
|
||||||
existing = set()
|
existing = set()
|
||||||
for root, dirs, fs in os.walk(test_dir):
|
if not os.path.isdir(test_dir):
|
||||||
|
return existing
|
||||||
|
for root, _, fs in os.walk(test_dir):
|
||||||
for f in fs:
|
for f in fs:
|
||||||
if f.startswith("test_") and f.endswith(".py"):
|
if f.startswith("test_") and f.endswith(".py"):
|
||||||
try:
|
try:
|
||||||
@@ -132,74 +151,112 @@ def identify_gaps(functions, existing_tests):
|
|||||||
continue
|
continue
|
||||||
covered = func.name in str(existing_tests)
|
covered = func.name in str(existing_tests)
|
||||||
if not covered:
|
if not covered:
|
||||||
pri = 3 if func.is_private else (1 if (func.raises or func.has_return) else 2)
|
priority = 3 if func.is_private else (1 if (func.raises or func.has_return) else 2)
|
||||||
gaps.append(CoverageGap(func=func, reason="no test found", test_priority=pri))
|
gaps.append(CoverageGap(func=func, reason="no test found", test_priority=priority))
|
||||||
gaps.sort(key=lambda g: (g.test_priority, g.func.module_path, g.func.name))
|
gaps.sort(key=lambda g: (g.test_priority, g.func.module_path, g.func.name))
|
||||||
return gaps
|
return gaps
|
||||||
|
|
||||||
|
|
||||||
|
def _format_arg_value(arg: str) -> str:
|
||||||
|
lower = arg.lower()
|
||||||
|
if lower == "args":
|
||||||
|
return "type('Args', (), {'files': []})()"
|
||||||
|
if lower in {"kwargs", "options", "params"}:
|
||||||
|
return "{}"
|
||||||
|
if lower in {"history"}:
|
||||||
|
return "[]"
|
||||||
|
if any(token in lower for token in ("dict", "data", "config", "report", "perception", "action")):
|
||||||
|
return "{}"
|
||||||
|
if any(token in lower for token in ("filepath", "file_path")):
|
||||||
|
return "str(Path(__file__))"
|
||||||
|
if lower.endswith("_path") or any(token in lower for token in ("path", "file", "dir")):
|
||||||
|
return "Path(__file__)"
|
||||||
|
if any(token in lower for token in ("root",)):
|
||||||
|
return "Path(__file__).resolve().parent"
|
||||||
|
if any(token in lower for token in ("response", "cmd", "entity", "message", "text", "content", "query", "name", "key", "label")):
|
||||||
|
return "'test'"
|
||||||
|
if any(token in lower for token in ("session", "user")):
|
||||||
|
return "'test'"
|
||||||
|
if lower == "width":
|
||||||
|
return "120"
|
||||||
|
if lower == "height":
|
||||||
|
return "40"
|
||||||
|
if lower == "n":
|
||||||
|
return "1"
|
||||||
|
if any(token in lower for token in ("count", "num", "size", "index", "port", "timeout", "wait")):
|
||||||
|
return "1"
|
||||||
|
if any(token in lower for token in ("flag", "enabled", "verbose", "quiet", "force", "debug", "dry_run")):
|
||||||
|
return "False"
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
|
||||||
|
def _call_args(func: FunctionInfo) -> str:
|
||||||
|
return ", ".join(f"{arg}={_format_arg_value(arg)}" for arg in func.args if arg not in ("self", "cls"))
|
||||||
|
|
||||||
|
|
||||||
|
def _strict_runtime_exception_expected(func: FunctionInfo) -> bool:
|
||||||
|
strict_names = {"tmux", "send_key", "send_text", "keypress", "type_and_observe", "cmd_classify_risk"}
|
||||||
|
return func.name in strict_names
|
||||||
|
|
||||||
|
|
||||||
|
def _path_returning(func: FunctionInfo) -> bool:
|
||||||
|
return func.name.endswith("_path")
|
||||||
|
|
||||||
|
|
||||||
def generate_test(gap):
|
def generate_test(gap):
|
||||||
func = gap.func
|
func = gap.func
|
||||||
lines = []
|
lines = []
|
||||||
lines.append(f" # AUTO-GENERATED -- review before merging")
|
lines.append(" # AUTO-GENERATED -- review before merging")
|
||||||
lines.append(f" # Source: {func.module_path}:{func.lineno}")
|
lines.append(f" # Source: {func.module_path}:{func.lineno}")
|
||||||
lines.append(f" # Function: {func.qualified_name}")
|
lines.append(f" # Function: {func.qualified_name}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
mod_imp = func.module_path.replace("/", ".").replace("-", "_").replace(".py", "")
|
|
||||||
|
|
||||||
call_args = []
|
|
||||||
for a in func.args:
|
|
||||||
if a in ("self", "cls"): continue
|
|
||||||
if "path" in a or "file" in a or "dir" in a: call_args.append(f"{a}='/tmp/test'")
|
|
||||||
elif "name" in a: call_args.append(f"{a}='test'")
|
|
||||||
elif "id" in a or "key" in a: call_args.append(f"{a}='test_id'")
|
|
||||||
elif "message" in a or "text" in a: call_args.append(f"{a}='test msg'")
|
|
||||||
elif "count" in a or "num" in a or "size" in a: call_args.append(f"{a}=1")
|
|
||||||
elif "flag" in a or "enabled" in a or "verbose" in a: call_args.append(f"{a}=False")
|
|
||||||
else: call_args.append(f"{a}=None")
|
|
||||||
args_str = ", ".join(call_args)
|
|
||||||
|
|
||||||
|
signature = "async def" if func.is_async else "def"
|
||||||
if func.is_async:
|
if func.is_async:
|
||||||
lines.append(" @pytest.mark.asyncio")
|
lines.append(" @pytest.mark.asyncio")
|
||||||
lines.append(f" def {func.test_name}(self):")
|
lines.append(f" {signature} {func.test_name}(self):")
|
||||||
lines.append(f' """Test {func.qualified_name} -- auto-generated."""')
|
lines.append(f' """Test {func.qualified_name} -- auto-generated."""')
|
||||||
|
lines.append(" try:")
|
||||||
|
lines.append(" try:")
|
||||||
if func.class_name:
|
if func.class_name:
|
||||||
lines.append(f" try:")
|
lines.append(f" owner = _load_symbol({func.module_path!r}, {func.class_name!r})")
|
||||||
lines.append(f" from {mod_imp} import {func.class_name}")
|
lines.append(" target = owner()")
|
||||||
if func.is_private:
|
if func.is_property:
|
||||||
lines.append(f" pytest.skip('Private method')")
|
lines.append(f" result = target.{func.name}")
|
||||||
elif func.is_property:
|
|
||||||
lines.append(f" obj = {func.class_name}()")
|
|
||||||
lines.append(f" _ = obj.{func.name}")
|
|
||||||
else:
|
else:
|
||||||
if func.raises:
|
lines.append(f" target = target.{func.name}")
|
||||||
lines.append(f" with pytest.raises(({', '.join(func.raises)})):")
|
|
||||||
lines.append(f" {func.class_name}().{func.name}({args_str})")
|
|
||||||
else:
|
|
||||||
lines.append(f" obj = {func.class_name}()")
|
|
||||||
lines.append(f" result = obj.{func.name}({args_str})")
|
|
||||||
if func.has_return:
|
|
||||||
lines.append(f" assert result is not None or result is None # Placeholder")
|
|
||||||
lines.append(f" except ImportError:")
|
|
||||||
lines.append(f" pytest.skip('Module not importable')")
|
|
||||||
else:
|
else:
|
||||||
lines.append(f" try:")
|
lines.append(f" target = _load_symbol({func.module_path!r}, {func.name!r})")
|
||||||
lines.append(f" from {mod_imp} import {func.name}")
|
|
||||||
if func.is_private:
|
|
||||||
lines.append(f" pytest.skip('Private function')")
|
|
||||||
else:
|
|
||||||
if func.raises:
|
|
||||||
lines.append(f" with pytest.raises(({', '.join(func.raises)})):")
|
|
||||||
lines.append(f" {func.name}({args_str})")
|
|
||||||
else:
|
|
||||||
lines.append(f" result = {func.name}({args_str})")
|
|
||||||
if func.has_return:
|
|
||||||
lines.append(f" assert result is not None or result is None # Placeholder")
|
|
||||||
lines.append(f" except ImportError:")
|
|
||||||
lines.append(f" pytest.skip('Module not importable')")
|
|
||||||
|
|
||||||
return chr(10).join(lines)
|
args_str = _call_args(func)
|
||||||
|
call_expr = f"target({args_str})" if not func.is_property else "result"
|
||||||
|
if _strict_runtime_exception_expected(func):
|
||||||
|
lines.append(" with pytest.raises((RuntimeError, ValueError, TypeError)):")
|
||||||
|
if func.is_async:
|
||||||
|
lines.append(f" await {call_expr}")
|
||||||
|
else:
|
||||||
|
lines.append(f" {call_expr}")
|
||||||
|
else:
|
||||||
|
if not func.is_property:
|
||||||
|
if func.is_async:
|
||||||
|
lines.append(f" result = await {call_expr}")
|
||||||
|
else:
|
||||||
|
lines.append(f" result = {call_expr}")
|
||||||
|
if _path_returning(func):
|
||||||
|
lines.append(" assert isinstance(result, Path)")
|
||||||
|
elif func.name.startswith(("has_", "is_")):
|
||||||
|
lines.append(" assert isinstance(result, bool)")
|
||||||
|
elif func.name.startswith("list_"):
|
||||||
|
lines.append(" assert isinstance(result, (list, tuple, set, dict, str))")
|
||||||
|
elif func.has_return:
|
||||||
|
lines.append(" assert result is not NotImplemented")
|
||||||
|
else:
|
||||||
|
lines.append(" assert True # smoke: reached without exception")
|
||||||
|
lines.append(" except (RuntimeError, ValueError, TypeError, AttributeError, FileNotFoundError, OSError, KeyError) as exc:")
|
||||||
|
lines.append(" pytest.skip(f'Auto-generated stub needs richer fixture: {exc}')")
|
||||||
|
lines.append(" except (ImportError, ModuleNotFoundError) as exc:")
|
||||||
|
lines.append(" pytest.skip(f'Module not importable: {exc}')")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def generate_test_suite(gaps, max_tests=50):
|
def generate_test_suite(gaps, max_tests=50):
|
||||||
@@ -216,10 +273,26 @@ def generate_test_suite(gaps, max_tests=50):
|
|||||||
lines.append("These tests are starting points. Review before merging.")
|
lines.append("These tests are starting points. Review before merging.")
|
||||||
lines.append('"""')
|
lines.append('"""')
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
lines.append("import importlib.util")
|
||||||
|
lines.append("from pathlib import Path")
|
||||||
lines.append("import pytest")
|
lines.append("import pytest")
|
||||||
lines.append("from unittest.mock import MagicMock, patch")
|
lines.append("from unittest.mock import MagicMock, patch")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
lines.append("def _load_symbol(relative_path, symbol):")
|
||||||
|
lines.append(" module_path = Path(__file__).resolve().parents[1] / relative_path")
|
||||||
|
lines.append(" if not module_path.exists():")
|
||||||
|
lines.append(" pytest.skip(f'Module file not found: {module_path}')")
|
||||||
|
lines.append(" spec_name = 'autogen_' + str(relative_path).replace('/', '_').replace('-', '_').replace('.', '_')")
|
||||||
|
lines.append(" spec = importlib.util.spec_from_file_location(spec_name, module_path)")
|
||||||
|
lines.append(" module = importlib.util.module_from_spec(spec)")
|
||||||
|
lines.append(" try:")
|
||||||
|
lines.append(" spec.loader.exec_module(module)")
|
||||||
|
lines.append(" except Exception as exc:")
|
||||||
|
lines.append(" pytest.skip(f'Module not importable: {exc}')")
|
||||||
|
lines.append(" return getattr(module, symbol)")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("")
|
||||||
lines.append("# AUTO-GENERATED -- DO NOT EDIT WITHOUT REVIEW")
|
lines.append("# AUTO-GENERATED -- DO NOT EDIT WITHOUT REVIEW")
|
||||||
|
|
||||||
for module, mgaps in sorted(by_module.items()):
|
for module, mgaps in sorted(by_module.items()):
|
||||||
@@ -276,7 +349,7 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
if gaps:
|
if gaps:
|
||||||
content = generate_test_suite(gaps, max_tests=args.max-tests if hasattr(args, 'max-tests') else args.max_tests)
|
content = generate_test_suite(gaps, max_tests=args.max_tests)
|
||||||
out = os.path.join(source_dir, args.output)
|
out = os.path.join(source_dir, args.output)
|
||||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||||
with open(out, "w") as f:
|
with open(out, "w") as f:
|
||||||
|
|||||||
55
tests/test_codebase_test_generator.py
Normal file
55
tests/test_codebase_test_generator.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import importlib.util
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
SCRIPT = ROOT / "scripts" / "codebase_test_generator.py"
|
||||||
|
|
||||||
|
|
||||||
|
def load_module():
|
||||||
|
spec = importlib.util.spec_from_file_location("codebase_test_generator", str(SCRIPT))
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_test_suite_uses_dynamic_loader_for_numbered_paths():
|
||||||
|
mod = load_module()
|
||||||
|
func = mod.FunctionInfo(
|
||||||
|
name="linkify",
|
||||||
|
module_path="reports/notebooklm/2026-03-27-hermes-openclaw/render_reports.py",
|
||||||
|
lineno=12,
|
||||||
|
args=["text"],
|
||||||
|
has_return=True,
|
||||||
|
)
|
||||||
|
gap = mod.CoverageGap(func=func, reason="no test found", test_priority=1)
|
||||||
|
|
||||||
|
suite = mod.generate_test_suite([gap], max_tests=1)
|
||||||
|
|
||||||
|
assert "import importlib.util" in suite
|
||||||
|
assert "_load_symbol(" in suite
|
||||||
|
assert "from reports.notebooklm" not in suite
|
||||||
|
assert "2026-03-27-hermes-openclaw/render_reports.py" in suite
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_test_handles_async_and_runtime_args_safely():
|
||||||
|
mod = load_module()
|
||||||
|
func = mod.FunctionInfo(
|
||||||
|
name="keypress",
|
||||||
|
module_path="angband/mcp_server.py",
|
||||||
|
lineno=200,
|
||||||
|
args=["key", "wait_ms", "session_name"],
|
||||||
|
is_async=True,
|
||||||
|
has_return=True,
|
||||||
|
calls=["send_key"],
|
||||||
|
)
|
||||||
|
gap = mod.CoverageGap(func=func, reason="no test found", test_priority=1)
|
||||||
|
|
||||||
|
test_code = mod.generate_test(gap)
|
||||||
|
|
||||||
|
assert "@pytest.mark.asyncio" in test_code
|
||||||
|
assert "async def" in test_code
|
||||||
|
assert "await target(" in test_code
|
||||||
|
assert "key='test'" in test_code
|
||||||
|
assert "wait_ms=1" in test_code
|
||||||
|
assert "session_name='test'" in test_code
|
||||||
|
assert "pytest.raises((RuntimeError, ValueError, TypeError))" in test_code
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user