Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 37s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 55s
Tests / test (pull_request) Failing after 55s
Tests / e2e (pull_request) Successful in 2m49s
Resolves #666. RIDER reranks retrieved passages by how well the LLM can actually answer from them, bridging the gap between high retrieval recall (98.4% R@5) and low end-to-end accuracy (17%). agent/rider.py (256 lines): - RIDER class with rerank(passages, query) method - Batch LLM prediction from each passage individually - Confidence-based scoring: specificity, grounding, hedge detection, query relevance, refusal penalty - Async scoring with configurable batch size - Convenience functions: rerank_passages(), is_rider_available() tools/session_search_tool.py: - Wired RIDER into session search pipeline after FTS5 results - Reranks sessions by LLM answerability before summarization - Graceful fallback if RIDER unavailable tests/test_reader_guided_reranking.py (10 tests): - Empty passages, few passages, disabled mode - Confidence scoring: short answers, hedging, grounding, refusal - Convenience function, availability check Config via env vars: RIDER_ENABLED, RIDER_TOP_K, RIDER_TOP_N, RIDER_MAX_TOKENS, RIDER_BATCH_SIZE.
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
"""Tests for Reader-Guided Reranking (RIDER) — issue #666."""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from agent.rider import RIDER, rerank_passages, is_rider_available
|
|
|
|
|
|
class TestRIDERClass:
|
|
def test_init(self):
|
|
rider = RIDER()
|
|
assert rider._auxiliary_task == "rider"
|
|
|
|
def test_rerank_empty_passages(self):
|
|
rider = RIDER()
|
|
result = rider.rerank([], "test query")
|
|
assert result == []
|
|
|
|
def test_rerank_fewer_than_top_n(self):
|
|
"""If passages <= top_n, return all (with scores if possible)."""
|
|
rider = RIDER()
|
|
passages = [{"content": "test content", "session_id": "s1"}]
|
|
result = rider.rerank(passages, "test query", top_n=3)
|
|
assert len(result) == 1
|
|
|
|
@patch("agent.rider.RIDER_ENABLED", False)
|
|
def test_rerank_disabled(self):
|
|
"""When disabled, return original order."""
|
|
rider = RIDER()
|
|
passages = [
|
|
{"content": f"content {i}", "session_id": f"s{i}"}
|
|
for i in range(5)
|
|
]
|
|
result = rider.rerank(passages, "test query", top_n=3)
|
|
assert result == passages[:3]
|
|
|
|
|
|
class TestConfidenceCalculation:
|
|
@pytest.fixture
|
|
def rider(self):
|
|
return RIDER()
|
|
|
|
def test_short_specific_answer(self, rider):
|
|
score = rider._calculate_confidence("Paris", "What is the capital of France?", "Paris is the capital of France.")
|
|
assert score > 0.5
|
|
|
|
def test_hedged_answer(self, rider):
|
|
score = rider._calculate_confidence(
|
|
"Maybe it could be Paris, but I'm not sure",
|
|
"What is the capital of France?",
|
|
"Paris is the capital.",
|
|
)
|
|
assert score < 0.5
|
|
|
|
def test_passage_grounding(self, rider):
|
|
score = rider._calculate_confidence(
|
|
"The system uses SQLite for storage",
|
|
"What database is used?",
|
|
"The system uses SQLite for persistent storage with FTS5 indexing.",
|
|
)
|
|
assert score > 0.5
|
|
|
|
def test_refusal_penalty(self, rider):
|
|
score = rider._calculate_confidence(
|
|
"I cannot answer this from the given context",
|
|
"What is X?",
|
|
"Some unrelated content",
|
|
)
|
|
assert score < 0.5
|
|
|
|
|
|
class TestRerankPassages:
|
|
def test_convenience_function(self):
|
|
"""Test the module-level convenience function."""
|
|
passages = [{"content": "test", "session_id": "s1"}]
|
|
result = rerank_passages(passages, "query", top_n=1)
|
|
assert len(result) == 1
|
|
|
|
|
|
class TestIsRiderAvailable:
|
|
def test_returns_bool(self):
|
|
result = is_rider_available()
|
|
assert isinstance(result, bool)
|