feat: test suite generator — fill coverage gaps (#667)
Some checks failed
Smoke Test / smoke (pull_request) Failing after 23s
Some checks failed
Smoke Test / smoke (pull_request) Failing after 23s
This commit is contained in:
298
scripts/test-suite-generator.py
Normal file
298
scripts/test-suite-generator.py
Normal file
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
test-suite-generator.py — Auto-generate pytest tests for uncovered functions.
|
||||
|
||||
Scans Python files in a repo, identifies public functions/methods without tests,
|
||||
and generates pytest test stubs with edge cases.
|
||||
|
||||
Usage:
|
||||
python3 scripts/test-suite-generator.py --repo ~/repos/my-project
|
||||
python3 scripts/test-suite-generator.py --repo ~/repos/my-project --output tests/test_genome_generated.py
|
||||
python3 scripts/test-suite-generator.py --repo ~/repos/my-project --dry-run
|
||||
"""
|
||||
|
||||
import ast
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
SKIP_DIRS = {"node_modules", ".git", "__pycache__", "venv", ".venv", "dist", "build", ".tox", ".mypy_cache"}
|
||||
SKIP_FILES = {"test_", "conftest", "setup", "__init__", "manage", "wsgi", "asgi"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionInfo:
|
||||
name: str
|
||||
file_path: str
|
||||
line_number: int
|
||||
class_name: Optional[str] = None
|
||||
args: list = field(default_factory=list)
|
||||
is_async: bool = False
|
||||
docstring: Optional[str] = None
|
||||
has_return: bool = False
|
||||
has_try_except: bool = False
|
||||
decorators: list = field(default_factory=list)
|
||||
|
||||
|
||||
def scan_file(filepath: Path) -> list[FunctionInfo]:
|
||||
"""Extract function definitions from a Python file."""
|
||||
try:
|
||||
source = filepath.read_text(errors="ignore")
|
||||
tree = ast.parse(source, filename=str(filepath))
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
return []
|
||||
|
||||
functions = []
|
||||
current_class = None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
current_class = node.name
|
||||
continue
|
||||
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
continue
|
||||
|
||||
# Skip private/internal functions
|
||||
if node.name.startswith("_") and node.name != "__init__":
|
||||
continue
|
||||
|
||||
# Extract args (skip 'self')
|
||||
args = [arg.arg for arg in node.args.args if arg.arg != "self"]
|
||||
|
||||
# Check for return statements
|
||||
has_return = any(
|
||||
isinstance(n, ast.Return) and n.value is not None
|
||||
for n in ast.walk(node)
|
||||
)
|
||||
|
||||
# Check for try/except
|
||||
has_try_except = any(
|
||||
isinstance(n, ast.Try)
|
||||
for n in ast.walk(node)
|
||||
)
|
||||
|
||||
# Extract decorators
|
||||
decorators = []
|
||||
for dec in node.decorator_list:
|
||||
if isinstance(dec, ast.Name):
|
||||
decorators.append(dec.id)
|
||||
elif isinstance(dec, ast.Attribute):
|
||||
decorators.append(dec.attr)
|
||||
|
||||
# Docstring
|
||||
docstring = ast.get_docstring(node)
|
||||
|
||||
functions.append(FunctionInfo(
|
||||
name=node.name,
|
||||
file_path=str(filepath),
|
||||
line_number=node.lineno,
|
||||
class_name=current_class,
|
||||
args=args,
|
||||
is_async=isinstance(node, ast.AsyncFunctionDef),
|
||||
docstring=docstring,
|
||||
has_return=has_return,
|
||||
has_try_except=has_try_except,
|
||||
decorators=decorators,
|
||||
))
|
||||
|
||||
current_class = None # Reset after function
|
||||
|
||||
return functions
|
||||
|
||||
|
||||
def scan_repo(repo_path: Path) -> list[FunctionInfo]:
|
||||
"""Scan all Python files in a repo."""
|
||||
all_functions = []
|
||||
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||
for fname in sorted(files):
|
||||
if not fname.endswith(".py"):
|
||||
continue
|
||||
if any(fname.startswith(skip) for skip in SKIP_FILES):
|
||||
continue
|
||||
|
||||
filepath = Path(root) / fname
|
||||
functions = scan_file(filepath)
|
||||
all_functions.extend(functions)
|
||||
|
||||
return all_functions
|
||||
|
||||
|
||||
def find_existing_tests(repo_path: Path) -> set[str]:
|
||||
"""Find function names that already have test coverage."""
|
||||
tested = set()
|
||||
|
||||
for root, dirs, files in os.walk(repo_path):
|
||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||
for fname in files:
|
||||
if not fname.startswith("test_") or not fname.endswith(".py"):
|
||||
continue
|
||||
filepath = Path(root) / fname
|
||||
try:
|
||||
source = filepath.read_text(errors="ignore")
|
||||
# Look for test_functionname patterns
|
||||
import re
|
||||
for match in re.finditer(r"def test_(\w+)", source):
|
||||
tested.add(match.group(1))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return tested
|
||||
|
||||
|
||||
def generate_test(func: FunctionInfo, repo_path: Path) -> str:
|
||||
"""Generate a pytest test for a function."""
|
||||
rel_path = Path(func.file_path).relative_to(repo_path)
|
||||
module_path = str(rel_path).replace("/", ".").replace(".py", "")
|
||||
|
||||
# Build import
|
||||
if func.class_name:
|
||||
import_line = f"from {module_path} import {func.class_name}"
|
||||
call_prefix = f"{func.class_name}()."
|
||||
else:
|
||||
import_line = f"from {module_path} import {func.name}"
|
||||
call_prefix = ""
|
||||
|
||||
func_qualified = f"{call_prefix}{func.name}"
|
||||
test_name = f"test_{func.name}"
|
||||
if func.class_name:
|
||||
test_name = f"test_{func.class_name}_{func.name}"
|
||||
|
||||
lines = []
|
||||
lines.append(f"")
|
||||
lines.append(f"def {test_name}():")
|
||||
if func.docstring:
|
||||
lines.append(f' """Auto-generated test for {func.name}."""')
|
||||
else:
|
||||
lines.append(f' """Auto-generated test."""')
|
||||
|
||||
# Build args
|
||||
arg_strs = []
|
||||
for arg in func.args:
|
||||
if arg in ("self", "cls"):
|
||||
continue
|
||||
if arg in ("text", "content", "message", "data", "input", "name", "path", "url"):
|
||||
arg_strs.append(f'"{arg}_value"')
|
||||
elif arg in ("count", "n", "limit", "size", "length", "index", "port", "timeout"):
|
||||
arg_strs.append("1")
|
||||
elif arg in ("items", "values", "args", "kwargs", "options", "params"):
|
||||
arg_strs.append("{}")
|
||||
elif arg in ("enabled", "verbose", "quiet", "force", "debug", "dry_run"):
|
||||
arg_strs.append("False")
|
||||
else:
|
||||
arg_strs.append("None")
|
||||
|
||||
args_joined = ", ".join(arg_strs)
|
||||
|
||||
# Generate test body
|
||||
if func.has_try_except:
|
||||
lines.append(f" # Function has try/except — test error handling")
|
||||
lines.append(f" try:")
|
||||
lines.append(f" result = {func_qualified}({args_joined})")
|
||||
lines.append(f" except Exception:")
|
||||
lines.append(f" pass # Expected behavior for edge case input")
|
||||
elif func.has_return:
|
||||
lines.append(f" result = {func_qualified}({args_joined})")
|
||||
lines.append(f" assert result is not None or result is None # Placeholder — refine with domain knowledge")
|
||||
elif func.is_async:
|
||||
lines.append(f" import asyncio")
|
||||
lines.append(f" # result = asyncio.run({func_qualified}({args_joined}))")
|
||||
lines.append(f" pass # Async function — needs proper test setup")
|
||||
else:
|
||||
lines.append(f" # Function returns None — test it doesn't raise")
|
||||
lines.append(f" {func_qualified}({args_joined})")
|
||||
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_file(functions: list[FunctionInfo], repo_path: Path) -> str:
|
||||
"""Generate a complete test file."""
|
||||
lines = []
|
||||
lines.append('"""Auto-generated tests by test-suite-generator.py.')
|
||||
lines.append("")
|
||||
lines.append("These tests are stubs — they verify functions don't crash on basic inputs.")
|
||||
lines.append("Refine with domain-specific assertions for real coverage.")
|
||||
lines.append("")
|
||||
lines.append(f"Generated: {len(functions)} test stubs")
|
||||
lines.append('"""')
|
||||
lines.append("")
|
||||
lines.append("import pytest")
|
||||
lines.append("")
|
||||
|
||||
# Collect unique imports
|
||||
imports = set()
|
||||
for func in functions:
|
||||
rel_path = Path(func.file_path).relative_to(repo_path)
|
||||
module_path = str(rel_path).replace("/", ".").replace(".py", "")
|
||||
if func.class_name:
|
||||
imports.add(f"from {module_path} import {func.class_name}")
|
||||
else:
|
||||
imports.add(f"from {module_path} import {func.name}")
|
||||
|
||||
for imp in sorted(imports):
|
||||
lines.append(imp)
|
||||
|
||||
# Generate tests grouped by file
|
||||
current_file = None
|
||||
for func in sorted(functions, key=lambda f: (f.file_path, f.line_number)):
|
||||
if func.file_path != current_file:
|
||||
rel = Path(func.file_path).relative_to(repo_path)
|
||||
lines.append("")
|
||||
lines.append(f"# --- Tests for {rel} ---")
|
||||
current_file = func.file_path
|
||||
|
||||
lines.append(generate_test(func, repo_path))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate pytest tests for uncovered functions")
|
||||
parser.add_argument("--repo", required=True, help="Path to repo to scan")
|
||||
parser.add_argument("--output", default=None, help="Output file path")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would be generated")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_path = Path(args.repo).resolve()
|
||||
if not repo_path.exists():
|
||||
print(f"Error: {repo_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Scanning {repo_path}...", file=sys.stderr)
|
||||
all_functions = scan_repo(repo_path)
|
||||
tested = find_existing_tests(repo_path)
|
||||
|
||||
untested = [f for f in all_functions if f.name not in tested]
|
||||
print(f"Found {len(all_functions)} functions, {len(tested)} already tested, {len(untested)} untested", file=sys.stderr)
|
||||
|
||||
if not untested:
|
||||
print("All functions have test coverage!", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if args.dry_run:
|
||||
print(f"\nWould generate {len(untested)} test stubs:")
|
||||
for func in untested[:20]:
|
||||
rel = Path(func.file_path).relative_to(repo_path)
|
||||
cls = f"{func.class_name}." if func.class_name else ""
|
||||
print(f" {rel}:{func.line_number} {cls}{func.name}({', '.join(func.args)})")
|
||||
if len(untested) > 20:
|
||||
print(f" ... and {len(untested) - 20} more")
|
||||
sys.exit(0)
|
||||
|
||||
# Generate
|
||||
test_content = generate_test_file(untested, repo_path)
|
||||
|
||||
output_path = args.output or str(repo_path / "tests" / "test_genome_generated.py")
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
Path(output_path).write_text(test_content)
|
||||
print(f"Generated {len(untested)} test stubs → {output_path}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user