diff --git a/src/timmy/session.py b/src/timmy/session.py index 1116309..ed5d93d 100644 --- a/src/timmy/session.py +++ b/src/timmy/session.py @@ -13,6 +13,7 @@ import re import httpx +from timmy.confidence import estimate_confidence from timmy.session_logger import get_session_logger logger = logging.getLogger(__name__) @@ -105,8 +106,12 @@ async def chat(message: str, session_id: str | None = None) -> str: # Post-processing: clean up any leaked tool calls or chain-of-thought response_text = _clean_response(response_text) + # Estimate confidence of the response + confidence = estimate_confidence(response_text) + logger.debug("Response confidence: %.2f", confidence) + # Record Timmy response after getting it - session_logger.record_message("timmy", response_text) + session_logger.record_message("timmy", response_text, confidence=confidence) # Flush session logs to disk session_logger.flush() @@ -141,7 +146,9 @@ async def chat_with_tools(message: str, session_id: str | None = None): response_text = ( run_output.content if hasattr(run_output, "content") and run_output.content else "" ) - session_logger.record_message("timmy", response_text) + confidence = estimate_confidence(response_text) if response_text else None + logger.debug("Response confidence: %.2f", confidence) + session_logger.record_message("timmy", response_text, confidence=confidence) session_logger.flush() return run_output except (httpx.ConnectError, httpx.ReadError, ConnectionError) as exc: @@ -178,7 +185,9 @@ async def continue_chat(run_output, session_id: str | None = None): result = await agent.acontinue_run(run_response=run_output, stream=False, session_id=sid) # Record Timmy response after getting it response_text = result.content if hasattr(result, "content") and result.content else "" - session_logger.record_message("timmy", response_text) + confidence = estimate_confidence(response_text) if response_text else None + logger.debug("Response confidence: %.2f", confidence) + session_logger.record_message("timmy", response_text, confidence=confidence) session_logger.flush() return result except (httpx.ConnectError, httpx.ReadError, ConnectionError) as exc: diff --git a/tests/timmy/test_audit_trail.py b/tests/timmy/test_audit_trail.py index d95bb15..022c050 100644 --- a/tests/timmy/test_audit_trail.py +++ b/tests/timmy/test_audit_trail.py @@ -55,13 +55,14 @@ async def test_chat_records_timmy_response_after_agent_call(): with ( patch("timmy.session._get_agent", return_value=mock_agent), patch("timmy.session.get_session_logger", return_value=mock_session_logger), + patch("timmy.session.estimate_confidence", return_value=0.75), ): from timmy.session import chat await chat("Hi Timmy") - # Verify Timmy response was recorded after agent call - mock_session_logger.record_message.assert_any_call("timmy", "Hello, sir.") + # Verify Timmy response was recorded after agent call with confidence + mock_session_logger.record_message.assert_any_call("timmy", "Hello, sir.", confidence=0.75) @pytest.mark.asyncio @@ -182,13 +183,14 @@ async def test_chat_with_tools_records_timmy_response(): with ( patch("timmy.session._get_agent", return_value=mock_agent), patch("timmy.session.get_session_logger", return_value=mock_session_logger), + patch("timmy.session.estimate_confidence", return_value=0.75), ): from timmy.session import chat_with_tools await chat_with_tools("Use a tool") - # Verify Timmy response was recorded - mock_session_logger.record_message.assert_any_call("timmy", "Tool result here") + # Verify Timmy response was recorded with confidence + mock_session_logger.record_message.assert_any_call("timmy", "Tool result here", confidence=0.75) @pytest.mark.asyncio @@ -248,13 +250,16 @@ async def test_continue_chat_records_timmy_response(): with ( patch("timmy.session._get_agent", return_value=mock_agent), patch("timmy.session.get_session_logger", return_value=mock_session_logger), + patch("timmy.session.estimate_confidence", return_value=0.75), ): from timmy.session import continue_chat await continue_chat(mock_run_output) - # Verify Timmy response was recorded - mock_session_logger.record_message.assert_called_once_with("timmy", "Continued result") + # Verify Timmy response was recorded with confidence + mock_session_logger.record_message.assert_called_once_with( + "timmy", "Continued result", confidence=0.75 + ) @pytest.mark.asyncio diff --git a/tests/timmy/test_session.py b/tests/timmy/test_session.py index 6c75ad3..0b9f4cf 100644 --- a/tests/timmy/test_session.py +++ b/tests/timmy/test_session.py @@ -213,3 +213,87 @@ def test_reset_session_clears_context(): reset_session("test-session") mock_cm.clear_context.assert_called_once_with("test-session") + + +# --------------------------------------------------------------------------- +# Confidence estimation integration +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_passes_confidence_to_record_message(): + """chat() should estimate confidence and pass it to record_message.""" + mock_agent = MagicMock() + mock_agent.arun = AsyncMock(return_value=MagicMock(content="This is a confident answer.")) + + with ( + patch("timmy.session._get_agent", return_value=mock_agent), + patch("timmy.session.estimate_confidence", return_value=0.85) as mock_estimate, + patch("timmy.session.get_session_logger") as mock_get_logger, + ): + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + from timmy.session import chat + + await chat("test message") + + mock_estimate.assert_called_once_with("This is a confident answer.") + # Check that record_message was called with confidence + calls = mock_logger.record_message.call_args_list + assert len(calls) >= 2 # user message + timmy response + # Last call should be timmy response with confidence + _, kwargs = calls[-1] + assert kwargs.get("confidence") == 0.85 + + +@pytest.mark.asyncio +async def test_chat_with_tools_passes_confidence_to_record_message(): + """chat_with_tools() should estimate confidence and pass it to record_message.""" + mock_agent = MagicMock() + mock_agent.arun = AsyncMock(return_value=MagicMock(content="Tool response here.")) + + with ( + patch("timmy.session._get_agent", return_value=mock_agent), + patch("timmy.session.estimate_confidence", return_value=0.72) as mock_estimate, + patch("timmy.session.get_session_logger") as mock_get_logger, + ): + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + from timmy.session import chat_with_tools + + await chat_with_tools("test message") + + mock_estimate.assert_called_once_with("Tool response here.") + calls = mock_logger.record_message.call_args_list + assert len(calls) >= 2 + _, kwargs = calls[-1] + assert kwargs.get("confidence") == 0.72 + + +@pytest.mark.asyncio +async def test_continue_chat_passes_confidence_to_record_message(): + """continue_chat() should estimate confidence and pass it to record_message.""" + mock_agent = MagicMock() + mock_agent.acontinue_run = AsyncMock(return_value=MagicMock(content="Continued response.")) + + mock_run_output = MagicMock() + + with ( + patch("timmy.session._get_agent", return_value=mock_agent), + patch("timmy.session.estimate_confidence", return_value=0.91) as mock_estimate, + patch("timmy.session.get_session_logger") as mock_get_logger, + ): + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + from timmy.session import continue_chat + + await continue_chat(mock_run_output) + + mock_estimate.assert_called_once_with("Continued response.") + calls = mock_logger.record_message.call_args_list + assert len(calls) >= 1 # should have timmy response + _, kwargs = calls[-1] + assert kwargs.get("confidence") == 0.91