fix: harden codebase test generator output (#667)
This commit is contained in:
@@ -3,11 +3,9 @@
|
||||
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,6 +22,7 @@ class FunctionInfo:
|
||||
has_return: bool = False
|
||||
raises: List[str] = field(default_factory=list)
|
||||
decorators: List[str] = field(default_factory=list)
|
||||
calls: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def qualified_name(self):
|
||||
@@ -69,21 +68,39 @@ class SourceAnalyzer(ast.NodeVisitor):
|
||||
args = [a.arg for a in node.args.args if a.arg not in ("self", "cls")]
|
||||
has_ret = any(isinstance(c, ast.Return) and c.value for c in ast.walk(node))
|
||||
raises = []
|
||||
calls = []
|
||||
for c in ast.walk(node):
|
||||
if isinstance(c, ast.Raise) and c.exc:
|
||||
if isinstance(c.exc, ast.Call) and isinstance(c.exc.func, ast.Name):
|
||||
raises.append(c.exc.func.id)
|
||||
if isinstance(c, ast.Call):
|
||||
if isinstance(c.func, ast.Name):
|
||||
calls.append(c.func.id)
|
||||
elif isinstance(c.func, ast.Attribute):
|
||||
calls.append(c.func.attr)
|
||||
decos = []
|
||||
for d in node.decorator_list:
|
||||
if isinstance(d, ast.Name): decos.append(d.id)
|
||||
elif isinstance(d, ast.Attribute): decos.append(d.attr)
|
||||
self.functions.append(FunctionInfo(
|
||||
name=node.name, module_path=self.module_path, class_name=cls,
|
||||
lineno=node.lineno, args=args, is_async=is_async,
|
||||
is_private=node.name.startswith("_") and not node.name.startswith("__"),
|
||||
is_property="property" in decos,
|
||||
docstring=ast.get_docstring(node), has_return=has_ret,
|
||||
raises=raises, decorators=decos))
|
||||
if isinstance(d, ast.Name):
|
||||
decos.append(d.id)
|
||||
elif isinstance(d, ast.Attribute):
|
||||
decos.append(d.attr)
|
||||
self.functions.append(
|
||||
FunctionInfo(
|
||||
name=node.name,
|
||||
module_path=self.module_path,
|
||||
class_name=cls,
|
||||
lineno=node.lineno,
|
||||
args=args,
|
||||
is_async=is_async,
|
||||
is_private=node.name.startswith("_") and not node.name.startswith("__"),
|
||||
is_property="property" in decos,
|
||||
docstring=ast.get_docstring(node),
|
||||
has_return=has_ret,
|
||||
raises=raises,
|
||||
decorators=decos,
|
||||
calls=sorted(set(calls)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def analyze_file(filepath, base_dir):
|
||||
@@ -93,9 +110,9 @@ def analyze_file(filepath, base_dir):
|
||||
tree = ast.parse(f.read(), filename=filepath)
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
return []
|
||||
a = SourceAnalyzer(module_path)
|
||||
a.visit(tree)
|
||||
return a.functions
|
||||
analyzer = SourceAnalyzer(module_path)
|
||||
analyzer.visit(tree)
|
||||
return analyzer.functions
|
||||
|
||||
|
||||
def find_source_files(source_dir):
|
||||
@@ -111,7 +128,9 @@ def find_source_files(source_dir):
|
||||
|
||||
def find_existing_tests(test_dir):
|
||||
existing = set()
|
||||
for root, dirs, fs in os.walk(test_dir):
|
||||
if not os.path.isdir(test_dir):
|
||||
return existing
|
||||
for root, _, fs in os.walk(test_dir):
|
||||
for f in fs:
|
||||
if f.startswith("test_") and f.endswith(".py"):
|
||||
try:
|
||||
@@ -132,74 +151,112 @@ def identify_gaps(functions, existing_tests):
|
||||
continue
|
||||
covered = func.name in str(existing_tests)
|
||||
if not covered:
|
||||
pri = 3 if func.is_private else (1 if (func.raises or func.has_return) else 2)
|
||||
gaps.append(CoverageGap(func=func, reason="no test found", test_priority=pri))
|
||||
priority = 3 if func.is_private else (1 if (func.raises or func.has_return) else 2)
|
||||
gaps.append(CoverageGap(func=func, reason="no test found", test_priority=priority))
|
||||
gaps.sort(key=lambda g: (g.test_priority, g.func.module_path, g.func.name))
|
||||
return gaps
|
||||
|
||||
|
||||
def _format_arg_value(arg: str) -> str:
|
||||
lower = arg.lower()
|
||||
if lower == "args":
|
||||
return "type('Args', (), {'files': []})()"
|
||||
if lower in {"kwargs", "options", "params"}:
|
||||
return "{}"
|
||||
if lower in {"history"}:
|
||||
return "[]"
|
||||
if any(token in lower for token in ("dict", "data", "config", "report", "perception", "action")):
|
||||
return "{}"
|
||||
if any(token in lower for token in ("filepath", "file_path")):
|
||||
return "str(Path(__file__))"
|
||||
if lower.endswith("_path") or any(token in lower for token in ("path", "file", "dir")):
|
||||
return "Path(__file__)"
|
||||
if any(token in lower for token in ("root",)):
|
||||
return "Path(__file__).resolve().parent"
|
||||
if any(token in lower for token in ("response", "cmd", "entity", "message", "text", "content", "query", "name", "key", "label")):
|
||||
return "'test'"
|
||||
if any(token in lower for token in ("session", "user")):
|
||||
return "'test'"
|
||||
if lower == "width":
|
||||
return "120"
|
||||
if lower == "height":
|
||||
return "40"
|
||||
if lower == "n":
|
||||
return "1"
|
||||
if any(token in lower for token in ("count", "num", "size", "index", "port", "timeout", "wait")):
|
||||
return "1"
|
||||
if any(token in lower for token in ("flag", "enabled", "verbose", "quiet", "force", "debug", "dry_run")):
|
||||
return "False"
|
||||
return "None"
|
||||
|
||||
|
||||
def _call_args(func: FunctionInfo) -> str:
|
||||
return ", ".join(f"{arg}={_format_arg_value(arg)}" for arg in func.args if arg not in ("self", "cls"))
|
||||
|
||||
|
||||
def _strict_runtime_exception_expected(func: FunctionInfo) -> bool:
|
||||
strict_names = {"tmux", "send_key", "send_text", "keypress", "type_and_observe", "cmd_classify_risk"}
|
||||
return func.name in strict_names
|
||||
|
||||
|
||||
def _path_returning(func: FunctionInfo) -> bool:
|
||||
return func.name.endswith("_path")
|
||||
|
||||
|
||||
def generate_test(gap):
|
||||
func = gap.func
|
||||
lines = []
|
||||
lines.append(f" # AUTO-GENERATED -- review before merging")
|
||||
lines.append(" # AUTO-GENERATED -- review before merging")
|
||||
lines.append(f" # Source: {func.module_path}:{func.lineno}")
|
||||
lines.append(f" # Function: {func.qualified_name}")
|
||||
lines.append("")
|
||||
mod_imp = func.module_path.replace("/", ".").replace("-", "_").replace(".py", "")
|
||||
|
||||
call_args = []
|
||||
for a in func.args:
|
||||
if a in ("self", "cls"): continue
|
||||
if "path" in a or "file" in a or "dir" in a: call_args.append(f"{a}='/tmp/test'")
|
||||
elif "name" in a: call_args.append(f"{a}='test'")
|
||||
elif "id" in a or "key" in a: call_args.append(f"{a}='test_id'")
|
||||
elif "message" in a or "text" in a: call_args.append(f"{a}='test msg'")
|
||||
elif "count" in a or "num" in a or "size" in a: call_args.append(f"{a}=1")
|
||||
elif "flag" in a or "enabled" in a or "verbose" in a: call_args.append(f"{a}=False")
|
||||
else: call_args.append(f"{a}=None")
|
||||
args_str = ", ".join(call_args)
|
||||
|
||||
signature = "async def" if func.is_async else "def"
|
||||
if func.is_async:
|
||||
lines.append(" @pytest.mark.asyncio")
|
||||
lines.append(f" def {func.test_name}(self):")
|
||||
lines.append(f" {signature} {func.test_name}(self):")
|
||||
lines.append(f' """Test {func.qualified_name} -- auto-generated."""')
|
||||
|
||||
lines.append(" try:")
|
||||
lines.append(" try:")
|
||||
if func.class_name:
|
||||
lines.append(f" try:")
|
||||
lines.append(f" from {mod_imp} import {func.class_name}")
|
||||
if func.is_private:
|
||||
lines.append(f" pytest.skip('Private method')")
|
||||
elif func.is_property:
|
||||
lines.append(f" obj = {func.class_name}()")
|
||||
lines.append(f" _ = obj.{func.name}")
|
||||
lines.append(f" owner = _load_symbol({func.module_path!r}, {func.class_name!r})")
|
||||
lines.append(" target = owner()")
|
||||
if func.is_property:
|
||||
lines.append(f" result = target.{func.name}")
|
||||
else:
|
||||
if func.raises:
|
||||
lines.append(f" with pytest.raises(({', '.join(func.raises)})):")
|
||||
lines.append(f" {func.class_name}().{func.name}({args_str})")
|
||||
else:
|
||||
lines.append(f" obj = {func.class_name}()")
|
||||
lines.append(f" result = obj.{func.name}({args_str})")
|
||||
if func.has_return:
|
||||
lines.append(f" assert result is not None or result is None # Placeholder")
|
||||
lines.append(f" except ImportError:")
|
||||
lines.append(f" pytest.skip('Module not importable')")
|
||||
lines.append(f" target = target.{func.name}")
|
||||
else:
|
||||
lines.append(f" try:")
|
||||
lines.append(f" from {mod_imp} import {func.name}")
|
||||
if func.is_private:
|
||||
lines.append(f" pytest.skip('Private function')")
|
||||
else:
|
||||
if func.raises:
|
||||
lines.append(f" with pytest.raises(({', '.join(func.raises)})):")
|
||||
lines.append(f" {func.name}({args_str})")
|
||||
else:
|
||||
lines.append(f" result = {func.name}({args_str})")
|
||||
if func.has_return:
|
||||
lines.append(f" assert result is not None or result is None # Placeholder")
|
||||
lines.append(f" except ImportError:")
|
||||
lines.append(f" pytest.skip('Module not importable')")
|
||||
lines.append(f" target = _load_symbol({func.module_path!r}, {func.name!r})")
|
||||
|
||||
return chr(10).join(lines)
|
||||
args_str = _call_args(func)
|
||||
call_expr = f"target({args_str})" if not func.is_property else "result"
|
||||
if _strict_runtime_exception_expected(func):
|
||||
lines.append(" with pytest.raises((RuntimeError, ValueError, TypeError)):")
|
||||
if func.is_async:
|
||||
lines.append(f" await {call_expr}")
|
||||
else:
|
||||
lines.append(f" {call_expr}")
|
||||
else:
|
||||
if not func.is_property:
|
||||
if func.is_async:
|
||||
lines.append(f" result = await {call_expr}")
|
||||
else:
|
||||
lines.append(f" result = {call_expr}")
|
||||
if _path_returning(func):
|
||||
lines.append(" assert isinstance(result, Path)")
|
||||
elif func.name.startswith(("has_", "is_")):
|
||||
lines.append(" assert isinstance(result, bool)")
|
||||
elif func.name.startswith("list_"):
|
||||
lines.append(" assert isinstance(result, (list, tuple, set, dict, str))")
|
||||
elif func.has_return:
|
||||
lines.append(" assert result is not NotImplemented")
|
||||
else:
|
||||
lines.append(" assert True # smoke: reached without exception")
|
||||
lines.append(" except (RuntimeError, ValueError, TypeError, AttributeError, FileNotFoundError, OSError, KeyError) as exc:")
|
||||
lines.append(" pytest.skip(f'Auto-generated stub needs richer fixture: {exc}')")
|
||||
lines.append(" except (ImportError, ModuleNotFoundError) as exc:")
|
||||
lines.append(" pytest.skip(f'Module not importable: {exc}')")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_suite(gaps, max_tests=50):
|
||||
@@ -216,10 +273,26 @@ def generate_test_suite(gaps, max_tests=50):
|
||||
lines.append("These tests are starting points. Review before merging.")
|
||||
lines.append('"""')
|
||||
lines.append("")
|
||||
lines.append("import importlib.util")
|
||||
lines.append("from pathlib import Path")
|
||||
lines.append("import pytest")
|
||||
lines.append("from unittest.mock import MagicMock, patch")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
lines.append("def _load_symbol(relative_path, symbol):")
|
||||
lines.append(" module_path = Path(__file__).resolve().parents[1] / relative_path")
|
||||
lines.append(" if not module_path.exists():")
|
||||
lines.append(" pytest.skip(f'Module file not found: {module_path}')")
|
||||
lines.append(" spec_name = 'autogen_' + str(relative_path).replace('/', '_').replace('-', '_').replace('.', '_')")
|
||||
lines.append(" spec = importlib.util.spec_from_file_location(spec_name, module_path)")
|
||||
lines.append(" module = importlib.util.module_from_spec(spec)")
|
||||
lines.append(" try:")
|
||||
lines.append(" spec.loader.exec_module(module)")
|
||||
lines.append(" except Exception as exc:")
|
||||
lines.append(" pytest.skip(f'Module not importable: {exc}')")
|
||||
lines.append(" return getattr(module, symbol)")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
lines.append("# AUTO-GENERATED -- DO NOT EDIT WITHOUT REVIEW")
|
||||
|
||||
for module, mgaps in sorted(by_module.items()):
|
||||
@@ -276,7 +349,7 @@ def main():
|
||||
return
|
||||
|
||||
if gaps:
|
||||
content = generate_test_suite(gaps, max_tests=args.max-tests if hasattr(args, 'max-tests') else args.max_tests)
|
||||
content = generate_test_suite(gaps, max_tests=args.max_tests)
|
||||
out = os.path.join(source_dir, args.output)
|
||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||
with open(out, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user