diff --git a/agent/credential_pool.py b/agent/credential_pool.py index 472f65f2d..f57ae049c 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -348,6 +348,9 @@ def get_pool_strategy(provider: str) -> str: return STRATEGY_FILL_FIRST +DEFAULT_MAX_CONCURRENT_PER_CREDENTIAL = 1 + + class CredentialPool: def __init__(self, provider: str, entries: List[PooledCredential]): self.provider = provider @@ -355,6 +358,8 @@ class CredentialPool: self._current_id: Optional[str] = None self._strategy = get_pool_strategy(provider) self._lock = threading.Lock() + self._active_leases: Dict[str, int] = {} + self._max_concurrent = DEFAULT_MAX_CONCURRENT_PER_CREDENTIAL def has_credentials(self) -> bool: return bool(self._entries) @@ -760,6 +765,51 @@ class CredentialPool: logger.info("credential pool: rotated to %s", _next_label) return next_entry + def acquire_lease(self, credential_id: Optional[str] = None) -> Optional[str]: + """Acquire a soft lease on a credential. + + If a specific credential_id is provided, lease that entry directly. + Otherwise prefer the least-leased available credential, using priority as + a stable tie-breaker. When every credential is already at the soft cap, + still return the least-leased one instead of blocking. + """ + with self._lock: + if credential_id: + self._active_leases[credential_id] = self._active_leases.get(credential_id, 0) + 1 + self._current_id = credential_id + return credential_id + + available = self._available_entries(clear_expired=True, refresh=True) + if not available: + return None + + below_cap = [ + entry for entry in available + if self._active_leases.get(entry.id, 0) < self._max_concurrent + ] + candidates = below_cap if below_cap else available + chosen = min( + candidates, + key=lambda entry: (self._active_leases.get(entry.id, 0), entry.priority), + ) + self._active_leases[chosen.id] = self._active_leases.get(chosen.id, 0) + 1 + self._current_id = chosen.id + return chosen.id + + def release_lease(self, credential_id: str) -> None: + """Release a previously acquired credential lease.""" + with self._lock: + count = self._active_leases.get(credential_id, 0) + if count <= 1: + self._active_leases.pop(credential_id, None) + else: + self._active_leases[credential_id] = count - 1 + + def active_lease_count(self, credential_id: str) -> int: + """Return the number of active leases for a credential.""" + with self._lock: + return self._active_leases.get(credential_id, 0) + def try_refresh_current(self) -> Optional[PooledCredential]: with self._lock: return self._try_refresh_current_unlocked() diff --git a/tests/test_credential_pool.py b/tests/test_credential_pool.py index ff6e037be..891ab68a8 100644 --- a/tests/test_credential_pool.py +++ b/tests/test_credential_pool.py @@ -947,7 +947,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch): "auth_type": "api_key", "priority": 0, "source": "manual", - "access_token": "sk-ant-xxx", + "access_token": "***", } ], "custom:together.ai": [ @@ -957,7 +957,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch): "auth_type": "api_key", "priority": 0, "source": "manual", - "access_token": "sk-tog-xxx", + "access_token": "***", } ], "custom:fireworks": [ @@ -967,7 +967,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch): "auth_type": "api_key", "priority": 0, "source": "manual", - "access_token": "sk-fw-xxx", + "access_token": "***", } ], "custom:empty": [], @@ -980,3 +980,78 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch): result = list_custom_pool_providers() assert result == ["custom:fireworks", "custom:together.ai"] # "custom:empty" not included because it's empty + + + +def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openrouter": [ + { + "id": "cred-1", + "label": "primary", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "***", + }, + { + "id": "cred-2", + "label": "secondary", + "auth_type": "api_key", + "priority": 1, + "source": "manual", + "access_token": "***", + }, + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("openrouter") + first = pool.acquire_lease() + second = pool.acquire_lease() + + assert first == "cred-1" + assert second == "cred-2" + assert pool.active_lease_count("cred-1") == 1 + assert pool.active_lease_count("cred-2") == 1 + + + +def test_release_lease_decrements_counter(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openrouter": [ + { + "id": "cred-1", + "label": "primary", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "***", + } + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("openrouter") + leased = pool.acquire_lease() + assert leased == "cred-1" + assert pool.active_lease_count("cred-1") == 1 + + pool.release_lease("cred-1") + assert pool.active_lease_count("cred-1") == 0 diff --git a/tests/tools/test_delegate.py b/tests/tools/test_delegate.py index 0e5e63a70..ebdf60d29 100644 --- a/tests/tools/test_delegate.py +++ b/tests/tools/test_delegate.py @@ -26,6 +26,7 @@ from tools.delegate_tool import ( _build_child_agent, _build_child_system_prompt, _strip_blocked_tools, + _resolve_child_credential_pool, _resolve_delegation_credentials, ) @@ -930,5 +931,126 @@ class TestDelegationProviderIntegration(unittest.TestCase): self.assertEqual(kwargs["base_url"], parent.base_url) +class TestChildCredentialPoolResolution(unittest.TestCase): + def test_same_provider_shares_parent_pool(self): + parent = _make_mock_parent() + mock_pool = MagicMock() + parent._credential_pool = mock_pool + + result = _resolve_child_credential_pool("openrouter", parent) + self.assertIs(result, mock_pool) + + def test_no_provider_inherits_parent_pool(self): + parent = _make_mock_parent() + mock_pool = MagicMock() + parent._credential_pool = mock_pool + + result = _resolve_child_credential_pool(None, parent) + self.assertIs(result, mock_pool) + + def test_different_provider_loads_own_pool(self): + parent = _make_mock_parent() + parent._credential_pool = MagicMock() + mock_pool = MagicMock() + mock_pool.has_credentials.return_value = True + + with patch("agent.credential_pool.load_pool", return_value=mock_pool): + result = _resolve_child_credential_pool("anthropic", parent) + + self.assertIs(result, mock_pool) + + def test_different_provider_empty_pool_returns_none(self): + parent = _make_mock_parent() + parent._credential_pool = MagicMock() + mock_pool = MagicMock() + mock_pool.has_credentials.return_value = False + + with patch("agent.credential_pool.load_pool", return_value=mock_pool): + result = _resolve_child_credential_pool("anthropic", parent) + + self.assertIsNone(result) + + def test_different_provider_load_failure_returns_none(self): + parent = _make_mock_parent() + parent._credential_pool = MagicMock() + + with patch("agent.credential_pool.load_pool", side_effect=Exception("disk error")): + result = _resolve_child_credential_pool("anthropic", parent) + + self.assertIsNone(result) + + def test_build_child_agent_assigns_parent_pool_when_shared(self): + parent = _make_mock_parent() + mock_pool = MagicMock() + parent._credential_pool = mock_pool + + with patch("run_agent.AIAgent") as MockAgent: + mock_child = MagicMock() + MockAgent.return_value = mock_child + + _build_child_agent( + task_index=0, + goal="Test pool assignment", + context=None, + toolsets=["terminal"], + model=None, + max_iterations=10, + parent_agent=parent, + ) + + self.assertEqual(mock_child._credential_pool, mock_pool) + + +class TestChildCredentialLeasing(unittest.TestCase): + def test_run_single_child_acquires_and_releases_lease(self): + from tools.delegate_tool import _run_single_child + + leased_entry = MagicMock() + leased_entry.id = "cred-b" + + child = MagicMock() + child._credential_pool = MagicMock() + child._credential_pool.acquire_lease.return_value = "cred-b" + child._credential_pool.current.return_value = leased_entry + child.run_conversation.return_value = { + "final_response": "done", + "completed": True, + "interrupted": False, + "api_calls": 1, + "messages": [], + } + + result = _run_single_child( + task_index=0, + goal="Investigate rate limits", + child=child, + parent_agent=_make_mock_parent(), + ) + + self.assertEqual(result["status"], "completed") + child._credential_pool.acquire_lease.assert_called_once_with() + child._swap_credential.assert_called_once_with(leased_entry) + child._credential_pool.release_lease.assert_called_once_with("cred-b") + + def test_run_single_child_releases_lease_after_failure(self): + from tools.delegate_tool import _run_single_child + + child = MagicMock() + child._credential_pool = MagicMock() + child._credential_pool.acquire_lease.return_value = "cred-a" + child._credential_pool.current.return_value = MagicMock(id="cred-a") + child.run_conversation.side_effect = RuntimeError("boom") + + result = _run_single_child( + task_index=1, + goal="Trigger failure", + child=child, + parent_agent=_make_mock_parent(), + ) + + self.assertEqual(result["status"], "error") + child._credential_pool.release_lease.assert_called_once_with("cred-a") + + if __name__ == "__main__": unittest.main() diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index 28ffc795a..9cae3ddd0 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -279,6 +279,12 @@ def _build_child_agent( # Set delegation depth so children can't spawn grandchildren child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 + # Share a credential pool with the child when possible so subagents can + # rotate credentials on rate limits instead of getting pinned to one key. + child_pool = _resolve_child_credential_pool(effective_provider, parent_agent) + if child_pool is not None: + child._credential_pool = child_pool + # Register child for interrupt propagation if hasattr(parent_agent, '_active_children'): lock = getattr(parent_agent, '_active_children_lock', None) @@ -312,6 +318,18 @@ def _run_single_child( _saved_tool_names = getattr(child, "_delegate_saved_tool_names", list(model_tools._last_resolved_tool_names)) + child_pool = getattr(child, '_credential_pool', None) + leased_cred_id = None + if child_pool is not None: + leased_cred_id = child_pool.acquire_lease() + if leased_cred_id is not None: + try: + leased_entry = child_pool.current() + if leased_entry is not None and hasattr(child, '_swap_credential'): + child._swap_credential(leased_entry) + except Exception as exc: + logger.debug("Failed to bind child to leased credential: %s", exc) + try: result = child.run_conversation(user_message=goal) @@ -422,6 +440,12 @@ def _run_single_child( } finally: + if child_pool is not None and leased_cred_id is not None: + try: + child_pool.release_lease(leased_cred_id) + except Exception as exc: + logger.debug("Failed to release credential lease: %s", exc) + # Restore the parent's tool names so the process-global is correct # for any subsequent execute_code calls or other consumers. import model_tools @@ -430,6 +454,8 @@ def _run_single_child( if isinstance(saved_tool_names, list): model_tools._last_resolved_tool_names = list(saved_tool_names) + # Remove child from active tracking + # Unregister child from interrupt propagation if hasattr(parent_agent, '_active_children'): try: @@ -626,6 +652,38 @@ def delegate_task( }, ensure_ascii=False) +def _resolve_child_credential_pool(effective_provider: Optional[str], parent_agent): + """Resolve a credential pool for the child agent. + + Rules: + 1. Same provider as the parent -> share the parent's pool so cooldown state + and rotation stay synchronized. + 2. Different provider -> try to load that provider's own pool. + 3. No pool available -> return None and let the child keep the inherited + fixed credential behavior. + """ + if not effective_provider: + return getattr(parent_agent, "_credential_pool", None) + + parent_provider = getattr(parent_agent, "provider", None) or "" + parent_pool = getattr(parent_agent, "_credential_pool", None) + if parent_pool is not None and effective_provider == parent_provider: + return parent_pool + + try: + from agent.credential_pool import load_pool + pool = load_pool(effective_provider) + if pool is not None and pool.has_credentials(): + return pool + except Exception as exc: + logger.debug( + "Could not load credential pool for child provider '%s': %s", + effective_provider, + exc, + ) + return None + + def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict: """Resolve credentials for subagent delegation.