diff --git a/run_agent.py b/run_agent.py index 95e8a5453..11a545a9a 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2740,7 +2740,7 @@ class AIAgent: "messages": api_messages, } if self.max_tokens is not None: - summary_kwargs["max_tokens"] = self.max_tokens + summary_kwargs.update(self._max_tokens_param(self.max_tokens)) if summary_extra_body: summary_kwargs["extra_body"] = summary_extra_body diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index ada6685c6..d5f4c2cdc 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -913,3 +913,31 @@ class TestConversationHistoryNotMutated: ) # Result should have more messages than the original history assert len(result["messages"]) > original_len + + +# --------------------------------------------------------------------------- +# _max_tokens_param consistency +# --------------------------------------------------------------------------- + +class TestMaxTokensParam: + """Verify _max_tokens_param returns the correct key for each provider.""" + + def test_returns_max_completion_tokens_for_direct_openai(self, agent): + agent.base_url = "https://api.openai.com/v1" + result = agent._max_tokens_param(4096) + assert result == {"max_completion_tokens": 4096} + + def test_returns_max_tokens_for_openrouter(self, agent): + agent.base_url = "https://openrouter.ai/api/v1" + result = agent._max_tokens_param(4096) + assert result == {"max_tokens": 4096} + + def test_returns_max_tokens_for_local(self, agent): + agent.base_url = "http://localhost:11434/v1" + result = agent._max_tokens_param(4096) + assert result == {"max_tokens": 4096} + + def test_not_tricked_by_openai_in_openrouter_url(self, agent): + agent.base_url = "https://openrouter.ai/api/v1/api.openai.com" + result = agent._max_tokens_param(4096) + assert result == {"max_tokens": 4096}