#!/usr/bin/env python3 """ test-suite-generator.py — Auto-generate pytest tests for uncovered functions. Scans Python files in a repo, identifies public functions/methods without tests, and generates pytest test stubs with edge cases. Usage: python3 scripts/test-suite-generator.py --repo ~/repos/my-project python3 scripts/test-suite-generator.py --repo ~/repos/my-project --output tests/test_genome_generated.py python3 scripts/test-suite-generator.py --repo ~/repos/my-project --dry-run """ import ast import argparse import os import sys from pathlib import Path from dataclasses import dataclass, field from typing import Optional SKIP_DIRS = {"node_modules", ".git", "__pycache__", "venv", ".venv", "dist", "build", ".tox", ".mypy_cache"} SKIP_FILES = {"test_", "conftest", "setup", "__init__", "manage", "wsgi", "asgi"} @dataclass class FunctionInfo: name: str file_path: str line_number: int class_name: Optional[str] = None args: list = field(default_factory=list) is_async: bool = False docstring: Optional[str] = None has_return: bool = False has_try_except: bool = False decorators: list = field(default_factory=list) def scan_file(filepath: Path) -> list[FunctionInfo]: """Extract function definitions from a Python file.""" try: source = filepath.read_text(errors="ignore") tree = ast.parse(source, filename=str(filepath)) except (SyntaxError, UnicodeDecodeError): return [] functions = [] current_class = None for node in ast.walk(tree): if isinstance(node, ast.ClassDef): current_class = node.name continue if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): continue # Skip private/internal functions if node.name.startswith("_") and node.name != "__init__": continue # Extract args (skip 'self') args = [arg.arg for arg in node.args.args if arg.arg != "self"] # Check for return statements has_return = any( isinstance(n, ast.Return) and n.value is not None for n in ast.walk(node) ) # Check for try/except has_try_except = any( isinstance(n, ast.Try) for n in ast.walk(node) ) # Extract decorators decorators = [] for dec in node.decorator_list: if isinstance(dec, ast.Name): decorators.append(dec.id) elif isinstance(dec, ast.Attribute): decorators.append(dec.attr) # Docstring docstring = ast.get_docstring(node) functions.append(FunctionInfo( name=node.name, file_path=str(filepath), line_number=node.lineno, class_name=current_class, args=args, is_async=isinstance(node, ast.AsyncFunctionDef), docstring=docstring, has_return=has_return, has_try_except=has_try_except, decorators=decorators, )) current_class = None # Reset after function return functions def scan_repo(repo_path: Path) -> list[FunctionInfo]: """Scan all Python files in a repo.""" all_functions = [] for root, dirs, files in os.walk(repo_path): dirs[:] = [d for d in dirs if d not in SKIP_DIRS] for fname in sorted(files): if not fname.endswith(".py"): continue if any(fname.startswith(skip) for skip in SKIP_FILES): continue filepath = Path(root) / fname functions = scan_file(filepath) all_functions.extend(functions) return all_functions def find_existing_tests(repo_path: Path) -> set[str]: """Find function names that already have test coverage.""" tested = set() for root, dirs, files in os.walk(repo_path): dirs[:] = [d for d in dirs if d not in SKIP_DIRS] for fname in files: if not fname.startswith("test_") or not fname.endswith(".py"): continue filepath = Path(root) / fname try: source = filepath.read_text(errors="ignore") # Look for test_functionname patterns import re for match in re.finditer(r"def test_(\w+)", source): tested.add(match.group(1)) except OSError: pass return tested def generate_test(func: FunctionInfo, repo_path: Path) -> str: """Generate a pytest test for a function.""" rel_path = Path(func.file_path).relative_to(repo_path) module_path = str(rel_path).replace("/", ".").replace(".py", "") # Build import if func.class_name: import_line = f"from {module_path} import {func.class_name}" call_prefix = f"{func.class_name}()." else: import_line = f"from {module_path} import {func.name}" call_prefix = "" func_qualified = f"{call_prefix}{func.name}" test_name = f"test_{func.name}" if func.class_name: test_name = f"test_{func.class_name}_{func.name}" lines = [] lines.append(f"") lines.append(f"def {test_name}():") if func.docstring: lines.append(f' """Auto-generated test for {func.name}."""') else: lines.append(f' """Auto-generated test."""') # Build args arg_strs = [] for arg in func.args: if arg in ("self", "cls"): continue if arg in ("text", "content", "message", "data", "input", "name", "path", "url"): arg_strs.append(f'"{arg}_value"') elif arg in ("count", "n", "limit", "size", "length", "index", "port", "timeout"): arg_strs.append("1") elif arg in ("items", "values", "args", "kwargs", "options", "params"): arg_strs.append("{}") elif arg in ("enabled", "verbose", "quiet", "force", "debug", "dry_run"): arg_strs.append("False") else: arg_strs.append("None") args_joined = ", ".join(arg_strs) # Generate test body if func.has_try_except: lines.append(f" # Function has try/except — test error handling") lines.append(f" try:") lines.append(f" result = {func_qualified}({args_joined})") lines.append(f" except Exception:") lines.append(f" pass # Expected behavior for edge case input") elif func.has_return: lines.append(f" result = {func_qualified}({args_joined})") lines.append(f" assert result is not None or result is None # Placeholder — refine with domain knowledge") elif func.is_async: lines.append(f" import asyncio") lines.append(f" # result = asyncio.run({func_qualified}({args_joined}))") lines.append(f" pass # Async function — needs proper test setup") else: lines.append(f" # Function returns None — test it doesn't raise") lines.append(f" {func_qualified}({args_joined})") lines.append("") return "\n".join(lines) def generate_test_file(functions: list[FunctionInfo], repo_path: Path) -> str: """Generate a complete test file.""" lines = [] lines.append('"""Auto-generated tests by test-suite-generator.py.') lines.append("") lines.append("These tests are stubs — they verify functions don't crash on basic inputs.") lines.append("Refine with domain-specific assertions for real coverage.") lines.append("") lines.append(f"Generated: {len(functions)} test stubs") lines.append('"""') lines.append("") lines.append("import pytest") lines.append("") # Collect unique imports imports = set() for func in functions: rel_path = Path(func.file_path).relative_to(repo_path) module_path = str(rel_path).replace("/", ".").replace(".py", "") if func.class_name: imports.add(f"from {module_path} import {func.class_name}") else: imports.add(f"from {module_path} import {func.name}") for imp in sorted(imports): lines.append(imp) # Generate tests grouped by file current_file = None for func in sorted(functions, key=lambda f: (f.file_path, f.line_number)): if func.file_path != current_file: rel = Path(func.file_path).relative_to(repo_path) lines.append("") lines.append(f"# --- Tests for {rel} ---") current_file = func.file_path lines.append(generate_test(func, repo_path)) return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Generate pytest tests for uncovered functions") parser.add_argument("--repo", required=True, help="Path to repo to scan") parser.add_argument("--output", default=None, help="Output file path") parser.add_argument("--dry-run", action="store_true", help="Show what would be generated") args = parser.parse_args() repo_path = Path(args.repo).resolve() if not repo_path.exists(): print(f"Error: {repo_path} not found", file=sys.stderr) sys.exit(1) print(f"Scanning {repo_path}...", file=sys.stderr) all_functions = scan_repo(repo_path) tested = find_existing_tests(repo_path) untested = [f for f in all_functions if f.name not in tested] print(f"Found {len(all_functions)} functions, {len(tested)} already tested, {len(untested)} untested", file=sys.stderr) if not untested: print("All functions have test coverage!", file=sys.stderr) sys.exit(0) if args.dry_run: print(f"\nWould generate {len(untested)} test stubs:") for func in untested[:20]: rel = Path(func.file_path).relative_to(repo_path) cls = f"{func.class_name}." if func.class_name else "" print(f" {rel}:{func.line_number} {cls}{func.name}({', '.join(func.args)})") if len(untested) > 20: print(f" ... and {len(untested) - 20} more") sys.exit(0) # Generate test_content = generate_test_file(untested, repo_path) output_path = args.output or str(repo_path / "tests" / "test_genome_generated.py") os.makedirs(os.path.dirname(output_path), exist_ok=True) Path(output_path).write_text(test_content) print(f"Generated {len(untested)} test stubs → {output_path}", file=sys.stderr) if __name__ == "__main__": main()