diff --git a/scripts/test-suite-generator.py b/scripts/test-suite-generator.py new file mode 100644 index 0000000..c536ffa --- /dev/null +++ b/scripts/test-suite-generator.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +""" +test-suite-generator.py — Auto-generate pytest tests for uncovered functions. + +Scans Python files in a repo, identifies public functions/methods without tests, +and generates pytest test stubs with edge cases. + +Usage: + python3 scripts/test-suite-generator.py --repo ~/repos/my-project + python3 scripts/test-suite-generator.py --repo ~/repos/my-project --output tests/test_genome_generated.py + python3 scripts/test-suite-generator.py --repo ~/repos/my-project --dry-run +""" + +import ast +import argparse +import os +import sys +from pathlib import Path +from dataclasses import dataclass, field +from typing import Optional + +SKIP_DIRS = {"node_modules", ".git", "__pycache__", "venv", ".venv", "dist", "build", ".tox", ".mypy_cache"} +SKIP_FILES = {"test_", "conftest", "setup", "__init__", "manage", "wsgi", "asgi"} + + +@dataclass +class FunctionInfo: + name: str + file_path: str + line_number: int + class_name: Optional[str] = None + args: list = field(default_factory=list) + is_async: bool = False + docstring: Optional[str] = None + has_return: bool = False + has_try_except: bool = False + decorators: list = field(default_factory=list) + + +def scan_file(filepath: Path) -> list[FunctionInfo]: + """Extract function definitions from a Python file.""" + try: + source = filepath.read_text(errors="ignore") + tree = ast.parse(source, filename=str(filepath)) + except (SyntaxError, UnicodeDecodeError): + return [] + + functions = [] + current_class = None + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + current_class = node.name + continue + + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + + # Skip private/internal functions + if node.name.startswith("_") and node.name != "__init__": + continue + + # Extract args (skip 'self') + args = [arg.arg for arg in node.args.args if arg.arg != "self"] + + # Check for return statements + has_return = any( + isinstance(n, ast.Return) and n.value is not None + for n in ast.walk(node) + ) + + # Check for try/except + has_try_except = any( + isinstance(n, ast.Try) + for n in ast.walk(node) + ) + + # Extract decorators + decorators = [] + for dec in node.decorator_list: + if isinstance(dec, ast.Name): + decorators.append(dec.id) + elif isinstance(dec, ast.Attribute): + decorators.append(dec.attr) + + # Docstring + docstring = ast.get_docstring(node) + + functions.append(FunctionInfo( + name=node.name, + file_path=str(filepath), + line_number=node.lineno, + class_name=current_class, + args=args, + is_async=isinstance(node, ast.AsyncFunctionDef), + docstring=docstring, + has_return=has_return, + has_try_except=has_try_except, + decorators=decorators, + )) + + current_class = None # Reset after function + + return functions + + +def scan_repo(repo_path: Path) -> list[FunctionInfo]: + """Scan all Python files in a repo.""" + all_functions = [] + + for root, dirs, files in os.walk(repo_path): + dirs[:] = [d for d in dirs if d not in SKIP_DIRS] + for fname in sorted(files): + if not fname.endswith(".py"): + continue + if any(fname.startswith(skip) for skip in SKIP_FILES): + continue + + filepath = Path(root) / fname + functions = scan_file(filepath) + all_functions.extend(functions) + + return all_functions + + +def find_existing_tests(repo_path: Path) -> set[str]: + """Find function names that already have test coverage.""" + tested = set() + + for root, dirs, files in os.walk(repo_path): + dirs[:] = [d for d in dirs if d not in SKIP_DIRS] + for fname in files: + if not fname.startswith("test_") or not fname.endswith(".py"): + continue + filepath = Path(root) / fname + try: + source = filepath.read_text(errors="ignore") + # Look for test_functionname patterns + import re + for match in re.finditer(r"def test_(\w+)", source): + tested.add(match.group(1)) + except OSError: + pass + + return tested + + +def generate_test(func: FunctionInfo, repo_path: Path) -> str: + """Generate a pytest test for a function.""" + rel_path = Path(func.file_path).relative_to(repo_path) + module_path = str(rel_path).replace("/", ".").replace(".py", "") + + # Build import + if func.class_name: + import_line = f"from {module_path} import {func.class_name}" + call_prefix = f"{func.class_name}()." + else: + import_line = f"from {module_path} import {func.name}" + call_prefix = "" + + func_qualified = f"{call_prefix}{func.name}" + test_name = f"test_{func.name}" + if func.class_name: + test_name = f"test_{func.class_name}_{func.name}" + + lines = [] + lines.append(f"") + lines.append(f"def {test_name}():") + if func.docstring: + lines.append(f' """Auto-generated test for {func.name}."""') + else: + lines.append(f' """Auto-generated test."""') + + # Build args + arg_strs = [] + for arg in func.args: + if arg in ("self", "cls"): + continue + if arg in ("text", "content", "message", "data", "input", "name", "path", "url"): + arg_strs.append(f'"{arg}_value"') + elif arg in ("count", "n", "limit", "size", "length", "index", "port", "timeout"): + arg_strs.append("1") + elif arg in ("items", "values", "args", "kwargs", "options", "params"): + arg_strs.append("{}") + elif arg in ("enabled", "verbose", "quiet", "force", "debug", "dry_run"): + arg_strs.append("False") + else: + arg_strs.append("None") + + args_joined = ", ".join(arg_strs) + + # Generate test body + if func.has_try_except: + lines.append(f" # Function has try/except — test error handling") + lines.append(f" try:") + lines.append(f" result = {func_qualified}({args_joined})") + lines.append(f" except Exception:") + lines.append(f" pass # Expected behavior for edge case input") + elif func.has_return: + lines.append(f" result = {func_qualified}({args_joined})") + lines.append(f" assert result is not None or result is None # Placeholder — refine with domain knowledge") + elif func.is_async: + lines.append(f" import asyncio") + lines.append(f" # result = asyncio.run({func_qualified}({args_joined}))") + lines.append(f" pass # Async function — needs proper test setup") + else: + lines.append(f" # Function returns None — test it doesn't raise") + lines.append(f" {func_qualified}({args_joined})") + + lines.append("") + return "\n".join(lines) + + +def generate_test_file(functions: list[FunctionInfo], repo_path: Path) -> str: + """Generate a complete test file.""" + lines = [] + lines.append('"""Auto-generated tests by test-suite-generator.py.') + lines.append("") + lines.append("These tests are stubs — they verify functions don't crash on basic inputs.") + lines.append("Refine with domain-specific assertions for real coverage.") + lines.append("") + lines.append(f"Generated: {len(functions)} test stubs") + lines.append('"""') + lines.append("") + lines.append("import pytest") + lines.append("") + + # Collect unique imports + imports = set() + for func in functions: + rel_path = Path(func.file_path).relative_to(repo_path) + module_path = str(rel_path).replace("/", ".").replace(".py", "") + if func.class_name: + imports.add(f"from {module_path} import {func.class_name}") + else: + imports.add(f"from {module_path} import {func.name}") + + for imp in sorted(imports): + lines.append(imp) + + # Generate tests grouped by file + current_file = None + for func in sorted(functions, key=lambda f: (f.file_path, f.line_number)): + if func.file_path != current_file: + rel = Path(func.file_path).relative_to(repo_path) + lines.append("") + lines.append(f"# --- Tests for {rel} ---") + current_file = func.file_path + + lines.append(generate_test(func, repo_path)) + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser(description="Generate pytest tests for uncovered functions") + parser.add_argument("--repo", required=True, help="Path to repo to scan") + parser.add_argument("--output", default=None, help="Output file path") + parser.add_argument("--dry-run", action="store_true", help="Show what would be generated") + args = parser.parse_args() + + repo_path = Path(args.repo).resolve() + if not repo_path.exists(): + print(f"Error: {repo_path} not found", file=sys.stderr) + sys.exit(1) + + print(f"Scanning {repo_path}...", file=sys.stderr) + all_functions = scan_repo(repo_path) + tested = find_existing_tests(repo_path) + + untested = [f for f in all_functions if f.name not in tested] + print(f"Found {len(all_functions)} functions, {len(tested)} already tested, {len(untested)} untested", file=sys.stderr) + + if not untested: + print("All functions have test coverage!", file=sys.stderr) + sys.exit(0) + + if args.dry_run: + print(f"\nWould generate {len(untested)} test stubs:") + for func in untested[:20]: + rel = Path(func.file_path).relative_to(repo_path) + cls = f"{func.class_name}." if func.class_name else "" + print(f" {rel}:{func.line_number} {cls}{func.name}({', '.join(func.args)})") + if len(untested) > 20: + print(f" ... and {len(untested) - 20} more") + sys.exit(0) + + # Generate + test_content = generate_test_file(untested, repo_path) + + output_path = args.output or str(repo_path / "tests" / "test_genome_generated.py") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + Path(output_path).write_text(test_content) + print(f"Generated {len(untested)} test stubs → {output_path}", file=sys.stderr) + + +if __name__ == "__main__": + main()