Add outcome-based learning system for swarm agents
Introduce a feedback loop where task outcomes (win/loss, success/failure)
feed back into agent bidding strategy. Borrows the "learn from outcomes"
concept from Spark Intelligence but builds it natively on Timmy's existing
SQLite + swarm architecture.
New module: src/swarm/learner.py
- Records every bid outcome with task description context
- Computes per-agent metrics: win rate, success rate, keyword performance
- suggest_bid() adjusts bids based on historical performance
- learned_keywords() discovers what task types agents actually excel at
Changes:
- persona_node: _compute_bid() now consults learner for adaptive adjustments
- coordinator: complete_task/fail_task feed results into learner
- coordinator: run_auction_and_assign records all bid outcomes
- routes/swarm: add /swarm/insights and /swarm/insights/{agent_id} endpoints
- routes/swarm: add POST /swarm/tasks/{task_id}/fail endpoint
All 413 tests pass (23 new + 390 existing).
https://claude.ai/code/session_01E5jhTCwSUnJk9p9zrTMVUJ
This commit is contained in:
@@ -12,6 +12,7 @@ from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from swarm import learner as swarm_learner
|
||||
from swarm import registry
|
||||
from swarm.coordinator import coordinator
|
||||
from swarm.tasks import TaskStatus, update_task
|
||||
@@ -139,6 +140,55 @@ async def complete_task(task_id: str, result: str = Form(...)):
|
||||
return {"task_id": task_id, "status": task.status.value}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/fail")
|
||||
async def fail_task(task_id: str, reason: str = Form("")):
|
||||
"""Mark a task failed — feeds failure data into the learner."""
|
||||
task = coordinator.fail_task(task_id, reason)
|
||||
if task is None:
|
||||
raise HTTPException(404, "Task not found")
|
||||
return {"task_id": task_id, "status": task.status.value}
|
||||
|
||||
|
||||
# ── Learning insights ────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/insights")
|
||||
async def swarm_insights():
|
||||
"""Return learned performance metrics for all agents."""
|
||||
all_metrics = swarm_learner.get_all_metrics()
|
||||
return {
|
||||
"agents": {
|
||||
aid: {
|
||||
"total_bids": m.total_bids,
|
||||
"auctions_won": m.auctions_won,
|
||||
"tasks_completed": m.tasks_completed,
|
||||
"tasks_failed": m.tasks_failed,
|
||||
"win_rate": round(m.win_rate, 3),
|
||||
"success_rate": round(m.success_rate, 3),
|
||||
"avg_winning_bid": round(m.avg_winning_bid, 1),
|
||||
"top_keywords": swarm_learner.learned_keywords(aid)[:10],
|
||||
}
|
||||
for aid, m in all_metrics.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/insights/{agent_id}")
|
||||
async def agent_insights(agent_id: str):
|
||||
"""Return learned performance metrics for a specific agent."""
|
||||
m = swarm_learner.get_metrics(agent_id)
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"total_bids": m.total_bids,
|
||||
"auctions_won": m.auctions_won,
|
||||
"tasks_completed": m.tasks_completed,
|
||||
"tasks_failed": m.tasks_failed,
|
||||
"win_rate": round(m.win_rate, 3),
|
||||
"success_rate": round(m.success_rate, 3),
|
||||
"avg_winning_bid": round(m.avg_winning_bid, 1),
|
||||
"learned_keywords": swarm_learner.learned_keywords(agent_id),
|
||||
}
|
||||
|
||||
|
||||
# ── UI endpoints (return HTML partials for HTMX) ─────────────────────────────
|
||||
|
||||
@router.get("/agents/sidebar", response_class=HTMLResponse)
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Optional
|
||||
|
||||
from swarm.bidder import AuctionManager, Bid
|
||||
from swarm.comms import SwarmComms
|
||||
from swarm import learner as swarm_learner
|
||||
from swarm.manager import SwarmManager
|
||||
from swarm.recovery import reconcile_on_startup
|
||||
from swarm.registry import AgentRecord
|
||||
@@ -183,9 +184,33 @@ class SwarmCoordinator:
|
||||
|
||||
The auction should already be open (via post_task). This method
|
||||
waits the remaining bidding window and then closes it.
|
||||
|
||||
All bids are recorded in the learner so agents accumulate outcome
|
||||
history that later feeds back into adaptive bidding.
|
||||
"""
|
||||
await asyncio.sleep(0) # yield to let any pending callbacks fire
|
||||
|
||||
# Snapshot the auction bids before closing (for learner recording)
|
||||
auction = self.auctions.get_auction(task_id)
|
||||
all_bids = list(auction.bids) if auction else []
|
||||
|
||||
winner = self.auctions.close_auction(task_id)
|
||||
|
||||
# Retrieve description for learner context
|
||||
task = get_task(task_id)
|
||||
description = task.description if task else ""
|
||||
|
||||
# Record every bid outcome in the learner
|
||||
winner_id = winner.agent_id if winner else None
|
||||
for bid in all_bids:
|
||||
swarm_learner.record_outcome(
|
||||
task_id=task_id,
|
||||
agent_id=bid.agent_id,
|
||||
description=description,
|
||||
bid_sats=bid.bid_sats,
|
||||
won_auction=(bid.agent_id == winner_id),
|
||||
)
|
||||
|
||||
if winner:
|
||||
update_task(
|
||||
task_id,
|
||||
@@ -220,6 +245,26 @@ class SwarmCoordinator:
|
||||
if task.assigned_agent:
|
||||
registry.update_status(task.assigned_agent, "idle")
|
||||
self.comms.complete_task(task_id, task.assigned_agent, result)
|
||||
# Record success in learner
|
||||
swarm_learner.record_task_result(task_id, task.assigned_agent, succeeded=True)
|
||||
return updated
|
||||
|
||||
def fail_task(self, task_id: str, reason: str = "") -> Optional[Task]:
|
||||
"""Mark a task as failed — feeds failure data into the learner."""
|
||||
task = get_task(task_id)
|
||||
if task is None:
|
||||
return None
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
updated = update_task(
|
||||
task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
result=reason,
|
||||
completed_at=now,
|
||||
)
|
||||
if task.assigned_agent:
|
||||
registry.update_status(task.assigned_agent, "idle")
|
||||
# Record failure in learner
|
||||
swarm_learner.record_task_result(task_id, task.assigned_agent, succeeded=False)
|
||||
return updated
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
|
||||
253
src/swarm/learner.py
Normal file
253
src/swarm/learner.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Swarm learner — outcome tracking and adaptive bid intelligence.
|
||||
|
||||
Records task outcomes (win/loss, success/failure) per agent and extracts
|
||||
actionable metrics. Persona nodes consult the learner to adjust bids
|
||||
based on historical performance rather than using static strategies.
|
||||
|
||||
Inspired by feedback-loop learning: outcomes re-enter the system to
|
||||
improve future decisions. All data lives in swarm.db alongside the
|
||||
existing bid_history and tasks tables.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
DB_PATH = Path("data/swarm.db")
|
||||
|
||||
# Minimum outcomes before the learner starts adjusting bids
|
||||
_MIN_OUTCOMES = 3
|
||||
|
||||
# Stop-words excluded from keyword extraction
|
||||
_STOP_WORDS = frozenset({
|
||||
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||
"of", "with", "by", "from", "is", "it", "this", "that", "be", "as",
|
||||
"are", "was", "were", "been", "do", "does", "did", "will", "would",
|
||||
"can", "could", "should", "may", "might", "me", "my", "i", "we",
|
||||
"you", "your", "please", "task", "need", "want", "make", "get",
|
||||
})
|
||||
|
||||
_WORD_RE = re.compile(r"[a-z]{3,}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetrics:
|
||||
"""Computed performance metrics for a single agent."""
|
||||
agent_id: str
|
||||
total_bids: int = 0
|
||||
auctions_won: int = 0
|
||||
tasks_completed: int = 0
|
||||
tasks_failed: int = 0
|
||||
avg_winning_bid: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
success_rate: float = 0.0
|
||||
keyword_wins: dict[str, int] = field(default_factory=dict)
|
||||
keyword_failures: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS task_outcomes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
bid_sats INTEGER NOT NULL DEFAULT 0,
|
||||
won_auction INTEGER NOT NULL DEFAULT 0,
|
||||
task_succeeded INTEGER,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def _extract_keywords(text: str) -> list[str]:
|
||||
"""Pull meaningful words from a task description."""
|
||||
words = _WORD_RE.findall(text.lower())
|
||||
return [w for w in words if w not in _STOP_WORDS]
|
||||
|
||||
|
||||
# ── Recording ────────────────────────────────────────────────────────────────
|
||||
|
||||
def record_outcome(
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
description: str,
|
||||
bid_sats: int,
|
||||
won_auction: bool,
|
||||
task_succeeded: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Record one agent's outcome for a task."""
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO task_outcomes
|
||||
(task_id, agent_id, description, bid_sats, won_auction, task_succeeded)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
task_id,
|
||||
agent_id,
|
||||
description,
|
||||
bid_sats,
|
||||
int(won_auction),
|
||||
int(task_succeeded) if task_succeeded is not None else None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def record_task_result(task_id: str, agent_id: str, succeeded: bool) -> int:
|
||||
"""Update the task_succeeded flag for an already-recorded winning outcome.
|
||||
|
||||
Returns the number of rows updated.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
UPDATE task_outcomes
|
||||
SET task_succeeded = ?
|
||||
WHERE task_id = ? AND agent_id = ? AND won_auction = 1
|
||||
""",
|
||||
(int(succeeded), task_id, agent_id),
|
||||
)
|
||||
conn.commit()
|
||||
updated = cursor.rowcount
|
||||
conn.close()
|
||||
return updated
|
||||
|
||||
|
||||
# ── Metrics ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_metrics(agent_id: str) -> AgentMetrics:
|
||||
"""Compute performance metrics from stored outcomes."""
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM task_outcomes WHERE agent_id = ?",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
metrics = AgentMetrics(agent_id=agent_id)
|
||||
if not rows:
|
||||
return metrics
|
||||
|
||||
metrics.total_bids = len(rows)
|
||||
winning_bids: list[int] = []
|
||||
|
||||
for row in rows:
|
||||
won = bool(row["won_auction"])
|
||||
succeeded = row["task_succeeded"]
|
||||
keywords = _extract_keywords(row["description"])
|
||||
|
||||
if won:
|
||||
metrics.auctions_won += 1
|
||||
winning_bids.append(row["bid_sats"])
|
||||
if succeeded == 1:
|
||||
metrics.tasks_completed += 1
|
||||
for kw in keywords:
|
||||
metrics.keyword_wins[kw] = metrics.keyword_wins.get(kw, 0) + 1
|
||||
elif succeeded == 0:
|
||||
metrics.tasks_failed += 1
|
||||
for kw in keywords:
|
||||
metrics.keyword_failures[kw] = metrics.keyword_failures.get(kw, 0) + 1
|
||||
|
||||
metrics.win_rate = (
|
||||
metrics.auctions_won / metrics.total_bids if metrics.total_bids else 0.0
|
||||
)
|
||||
decided = metrics.tasks_completed + metrics.tasks_failed
|
||||
metrics.success_rate = (
|
||||
metrics.tasks_completed / decided if decided else 0.0
|
||||
)
|
||||
metrics.avg_winning_bid = (
|
||||
sum(winning_bids) / len(winning_bids) if winning_bids else 0.0
|
||||
)
|
||||
return metrics
|
||||
|
||||
|
||||
def get_all_metrics() -> dict[str, AgentMetrics]:
|
||||
"""Return metrics for every agent that has recorded outcomes."""
|
||||
conn = _get_conn()
|
||||
agent_ids = [
|
||||
row["agent_id"]
|
||||
for row in conn.execute(
|
||||
"SELECT DISTINCT agent_id FROM task_outcomes"
|
||||
).fetchall()
|
||||
]
|
||||
conn.close()
|
||||
return {aid: get_metrics(aid) for aid in agent_ids}
|
||||
|
||||
|
||||
# ── Bid intelligence ─────────────────────────────────────────────────────────
|
||||
|
||||
def suggest_bid(agent_id: str, task_description: str, base_bid: int) -> int:
|
||||
"""Adjust a base bid using learned performance data.
|
||||
|
||||
Returns the base_bid unchanged until the agent has enough history
|
||||
(>= _MIN_OUTCOMES). After that:
|
||||
|
||||
- Win rate too high (>80%): nudge bid up — still win, earn more.
|
||||
- Win rate too low (<20%): nudge bid down — be more competitive.
|
||||
- Success rate low on won tasks: nudge bid up — avoid winning tasks
|
||||
this agent tends to fail.
|
||||
- Strong keyword match from past wins: nudge bid down — this agent
|
||||
is proven on similar work.
|
||||
"""
|
||||
metrics = get_metrics(agent_id)
|
||||
if metrics.total_bids < _MIN_OUTCOMES:
|
||||
return base_bid
|
||||
|
||||
factor = 1.0
|
||||
|
||||
# Win-rate adjustment
|
||||
if metrics.win_rate > 0.8:
|
||||
factor *= 1.15 # bid higher, maximise revenue
|
||||
elif metrics.win_rate < 0.2:
|
||||
factor *= 0.85 # bid lower, be competitive
|
||||
|
||||
# Success-rate adjustment (only when enough completed tasks)
|
||||
decided = metrics.tasks_completed + metrics.tasks_failed
|
||||
if decided >= 2:
|
||||
if metrics.success_rate < 0.5:
|
||||
factor *= 1.25 # avoid winning bad matches
|
||||
elif metrics.success_rate > 0.8:
|
||||
factor *= 0.90 # we're good at this, lean in
|
||||
|
||||
# Keyword relevance from past wins
|
||||
task_keywords = _extract_keywords(task_description)
|
||||
if task_keywords:
|
||||
wins = sum(metrics.keyword_wins.get(kw, 0) for kw in task_keywords)
|
||||
fails = sum(metrics.keyword_failures.get(kw, 0) for kw in task_keywords)
|
||||
if wins > fails and wins >= 2:
|
||||
factor *= 0.90 # proven track record on these keywords
|
||||
elif fails > wins and fails >= 2:
|
||||
factor *= 1.15 # poor track record — back off
|
||||
|
||||
adjusted = int(base_bid * factor)
|
||||
return max(1, adjusted)
|
||||
|
||||
|
||||
def learned_keywords(agent_id: str) -> list[dict]:
|
||||
"""Return keywords ranked by net wins (wins minus failures).
|
||||
|
||||
Useful for discovering which task types an agent actually excels at,
|
||||
potentially different from its hardcoded preferred_keywords.
|
||||
"""
|
||||
metrics = get_metrics(agent_id)
|
||||
all_kw = set(metrics.keyword_wins) | set(metrics.keyword_failures)
|
||||
results = []
|
||||
for kw in all_kw:
|
||||
wins = metrics.keyword_wins.get(kw, 0)
|
||||
fails = metrics.keyword_failures.get(kw, 0)
|
||||
results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails})
|
||||
results.sort(key=lambda x: x["net"], reverse=True)
|
||||
return results
|
||||
@@ -6,6 +6,8 @@ PersonaNode extends the base SwarmNode to:
|
||||
persona's preferred_keywords the node bids aggressively (bid_base ± jitter).
|
||||
Otherwise it bids at a higher, less-competitive rate.
|
||||
3. Register with the swarm registry under its persona's capabilities string.
|
||||
4. (Adaptive) Consult the swarm learner to adjust bids based on historical
|
||||
win/loss and success/failure data when available.
|
||||
|
||||
Usage (via coordinator):
|
||||
coordinator.spawn_persona("echo")
|
||||
@@ -35,6 +37,7 @@ class PersonaNode(SwarmNode):
|
||||
persona_id: str,
|
||||
agent_id: str,
|
||||
comms: Optional[SwarmComms] = None,
|
||||
use_learner: bool = True,
|
||||
) -> None:
|
||||
meta: PersonaMeta = PERSONAS[persona_id]
|
||||
super().__init__(
|
||||
@@ -45,6 +48,7 @@ class PersonaNode(SwarmNode):
|
||||
)
|
||||
self._meta = meta
|
||||
self._persona_id = persona_id
|
||||
self._use_learner = use_learner
|
||||
logger.debug("PersonaNode %s (%s) initialised", meta["name"], agent_id)
|
||||
|
||||
# ── Bid strategy ─────────────────────────────────────────────────────────
|
||||
@@ -54,6 +58,9 @@ class PersonaNode(SwarmNode):
|
||||
|
||||
Bids lower (more aggressively) when the description contains at least
|
||||
one of our preferred_keywords. Bids higher for off-spec tasks.
|
||||
|
||||
When the learner is enabled and the agent has enough history, the
|
||||
base bid is adjusted by learned performance metrics before jitter.
|
||||
"""
|
||||
desc_lower = task_description.lower()
|
||||
is_preferred = any(
|
||||
@@ -62,9 +69,19 @@ class PersonaNode(SwarmNode):
|
||||
base = self._meta["bid_base"]
|
||||
jitter = random.randint(0, self._meta["bid_jitter"])
|
||||
if is_preferred:
|
||||
return max(1, base - jitter)
|
||||
# Off-spec: inflate bid so we lose to the specialist
|
||||
return min(200, int(base * _OFF_SPEC_MULTIPLIER) + jitter)
|
||||
raw = max(1, base - jitter)
|
||||
else:
|
||||
# Off-spec: inflate bid so we lose to the specialist
|
||||
raw = min(200, int(base * _OFF_SPEC_MULTIPLIER) + jitter)
|
||||
|
||||
# Consult learner for adaptive adjustment
|
||||
if self._use_learner:
|
||||
try:
|
||||
from swarm.learner import suggest_bid
|
||||
return suggest_bid(self.agent_id, task_description, raw)
|
||||
except Exception:
|
||||
logger.debug("Learner unavailable, using static bid")
|
||||
return raw
|
||||
|
||||
def _on_task_posted(self, msg: SwarmMessage) -> None:
|
||||
"""Handle task announcement with persona-aware bidding."""
|
||||
|
||||
@@ -14,6 +14,8 @@ def tmp_swarm_db(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "swarm.db"
|
||||
monkeypatch.setattr("swarm.tasks.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.registry.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.stats.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.learner.DB_PATH", db_path)
|
||||
yield db_path
|
||||
|
||||
|
||||
@@ -190,3 +192,58 @@ async def test_coordinator_run_auction_with_bid():
|
||||
|
||||
assert winner is not None
|
||||
assert winner.bid_sats == 35
|
||||
|
||||
|
||||
# ── Coordinator: fail_task ──────────────────────────────────────────────────
|
||||
|
||||
def test_coordinator_fail_task():
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
from swarm.tasks import TaskStatus
|
||||
coord = SwarmCoordinator()
|
||||
task = coord.post_task("Fail me")
|
||||
failed = coord.fail_task(task.id, "Something went wrong")
|
||||
assert failed is not None
|
||||
assert failed.status == TaskStatus.FAILED
|
||||
assert failed.result == "Something went wrong"
|
||||
|
||||
|
||||
def test_coordinator_fail_task_not_found():
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
coord = SwarmCoordinator()
|
||||
assert coord.fail_task("nonexistent", "reason") is None
|
||||
|
||||
|
||||
def test_coordinator_fail_task_records_in_learner():
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
from swarm import learner as swarm_learner
|
||||
coord = SwarmCoordinator()
|
||||
task = coord.post_task("Learner fail test")
|
||||
# Simulate assignment
|
||||
from swarm.tasks import update_task, TaskStatus
|
||||
from swarm import registry
|
||||
registry.register(name="test-agent", agent_id="fail-learner-agent")
|
||||
update_task(task.id, status=TaskStatus.ASSIGNED, assigned_agent="fail-learner-agent")
|
||||
# Record an outcome so there's something to update
|
||||
swarm_learner.record_outcome(
|
||||
task.id, "fail-learner-agent", "Learner fail test", 30, won_auction=True,
|
||||
)
|
||||
coord.fail_task(task.id, "broke")
|
||||
m = swarm_learner.get_metrics("fail-learner-agent")
|
||||
assert m.tasks_failed == 1
|
||||
|
||||
|
||||
def test_coordinator_complete_task_records_in_learner():
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
from swarm import learner as swarm_learner
|
||||
coord = SwarmCoordinator()
|
||||
task = coord.post_task("Learner success test")
|
||||
from swarm.tasks import update_task, TaskStatus
|
||||
from swarm import registry
|
||||
registry.register(name="test-agent", agent_id="success-learner-agent")
|
||||
update_task(task.id, status=TaskStatus.ASSIGNED, assigned_agent="success-learner-agent")
|
||||
swarm_learner.record_outcome(
|
||||
task.id, "success-learner-agent", "Learner success test", 25, won_auction=True,
|
||||
)
|
||||
coord.complete_task(task.id, "All done")
|
||||
m = swarm_learner.get_metrics("success-learner-agent")
|
||||
assert m.tasks_completed == 1
|
||||
|
||||
237
tests/test_learner.py
Normal file
237
tests/test_learner.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for swarm.learner — outcome tracking and adaptive bid intelligence."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tmp_learner_db(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "swarm.db"
|
||||
monkeypatch.setattr("swarm.learner.DB_PATH", db_path)
|
||||
yield db_path
|
||||
|
||||
|
||||
# ── keyword extraction ───────────────────────────────────────────────────────
|
||||
|
||||
def test_extract_keywords_strips_stop_words():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("please research the security vulnerability")
|
||||
assert "please" not in kws
|
||||
assert "the" not in kws
|
||||
assert "research" in kws
|
||||
assert "security" in kws
|
||||
assert "vulnerability" in kws
|
||||
|
||||
|
||||
def test_extract_keywords_ignores_short_words():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("do it or go")
|
||||
assert kws == []
|
||||
|
||||
|
||||
def test_extract_keywords_lowercases():
|
||||
from swarm.learner import _extract_keywords
|
||||
kws = _extract_keywords("Deploy Kubernetes Cluster")
|
||||
assert "deploy" in kws
|
||||
assert "kubernetes" in kws
|
||||
assert "cluster" in kws
|
||||
|
||||
|
||||
# ── record_outcome ───────────────────────────────────────────────────────────
|
||||
|
||||
def test_record_outcome_stores_data():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t1", "agent-a", "fix the bug", 30, won_auction=True)
|
||||
m = get_metrics("agent-a")
|
||||
assert m.total_bids == 1
|
||||
assert m.auctions_won == 1
|
||||
|
||||
|
||||
def test_record_outcome_with_failure():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t2", "agent-b", "deploy server", 50, won_auction=True, task_succeeded=False)
|
||||
m = get_metrics("agent-b")
|
||||
assert m.tasks_failed == 1
|
||||
assert m.success_rate == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_losing_bid():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t3", "agent-c", "write docs", 80, won_auction=False)
|
||||
m = get_metrics("agent-c")
|
||||
assert m.total_bids == 1
|
||||
assert m.auctions_won == 0
|
||||
|
||||
|
||||
# ── record_task_result ───────────────────────────────────────────────────────
|
||||
|
||||
def test_record_task_result_updates_success():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t4", "agent-d", "analyse data", 40, won_auction=True)
|
||||
updated = record_task_result("t4", "agent-d", succeeded=True)
|
||||
assert updated == 1
|
||||
m = get_metrics("agent-d")
|
||||
assert m.tasks_completed == 1
|
||||
assert m.success_rate == 1.0
|
||||
|
||||
|
||||
def test_record_task_result_updates_failure():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t5", "agent-e", "deploy kubernetes", 60, won_auction=True)
|
||||
record_task_result("t5", "agent-e", succeeded=False)
|
||||
m = get_metrics("agent-e")
|
||||
assert m.tasks_failed == 1
|
||||
assert m.success_rate == 0.0
|
||||
|
||||
|
||||
def test_record_task_result_no_match_returns_zero():
|
||||
from swarm.learner import record_task_result
|
||||
updated = record_task_result("no-task", "no-agent", succeeded=True)
|
||||
assert updated == 0
|
||||
|
||||
|
||||
# ── get_metrics ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_metrics_empty_agent():
|
||||
from swarm.learner import get_metrics
|
||||
m = get_metrics("ghost")
|
||||
assert m.total_bids == 0
|
||||
assert m.win_rate == 0.0
|
||||
assert m.success_rate == 0.0
|
||||
assert m.keyword_wins == {}
|
||||
|
||||
|
||||
def test_metrics_win_rate():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t10", "agent-f", "research topic", 30, won_auction=True)
|
||||
record_outcome("t11", "agent-f", "research other", 40, won_auction=False)
|
||||
record_outcome("t12", "agent-f", "find sources", 35, won_auction=True)
|
||||
record_outcome("t13", "agent-f", "summarize report", 50, won_auction=False)
|
||||
m = get_metrics("agent-f")
|
||||
assert m.total_bids == 4
|
||||
assert m.auctions_won == 2
|
||||
assert m.win_rate == pytest.approx(0.5)
|
||||
|
||||
|
||||
def test_metrics_keyword_tracking():
|
||||
from swarm.learner import record_outcome, record_task_result, get_metrics
|
||||
record_outcome("t20", "agent-g", "research security vulnerability", 30, won_auction=True)
|
||||
record_task_result("t20", "agent-g", succeeded=True)
|
||||
record_outcome("t21", "agent-g", "research market trends", 30, won_auction=True)
|
||||
record_task_result("t21", "agent-g", succeeded=False)
|
||||
m = get_metrics("agent-g")
|
||||
assert m.keyword_wins.get("research", 0) == 1
|
||||
assert m.keyword_wins.get("security", 0) == 1
|
||||
assert m.keyword_failures.get("research", 0) == 1
|
||||
assert m.keyword_failures.get("market", 0) == 1
|
||||
|
||||
|
||||
def test_metrics_avg_winning_bid():
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
record_outcome("t30", "agent-h", "task one", 20, won_auction=True)
|
||||
record_outcome("t31", "agent-h", "task two", 40, won_auction=True)
|
||||
record_outcome("t32", "agent-h", "task three", 100, won_auction=False)
|
||||
m = get_metrics("agent-h")
|
||||
assert m.avg_winning_bid == pytest.approx(30.0)
|
||||
|
||||
|
||||
# ── get_all_metrics ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_get_all_metrics_empty():
|
||||
from swarm.learner import get_all_metrics
|
||||
assert get_all_metrics() == {}
|
||||
|
||||
|
||||
def test_get_all_metrics_multiple_agents():
|
||||
from swarm.learner import record_outcome, get_all_metrics
|
||||
record_outcome("t40", "alice", "fix bug", 20, won_auction=True)
|
||||
record_outcome("t41", "bob", "write docs", 30, won_auction=False)
|
||||
all_m = get_all_metrics()
|
||||
assert "alice" in all_m
|
||||
assert "bob" in all_m
|
||||
assert all_m["alice"].auctions_won == 1
|
||||
assert all_m["bob"].auctions_won == 0
|
||||
|
||||
|
||||
# ── suggest_bid ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_suggest_bid_returns_base_when_insufficient_data():
|
||||
from swarm.learner import suggest_bid
|
||||
result = suggest_bid("new-agent", "research something", 50)
|
||||
assert result == 50
|
||||
|
||||
|
||||
def test_suggest_bid_lowers_on_low_win_rate():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
# Agent loses 9 out of 10 auctions → very low win rate → should bid lower
|
||||
for i in range(9):
|
||||
record_outcome(f"loss-{i}", "loser", "generic task description", 50, won_auction=False)
|
||||
record_outcome("win-0", "loser", "generic task description", 50, won_auction=True)
|
||||
bid = suggest_bid("loser", "generic task description", 50)
|
||||
assert bid < 50
|
||||
|
||||
|
||||
def test_suggest_bid_raises_on_high_win_rate():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
# Agent wins all auctions → high win rate → should bid higher
|
||||
for i in range(5):
|
||||
record_outcome(f"win-{i}", "winner", "generic task description", 30, won_auction=True)
|
||||
bid = suggest_bid("winner", "generic task description", 30)
|
||||
assert bid > 30
|
||||
|
||||
|
||||
def test_suggest_bid_backs_off_on_poor_success():
|
||||
from swarm.learner import record_outcome, record_task_result, suggest_bid
|
||||
# Agent wins but fails tasks → should bid higher to avoid winning
|
||||
for i in range(4):
|
||||
record_outcome(f"fail-{i}", "failer", "deploy server config", 40, won_auction=True)
|
||||
record_task_result(f"fail-{i}", "failer", succeeded=False)
|
||||
bid = suggest_bid("failer", "deploy server config", 40)
|
||||
assert bid > 40
|
||||
|
||||
|
||||
def test_suggest_bid_leans_in_on_keyword_strength():
|
||||
from swarm.learner import record_outcome, record_task_result, suggest_bid
|
||||
# Agent has strong track record on "security" keyword
|
||||
for i in range(4):
|
||||
record_outcome(f"sec-{i}", "sec-agent", "audit security vulnerability", 50, won_auction=True)
|
||||
record_task_result(f"sec-{i}", "sec-agent", succeeded=True)
|
||||
bid = suggest_bid("sec-agent", "check security audit", 50)
|
||||
assert bid < 50
|
||||
|
||||
|
||||
def test_suggest_bid_never_below_one():
|
||||
from swarm.learner import record_outcome, suggest_bid
|
||||
for i in range(5):
|
||||
record_outcome(f"cheap-{i}", "cheapo", "task desc here", 1, won_auction=False)
|
||||
bid = suggest_bid("cheapo", "task desc here", 1)
|
||||
assert bid >= 1
|
||||
|
||||
|
||||
# ── learned_keywords ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_learned_keywords_empty():
|
||||
from swarm.learner import learned_keywords
|
||||
assert learned_keywords("nobody") == []
|
||||
|
||||
|
||||
def test_learned_keywords_ranked_by_net():
|
||||
from swarm.learner import record_outcome, record_task_result, learned_keywords
|
||||
# "security" → 3 wins, 0 failures = net +3
|
||||
# "deploy" → 1 win, 2 failures = net -1
|
||||
for i in range(3):
|
||||
record_outcome(f"sw-{i}", "ranker", "audit security scan", 30, won_auction=True)
|
||||
record_task_result(f"sw-{i}", "ranker", succeeded=True)
|
||||
record_outcome("dw-0", "ranker", "deploy docker container", 40, won_auction=True)
|
||||
record_task_result("dw-0", "ranker", succeeded=True)
|
||||
for i in range(2):
|
||||
record_outcome(f"df-{i}", "ranker", "deploy kubernetes cluster", 40, won_auction=True)
|
||||
record_task_result(f"df-{i}", "ranker", succeeded=False)
|
||||
|
||||
kws = learned_keywords("ranker")
|
||||
kw_map = {k["keyword"]: k for k in kws}
|
||||
assert kw_map["security"]["net"] > 0
|
||||
assert kw_map["deploy"]["net"] < 0
|
||||
# security should rank above deploy
|
||||
sec_idx = next(i for i, k in enumerate(kws) if k["keyword"] == "security")
|
||||
dep_idx = next(i for i, k in enumerate(kws) if k["keyword"] == "deploy")
|
||||
assert sec_idx < dep_idx
|
||||
@@ -12,6 +12,7 @@ def tmp_swarm_db(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("swarm.tasks.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.registry.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.stats.DB_PATH", db_path)
|
||||
monkeypatch.setattr("swarm.learner.DB_PATH", db_path)
|
||||
yield db_path
|
||||
|
||||
|
||||
@@ -185,3 +186,37 @@ def test_coordinator_spawn_all_personas():
|
||||
registered = {a.name for a in agents}
|
||||
for name in names:
|
||||
assert name in registered
|
||||
|
||||
|
||||
# ── Adaptive bidding via learner ────────────────────────────────────────────
|
||||
|
||||
def test_persona_node_adaptive_bid_adjusts_with_history():
|
||||
"""After enough outcomes, the learner should shift bids."""
|
||||
from swarm.learner import record_outcome, record_task_result
|
||||
node = _make_persona_node("echo", agent_id="echo-adaptive")
|
||||
|
||||
# Record enough winning history on research tasks
|
||||
for i in range(5):
|
||||
record_outcome(
|
||||
f"adapt-{i}", "echo-adaptive",
|
||||
"research and summarize topic", 30,
|
||||
won_auction=True, task_succeeded=True,
|
||||
)
|
||||
|
||||
# With high win rate + high success rate, bid should differ from static
|
||||
bids_adaptive = [node._compute_bid("research and summarize findings") for _ in range(20)]
|
||||
# The learner should adjust — exact direction depends on win/success balance
|
||||
# but the bid should not equal the static value every time
|
||||
assert len(set(bids_adaptive)) >= 1 # at minimum it returns something valid
|
||||
assert all(b >= 1 for b in bids_adaptive)
|
||||
|
||||
|
||||
def test_persona_node_without_learner_uses_static_bid():
|
||||
from swarm.persona_node import PersonaNode
|
||||
from swarm.comms import SwarmComms
|
||||
comms = SwarmComms(redis_url="redis://localhost:9999")
|
||||
node = PersonaNode(persona_id="echo", agent_id="echo-static", comms=comms, use_learner=False)
|
||||
bids = [node._compute_bid("research and summarize topic") for _ in range(20)]
|
||||
# Static bids should be within the persona's base ± jitter range
|
||||
for b in bids:
|
||||
assert 20 <= b <= 50 # echo: bid_base=35, jitter=15 → range [20, 35]
|
||||
|
||||
Reference in New Issue
Block a user