fix(delegate): share credential pools with subagents + per-task leasing
Cherry-picked from PR #5580 by MestreY0d4-Uninter. - Share parent's credential pool with child agents for key rotation - Leasing layer spreads parallel children across keys (least-loaded) - Thread-safe acquire_lease/release_lease in CredentialPool - Reverted sneaked-in tool-name restoration change (kept original getattr + isinstance guard pattern)
This commit is contained in:
committed by
Teknium
parent
8dee82ea1e
commit
f2c11ff30c
@@ -348,6 +348,9 @@ def get_pool_strategy(provider: str) -> str:
|
||||
return STRATEGY_FILL_FIRST
|
||||
|
||||
|
||||
DEFAULT_MAX_CONCURRENT_PER_CREDENTIAL = 1
|
||||
|
||||
|
||||
class CredentialPool:
|
||||
def __init__(self, provider: str, entries: List[PooledCredential]):
|
||||
self.provider = provider
|
||||
@@ -355,6 +358,8 @@ class CredentialPool:
|
||||
self._current_id: Optional[str] = None
|
||||
self._strategy = get_pool_strategy(provider)
|
||||
self._lock = threading.Lock()
|
||||
self._active_leases: Dict[str, int] = {}
|
||||
self._max_concurrent = DEFAULT_MAX_CONCURRENT_PER_CREDENTIAL
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
return bool(self._entries)
|
||||
@@ -760,6 +765,51 @@ class CredentialPool:
|
||||
logger.info("credential pool: rotated to %s", _next_label)
|
||||
return next_entry
|
||||
|
||||
def acquire_lease(self, credential_id: Optional[str] = None) -> Optional[str]:
|
||||
"""Acquire a soft lease on a credential.
|
||||
|
||||
If a specific credential_id is provided, lease that entry directly.
|
||||
Otherwise prefer the least-leased available credential, using priority as
|
||||
a stable tie-breaker. When every credential is already at the soft cap,
|
||||
still return the least-leased one instead of blocking.
|
||||
"""
|
||||
with self._lock:
|
||||
if credential_id:
|
||||
self._active_leases[credential_id] = self._active_leases.get(credential_id, 0) + 1
|
||||
self._current_id = credential_id
|
||||
return credential_id
|
||||
|
||||
available = self._available_entries(clear_expired=True, refresh=True)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
below_cap = [
|
||||
entry for entry in available
|
||||
if self._active_leases.get(entry.id, 0) < self._max_concurrent
|
||||
]
|
||||
candidates = below_cap if below_cap else available
|
||||
chosen = min(
|
||||
candidates,
|
||||
key=lambda entry: (self._active_leases.get(entry.id, 0), entry.priority),
|
||||
)
|
||||
self._active_leases[chosen.id] = self._active_leases.get(chosen.id, 0) + 1
|
||||
self._current_id = chosen.id
|
||||
return chosen.id
|
||||
|
||||
def release_lease(self, credential_id: str) -> None:
|
||||
"""Release a previously acquired credential lease."""
|
||||
with self._lock:
|
||||
count = self._active_leases.get(credential_id, 0)
|
||||
if count <= 1:
|
||||
self._active_leases.pop(credential_id, None)
|
||||
else:
|
||||
self._active_leases[credential_id] = count - 1
|
||||
|
||||
def active_lease_count(self, credential_id: str) -> int:
|
||||
"""Return the number of active leases for a credential."""
|
||||
with self._lock:
|
||||
return self._active_leases.get(credential_id, 0)
|
||||
|
||||
def try_refresh_current(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._try_refresh_current_unlocked()
|
||||
|
||||
@@ -947,7 +947,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:together.ai": [
|
||||
@@ -957,7 +957,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-tog-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:fireworks": [
|
||||
@@ -967,7 +967,7 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-fw-xxx",
|
||||
"access_token": "***",
|
||||
}
|
||||
],
|
||||
"custom:empty": [],
|
||||
@@ -980,3 +980,78 @@ def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
result = list_custom_pool_providers()
|
||||
assert result == ["custom:fireworks", "custom:together.ai"]
|
||||
# "custom:empty" not included because it's empty
|
||||
|
||||
|
||||
|
||||
def test_acquire_lease_prefers_unleased_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
first = pool.acquire_lease()
|
||||
second = pool.acquire_lease()
|
||||
|
||||
assert first == "cred-1"
|
||||
assert second == "cred-2"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
assert pool.active_lease_count("cred-2") == 1
|
||||
|
||||
|
||||
|
||||
def test_release_lease_decrements_counter(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
leased = pool.acquire_lease()
|
||||
assert leased == "cred-1"
|
||||
assert pool.active_lease_count("cred-1") == 1
|
||||
|
||||
pool.release_lease("cred-1")
|
||||
assert pool.active_lease_count("cred-1") == 0
|
||||
|
||||
@@ -26,6 +26,7 @@ from tools.delegate_tool import (
|
||||
_build_child_agent,
|
||||
_build_child_system_prompt,
|
||||
_strip_blocked_tools,
|
||||
_resolve_child_credential_pool,
|
||||
_resolve_delegation_credentials,
|
||||
)
|
||||
|
||||
@@ -930,5 +931,126 @@ class TestDelegationProviderIntegration(unittest.TestCase):
|
||||
self.assertEqual(kwargs["base_url"], parent.base_url)
|
||||
|
||||
|
||||
class TestChildCredentialPoolResolution(unittest.TestCase):
|
||||
def test_same_provider_shares_parent_pool(self):
|
||||
parent = _make_mock_parent()
|
||||
mock_pool = MagicMock()
|
||||
parent._credential_pool = mock_pool
|
||||
|
||||
result = _resolve_child_credential_pool("openrouter", parent)
|
||||
self.assertIs(result, mock_pool)
|
||||
|
||||
def test_no_provider_inherits_parent_pool(self):
|
||||
parent = _make_mock_parent()
|
||||
mock_pool = MagicMock()
|
||||
parent._credential_pool = mock_pool
|
||||
|
||||
result = _resolve_child_credential_pool(None, parent)
|
||||
self.assertIs(result, mock_pool)
|
||||
|
||||
def test_different_provider_loads_own_pool(self):
|
||||
parent = _make_mock_parent()
|
||||
parent._credential_pool = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.has_credentials.return_value = True
|
||||
|
||||
with patch("agent.credential_pool.load_pool", return_value=mock_pool):
|
||||
result = _resolve_child_credential_pool("anthropic", parent)
|
||||
|
||||
self.assertIs(result, mock_pool)
|
||||
|
||||
def test_different_provider_empty_pool_returns_none(self):
|
||||
parent = _make_mock_parent()
|
||||
parent._credential_pool = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.has_credentials.return_value = False
|
||||
|
||||
with patch("agent.credential_pool.load_pool", return_value=mock_pool):
|
||||
result = _resolve_child_credential_pool("anthropic", parent)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_different_provider_load_failure_returns_none(self):
|
||||
parent = _make_mock_parent()
|
||||
parent._credential_pool = MagicMock()
|
||||
|
||||
with patch("agent.credential_pool.load_pool", side_effect=Exception("disk error")):
|
||||
result = _resolve_child_credential_pool("anthropic", parent)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_build_child_agent_assigns_parent_pool_when_shared(self):
|
||||
parent = _make_mock_parent()
|
||||
mock_pool = MagicMock()
|
||||
parent._credential_pool = mock_pool
|
||||
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
MockAgent.return_value = mock_child
|
||||
|
||||
_build_child_agent(
|
||||
task_index=0,
|
||||
goal="Test pool assignment",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model=None,
|
||||
max_iterations=10,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
self.assertEqual(mock_child._credential_pool, mock_pool)
|
||||
|
||||
|
||||
class TestChildCredentialLeasing(unittest.TestCase):
|
||||
def test_run_single_child_acquires_and_releases_lease(self):
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
leased_entry = MagicMock()
|
||||
leased_entry.id = "cred-b"
|
||||
|
||||
child = MagicMock()
|
||||
child._credential_pool = MagicMock()
|
||||
child._credential_pool.acquire_lease.return_value = "cred-b"
|
||||
child._credential_pool.current.return_value = leased_entry
|
||||
child.run_conversation.return_value = {
|
||||
"final_response": "done",
|
||||
"completed": True,
|
||||
"interrupted": False,
|
||||
"api_calls": 1,
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Investigate rate limits",
|
||||
child=child,
|
||||
parent_agent=_make_mock_parent(),
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "completed")
|
||||
child._credential_pool.acquire_lease.assert_called_once_with()
|
||||
child._swap_credential.assert_called_once_with(leased_entry)
|
||||
child._credential_pool.release_lease.assert_called_once_with("cred-b")
|
||||
|
||||
def test_run_single_child_releases_lease_after_failure(self):
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
child = MagicMock()
|
||||
child._credential_pool = MagicMock()
|
||||
child._credential_pool.acquire_lease.return_value = "cred-a"
|
||||
child._credential_pool.current.return_value = MagicMock(id="cred-a")
|
||||
child.run_conversation.side_effect = RuntimeError("boom")
|
||||
|
||||
result = _run_single_child(
|
||||
task_index=1,
|
||||
goal="Trigger failure",
|
||||
child=child,
|
||||
parent_agent=_make_mock_parent(),
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "error")
|
||||
child._credential_pool.release_lease.assert_called_once_with("cred-a")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -279,6 +279,12 @@ def _build_child_agent(
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
# Share a credential pool with the child when possible so subagents can
|
||||
# rotate credentials on rate limits instead of getting pinned to one key.
|
||||
child_pool = _resolve_child_credential_pool(effective_provider, parent_agent)
|
||||
if child_pool is not None:
|
||||
child._credential_pool = child_pool
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||
@@ -312,6 +318,18 @@ def _run_single_child(
|
||||
_saved_tool_names = getattr(child, "_delegate_saved_tool_names",
|
||||
list(model_tools._last_resolved_tool_names))
|
||||
|
||||
child_pool = getattr(child, '_credential_pool', None)
|
||||
leased_cred_id = None
|
||||
if child_pool is not None:
|
||||
leased_cred_id = child_pool.acquire_lease()
|
||||
if leased_cred_id is not None:
|
||||
try:
|
||||
leased_entry = child_pool.current()
|
||||
if leased_entry is not None and hasattr(child, '_swap_credential'):
|
||||
child._swap_credential(leased_entry)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to bind child to leased credential: %s", exc)
|
||||
|
||||
try:
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
@@ -422,6 +440,12 @@ def _run_single_child(
|
||||
}
|
||||
|
||||
finally:
|
||||
if child_pool is not None and leased_cred_id is not None:
|
||||
try:
|
||||
child_pool.release_lease(leased_cred_id)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to release credential lease: %s", exc)
|
||||
|
||||
# Restore the parent's tool names so the process-global is correct
|
||||
# for any subsequent execute_code calls or other consumers.
|
||||
import model_tools
|
||||
@@ -430,6 +454,8 @@ def _run_single_child(
|
||||
if isinstance(saved_tool_names, list):
|
||||
model_tools._last_resolved_tool_names = list(saved_tool_names)
|
||||
|
||||
# Remove child from active tracking
|
||||
|
||||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
try:
|
||||
@@ -626,6 +652,38 @@ def delegate_task(
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _resolve_child_credential_pool(effective_provider: Optional[str], parent_agent):
|
||||
"""Resolve a credential pool for the child agent.
|
||||
|
||||
Rules:
|
||||
1. Same provider as the parent -> share the parent's pool so cooldown state
|
||||
and rotation stay synchronized.
|
||||
2. Different provider -> try to load that provider's own pool.
|
||||
3. No pool available -> return None and let the child keep the inherited
|
||||
fixed credential behavior.
|
||||
"""
|
||||
if not effective_provider:
|
||||
return getattr(parent_agent, "_credential_pool", None)
|
||||
|
||||
parent_provider = getattr(parent_agent, "provider", None) or ""
|
||||
parent_pool = getattr(parent_agent, "_credential_pool", None)
|
||||
if parent_pool is not None and effective_provider == parent_provider:
|
||||
return parent_pool
|
||||
|
||||
try:
|
||||
from agent.credential_pool import load_pool
|
||||
pool = load_pool(effective_provider)
|
||||
if pool is not None and pool.has_credentials():
|
||||
return pool
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Could not load credential pool for child provider '%s': %s",
|
||||
effective_provider,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
||||
"""Resolve credentials for subagent delegation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user