feat: test suite generator — fill coverage gaps (#667)
Some checks failed
Smoke Test / smoke (pull_request) Failing after 23s

This commit is contained in:
2026-04-16 01:39:58 +00:00
parent 10fd467b28
commit a251d3b75d

View File

@@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
test-suite-generator.py — Auto-generate pytest tests for uncovered functions.
Scans Python files in a repo, identifies public functions/methods without tests,
and generates pytest test stubs with edge cases.
Usage:
python3 scripts/test-suite-generator.py --repo ~/repos/my-project
python3 scripts/test-suite-generator.py --repo ~/repos/my-project --output tests/test_genome_generated.py
python3 scripts/test-suite-generator.py --repo ~/repos/my-project --dry-run
"""
import ast
import argparse
import os
import sys
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional
SKIP_DIRS = {"node_modules", ".git", "__pycache__", "venv", ".venv", "dist", "build", ".tox", ".mypy_cache"}
SKIP_FILES = {"test_", "conftest", "setup", "__init__", "manage", "wsgi", "asgi"}
@dataclass
class FunctionInfo:
name: str
file_path: str
line_number: int
class_name: Optional[str] = None
args: list = field(default_factory=list)
is_async: bool = False
docstring: Optional[str] = None
has_return: bool = False
has_try_except: bool = False
decorators: list = field(default_factory=list)
def scan_file(filepath: Path) -> list[FunctionInfo]:
"""Extract function definitions from a Python file."""
try:
source = filepath.read_text(errors="ignore")
tree = ast.parse(source, filename=str(filepath))
except (SyntaxError, UnicodeDecodeError):
return []
functions = []
current_class = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
current_class = node.name
continue
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
# Skip private/internal functions
if node.name.startswith("_") and node.name != "__init__":
continue
# Extract args (skip 'self')
args = [arg.arg for arg in node.args.args if arg.arg != "self"]
# Check for return statements
has_return = any(
isinstance(n, ast.Return) and n.value is not None
for n in ast.walk(node)
)
# Check for try/except
has_try_except = any(
isinstance(n, ast.Try)
for n in ast.walk(node)
)
# Extract decorators
decorators = []
for dec in node.decorator_list:
if isinstance(dec, ast.Name):
decorators.append(dec.id)
elif isinstance(dec, ast.Attribute):
decorators.append(dec.attr)
# Docstring
docstring = ast.get_docstring(node)
functions.append(FunctionInfo(
name=node.name,
file_path=str(filepath),
line_number=node.lineno,
class_name=current_class,
args=args,
is_async=isinstance(node, ast.AsyncFunctionDef),
docstring=docstring,
has_return=has_return,
has_try_except=has_try_except,
decorators=decorators,
))
current_class = None # Reset after function
return functions
def scan_repo(repo_path: Path) -> list[FunctionInfo]:
"""Scan all Python files in a repo."""
all_functions = []
for root, dirs, files in os.walk(repo_path):
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
for fname in sorted(files):
if not fname.endswith(".py"):
continue
if any(fname.startswith(skip) for skip in SKIP_FILES):
continue
filepath = Path(root) / fname
functions = scan_file(filepath)
all_functions.extend(functions)
return all_functions
def find_existing_tests(repo_path: Path) -> set[str]:
"""Find function names that already have test coverage."""
tested = set()
for root, dirs, files in os.walk(repo_path):
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
for fname in files:
if not fname.startswith("test_") or not fname.endswith(".py"):
continue
filepath = Path(root) / fname
try:
source = filepath.read_text(errors="ignore")
# Look for test_functionname patterns
import re
for match in re.finditer(r"def test_(\w+)", source):
tested.add(match.group(1))
except OSError:
pass
return tested
def generate_test(func: FunctionInfo, repo_path: Path) -> str:
"""Generate a pytest test for a function."""
rel_path = Path(func.file_path).relative_to(repo_path)
module_path = str(rel_path).replace("/", ".").replace(".py", "")
# Build import
if func.class_name:
import_line = f"from {module_path} import {func.class_name}"
call_prefix = f"{func.class_name}()."
else:
import_line = f"from {module_path} import {func.name}"
call_prefix = ""
func_qualified = f"{call_prefix}{func.name}"
test_name = f"test_{func.name}"
if func.class_name:
test_name = f"test_{func.class_name}_{func.name}"
lines = []
lines.append(f"")
lines.append(f"def {test_name}():")
if func.docstring:
lines.append(f' """Auto-generated test for {func.name}."""')
else:
lines.append(f' """Auto-generated test."""')
# Build args
arg_strs = []
for arg in func.args:
if arg in ("self", "cls"):
continue
if arg in ("text", "content", "message", "data", "input", "name", "path", "url"):
arg_strs.append(f'"{arg}_value"')
elif arg in ("count", "n", "limit", "size", "length", "index", "port", "timeout"):
arg_strs.append("1")
elif arg in ("items", "values", "args", "kwargs", "options", "params"):
arg_strs.append("{}")
elif arg in ("enabled", "verbose", "quiet", "force", "debug", "dry_run"):
arg_strs.append("False")
else:
arg_strs.append("None")
args_joined = ", ".join(arg_strs)
# Generate test body
if func.has_try_except:
lines.append(f" # Function has try/except — test error handling")
lines.append(f" try:")
lines.append(f" result = {func_qualified}({args_joined})")
lines.append(f" except Exception:")
lines.append(f" pass # Expected behavior for edge case input")
elif func.has_return:
lines.append(f" result = {func_qualified}({args_joined})")
lines.append(f" assert result is not None or result is None # Placeholder — refine with domain knowledge")
elif func.is_async:
lines.append(f" import asyncio")
lines.append(f" # result = asyncio.run({func_qualified}({args_joined}))")
lines.append(f" pass # Async function — needs proper test setup")
else:
lines.append(f" # Function returns None — test it doesn't raise")
lines.append(f" {func_qualified}({args_joined})")
lines.append("")
return "\n".join(lines)
def generate_test_file(functions: list[FunctionInfo], repo_path: Path) -> str:
"""Generate a complete test file."""
lines = []
lines.append('"""Auto-generated tests by test-suite-generator.py.')
lines.append("")
lines.append("These tests are stubs — they verify functions don't crash on basic inputs.")
lines.append("Refine with domain-specific assertions for real coverage.")
lines.append("")
lines.append(f"Generated: {len(functions)} test stubs")
lines.append('"""')
lines.append("")
lines.append("import pytest")
lines.append("")
# Collect unique imports
imports = set()
for func in functions:
rel_path = Path(func.file_path).relative_to(repo_path)
module_path = str(rel_path).replace("/", ".").replace(".py", "")
if func.class_name:
imports.add(f"from {module_path} import {func.class_name}")
else:
imports.add(f"from {module_path} import {func.name}")
for imp in sorted(imports):
lines.append(imp)
# Generate tests grouped by file
current_file = None
for func in sorted(functions, key=lambda f: (f.file_path, f.line_number)):
if func.file_path != current_file:
rel = Path(func.file_path).relative_to(repo_path)
lines.append("")
lines.append(f"# --- Tests for {rel} ---")
current_file = func.file_path
lines.append(generate_test(func, repo_path))
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="Generate pytest tests for uncovered functions")
parser.add_argument("--repo", required=True, help="Path to repo to scan")
parser.add_argument("--output", default=None, help="Output file path")
parser.add_argument("--dry-run", action="store_true", help="Show what would be generated")
args = parser.parse_args()
repo_path = Path(args.repo).resolve()
if not repo_path.exists():
print(f"Error: {repo_path} not found", file=sys.stderr)
sys.exit(1)
print(f"Scanning {repo_path}...", file=sys.stderr)
all_functions = scan_repo(repo_path)
tested = find_existing_tests(repo_path)
untested = [f for f in all_functions if f.name not in tested]
print(f"Found {len(all_functions)} functions, {len(tested)} already tested, {len(untested)} untested", file=sys.stderr)
if not untested:
print("All functions have test coverage!", file=sys.stderr)
sys.exit(0)
if args.dry_run:
print(f"\nWould generate {len(untested)} test stubs:")
for func in untested[:20]:
rel = Path(func.file_path).relative_to(repo_path)
cls = f"{func.class_name}." if func.class_name else ""
print(f" {rel}:{func.line_number} {cls}{func.name}({', '.join(func.args)})")
if len(untested) > 20:
print(f" ... and {len(untested) - 20} more")
sys.exit(0)
# Generate
test_content = generate_test_file(untested, repo_path)
output_path = args.output or str(repo_path / "tests" / "test_genome_generated.py")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
Path(output_path).write_text(test_content)
print(f"Generated {len(untested)} test stubs → {output_path}", file=sys.stderr)
if __name__ == "__main__":
main()