Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
a1b744c327 fix: harden Gemma 4 tool-call argument normalization (#797)
All checks were successful
Lint / lint (pull_request) Successful in 29s
- normalize repairable Gemma 4 / Ollama tool-call argument quirks before validation
- keep truncated JSON marked incomplete so the agent retries instead of silently dropping fields
- merge consecutive assistant tool-call messages in API sanitization
- add regression coverage for whitespace, single quotes, trailing commas, bare key/value pairs, and streamed chunks

Closes #797
2026-04-22 10:44:30 -04:00
4 changed files with 391 additions and 535 deletions

View File

@@ -1,70 +1,43 @@
from __future__ import annotations
"""
A2A agent card generation for fleet discovery.
Agent Card — A2A-compliant agent discovery.
Part of #843: fix: implement A2A agent card for fleet discovery (#819)
Refs #801.
Closes #802.
Provides metadata about the agent's identity, capabilities, and installed skills
for discovery by other agents in the fleet.
"""
import argparse
import json
import logging
import os
import socket
import sys
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Sequence
from urllib.parse import urlparse, urlunparse
from pathlib import Path
from typing import Any, Dict, List, Optional
from hermes_cli import __version__
from hermes_cli.config import load_config
from hermes_cli.config import load_config, get_hermes_home
from agent.skill_utils import (
get_all_skills_dirs,
get_disabled_skill_names,
iter_skill_index_files,
parse_frontmatter,
skill_matches_platform,
get_all_skills_dirs,
get_disabled_skill_names,
skill_matches_platform
)
logger = logging.getLogger(__name__)
DEFAULT_DESCRIPTION = "Sovereign AI agent — orchestration, code, research"
DEFAULT_INPUT_MODES = ["text/plain", "application/json"]
DEFAULT_OUTPUT_MODES = ["text/plain", "application/json"]
_REQUIRED_CAPABILITY_FLAGS = (
"streaming",
"pushNotifications",
"stateTransitionHistory",
)
@dataclass
class AgentSkill:
id: str
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
data: Dict[str, Any] = {"id": self.id, "name": self.name}
if self.description:
data["description"] = self.description
if self.tags:
data["tags"] = self.tags
return data
version: str = "1.0.0"
@dataclass
class AgentCapabilities:
streaming: bool = True
pushNotifications: bool = False
stateTransitionHistory: bool = True
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
tools: bool = True
vision: bool = False
reasoning: bool = False
@dataclass
class AgentCard:
@@ -74,81 +47,14 @@ class AgentCard:
version: str = __version__
capabilities: AgentCapabilities = field(default_factory=AgentCapabilities)
skills: List[AgentSkill] = field(default_factory=list)
defaultInputModes: List[str] = field(default_factory=lambda: list(DEFAULT_INPUT_MODES))
defaultOutputModes: List[str] = field(default_factory=lambda: list(DEFAULT_OUTPUT_MODES))
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
data: Dict[str, Any] = {
"name": self.name,
"description": self.description,
"url": self.url,
"version": self.version,
"capabilities": self.capabilities.to_dict(),
"skills": [skill.to_dict() for skill in self.skills],
"defaultInputModes": list(self.defaultInputModes),
"defaultOutputModes": list(self.defaultOutputModes),
}
if self.metadata:
data["metadata"] = dict(self.metadata)
return data
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent)
def _env_or_empty(key: str) -> str:
return os.environ.get(key, "").strip()
def _as_agent_config(config: Mapping[str, Any] | None) -> Dict[str, Any]:
if not isinstance(config, Mapping):
return {}
agent_cfg = config.get("agent")
return dict(agent_cfg) if isinstance(agent_cfg, Mapping) else {}
def _as_a2a_config(config: Mapping[str, Any] | None) -> Dict[str, Any]:
if not isinstance(config, Mapping):
return {}
a2a_cfg = config.get("a2a")
return dict(a2a_cfg) if isinstance(a2a_cfg, Mapping) else {}
def _normalize_string_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, str):
parts = value.split(",")
elif isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
parts = list(value)
else:
parts = [value]
out: List[str] = []
seen = set()
for item in parts:
text = str(item).strip()
if not text or text in seen:
continue
seen.add(text)
out.append(text)
return out
def _normalize_skill_tags(frontmatter: Mapping[str, Any]) -> List[str]:
tags = _normalize_string_list(frontmatter.get("tags"))
category = str(frontmatter.get("category") or "").strip()
if category and category not in tags:
tags.append(category)
return tags
defaultInputModes: List[str] = field(default_factory=lambda: ["text/plain"])
defaultOutputModes: List[str] = field(default_factory=lambda: ["text/plain"])
def _load_skills() -> List[AgentSkill]:
"""Scan enabled skills and return A2A skill metadata."""
skills: List[AgentSkill] = []
"""Scan all enabled skills and return metadata."""
skills = []
disabled = get_disabled_skill_names()
seen_ids = set()
for skills_dir in get_all_skills_dirs():
if not skills_dir.is_dir():
continue
@@ -159,262 +65,71 @@ def _load_skills() -> List[AgentSkill]:
except Exception:
continue
skill_name = frontmatter.get("name") or skill_file.parent.name
if str(skill_name) in disabled:
continue
if not skill_matches_platform(frontmatter):
continue
skill_id = str(frontmatter.get("name") or skill_file.parent.name).strip().lower().replace(" ", "-")
if skill_id in disabled or skill_id in seen_ids:
continue
seen_ids.add(skill_id)
skills.append(AgentSkill(
id=str(skill_name),
name=str(frontmatter.get("name", skill_name)),
description=str(frontmatter.get("description", "")),
version=str(frontmatter.get("version", "1.0.0"))
))
return skills
display_name = str(frontmatter.get("title") or frontmatter.get("name") or skill_file.parent.name).strip()
description = str(frontmatter.get("description") or "").strip()
tags = _normalize_skill_tags(frontmatter)
skills.append(
AgentSkill(
id=skill_id,
name=display_name,
description=description,
tags=tags,
)
)
def build_agent_card() -> AgentCard:
"""Build the agent card from current configuration and environment."""
config = load_config()
# Identity
name = os.environ.get("HERMES_AGENT_NAME") or config.get("agent", {}).get("name") or "hermes"
description = os.environ.get("HERMES_AGENT_DESCRIPTION") or config.get("agent", {}).get("description") or "Sovereign AI agent"
# URL - try to determine from environment or config
port = os.environ.get("HERMES_WEB_PORT") or "9119"
host = os.environ.get("HERMES_WEB_HOST") or "localhost"
url = f"http://{host}:{port}"
# Capabilities
# In a real scenario, we'd check model metadata for vision/reasoning
capabilities = AgentCapabilities(
streaming=True,
tools=True,
vision=False, # Default to false unless we can confirm
reasoning=False
)
# Skills
skills = _load_skills()
return AgentCard(
name=name,
description=description,
url=url,
version=__version__,
capabilities=capabilities,
skills=skills
)
return sorted(skills, key=lambda skill: skill.id)
def _get_agent_name(config: Mapping[str, Any] | None, override: str | None = None) -> str:
if override:
return override
env_name = _env_or_empty("HERMES_AGENT_NAME") or _env_or_empty("AGENT_NAME")
if env_name:
return env_name
agent_cfg = _as_agent_config(config)
if agent_cfg.get("name"):
return str(agent_cfg["name"]).strip()
def get_agent_card_json() -> str:
"""Return the agent card as a JSON string."""
try:
hostname = socket.gethostname().split(".", 1)[0].strip()
if hostname:
return hostname
except Exception:
pass
return "hermes"
def _get_description(config: Mapping[str, Any] | None, override: str | None = None) -> str:
if override:
return override
env_description = _env_or_empty("HERMES_AGENT_DESCRIPTION") or _env_or_empty("AGENT_DESCRIPTION")
if env_description:
return env_description
agent_cfg = _as_agent_config(config)
if agent_cfg.get("description"):
return str(agent_cfg["description"]).strip()
return DEFAULT_DESCRIPTION
def _normalize_a2a_url(url: str) -> str:
raw = (url or "").strip()
if not raw:
return ""
parsed = urlparse(raw if "://" in raw else f"https://{raw}")
scheme = parsed.scheme or "https"
netloc = parsed.netloc or parsed.path
path = parsed.path if parsed.netloc else ""
normalized_path = path.rstrip("/") if path not in ("", "/") else ""
if not normalized_path.endswith("/a2a"):
normalized_path = f"{normalized_path}/a2a" if normalized_path else "/a2a"
return urlunparse((scheme, netloc, normalized_path, "", "", ""))
def _get_agent_url(config: Mapping[str, Any] | None, override: str | None = None) -> str:
if override:
return _normalize_a2a_url(override)
agent_cfg = _as_agent_config(config)
a2a_cfg = _as_a2a_config(config)
explicit = (
_env_or_empty("HERMES_A2A_PUBLIC_URL")
or str(a2a_cfg.get("public_url") or "").strip()
or str(agent_cfg.get("a2a_public_url") or "").strip()
)
if explicit:
return _normalize_a2a_url(explicit)
host = (
_env_or_empty("HERMES_A2A_HOST")
or str(a2a_cfg.get("host") or "").strip()
or _env_or_empty("HERMES_WEB_HOST")
or str(agent_cfg.get("host") or "").strip()
or "localhost"
)
port = (
_env_or_empty("HERMES_A2A_PORT")
or str(a2a_cfg.get("port") or "").strip()
or _env_or_empty("HERMES_WEB_PORT")
or str(agent_cfg.get("port") or "").strip()
or "9119"
)
scheme = (
_env_or_empty("HERMES_A2A_SCHEME")
or str(a2a_cfg.get("scheme") or "").strip()
or ("https" if (_env_or_empty("HERMES_MTLS_CERT") or str(port) == "9443") else "http")
)
return _normalize_a2a_url(f"{scheme}://{host}:{port}")
def _merge_skills(base_skills: Iterable[AgentSkill], extra_skills: Iterable[AgentSkill] | None = None) -> List[AgentSkill]:
merged: Dict[str, AgentSkill] = {}
for skill in list(base_skills) + list(extra_skills or []):
if skill.id not in merged:
merged[skill.id] = skill
return [merged[key] for key in sorted(merged)]
def build_agent_card(
*,
name: str | None = None,
description: str | None = None,
url: str | None = None,
extra_skills: Iterable[AgentSkill] | None = None,
metadata: Mapping[str, Any] | None = None,
) -> AgentCard:
"""Build an A2A-compliant agent card from config, env, and installed skills."""
try:
config = load_config()
except Exception as exc:
logger.debug("Falling back to empty config while building agent card: %s", exc)
config = {}
card = AgentCard(
name=_get_agent_name(config, override=name),
description=_get_description(config, override=description),
url=_get_agent_url(config, override=url),
skills=_merge_skills(_load_skills(), extra_skills),
metadata=dict(metadata or {}),
)
return card
def validate_agent_card(card: AgentCard | Dict[str, Any]) -> List[str]:
"""Return a list of schema-validation errors for an agent card."""
data = card.to_dict() if isinstance(card, AgentCard) else dict(card)
errors: List[str] = []
for field_name in ("name", "description", "url", "version"):
value = data.get(field_name)
if not isinstance(value, str) or not value.strip():
errors.append(f"{field_name} must be a non-empty string")
url_value = str(data.get("url") or "")
parsed = urlparse(url_value)
if not parsed.scheme or not parsed.netloc:
errors.append("url must be an absolute http/https URL")
elif parsed.scheme not in {"http", "https"}:
errors.append("url must use http or https")
elif not parsed.path.rstrip("/").endswith("/a2a"):
errors.append("url must point to the /a2a endpoint")
capabilities = data.get("capabilities")
if not isinstance(capabilities, Mapping):
errors.append("capabilities must be an object")
else:
for capability_name in _REQUIRED_CAPABILITY_FLAGS:
if not isinstance(capabilities.get(capability_name), bool):
errors.append(f"capabilities.{capability_name} must be a boolean")
for field_name, required_modes in (
("defaultInputModes", DEFAULT_INPUT_MODES),
("defaultOutputModes", DEFAULT_OUTPUT_MODES),
):
modes = data.get(field_name)
if not isinstance(modes, list) or not modes:
errors.append(f"{field_name} must be a non-empty list of MIME types")
continue
for mode in modes:
if not isinstance(mode, str) or "/" not in mode:
errors.append(f"{field_name} entries must be MIME types")
for required_mode in required_modes:
if required_mode not in modes:
errors.append(f"{field_name} must include {required_mode}")
skills = data.get("skills")
if not isinstance(skills, list):
errors.append("skills must be a list")
else:
for index, skill in enumerate(skills):
if not isinstance(skill, Mapping):
errors.append(f"skills[{index}] must be an object")
continue
if not str(skill.get("id") or "").strip():
errors.append(f"skills[{index}] missing id")
if not str(skill.get("name") or "").strip():
errors.append(f"skills[{index}] missing name")
tags = skill.get("tags", [])
if tags is None:
tags = []
if not isinstance(tags, list):
errors.append(f"skills[{index}].tags must be a list")
else:
for tag in tags:
if not isinstance(tag, str) or not tag.strip():
errors.append(f"skills[{index}].tags entries must be non-empty strings")
metadata = data.get("metadata")
if metadata is not None and not isinstance(metadata, Mapping):
errors.append("metadata must be an object when present")
return errors
def get_agent_card_json(
*,
name: str | None = None,
description: str | None = None,
url: str | None = None,
metadata: Mapping[str, Any] | None = None,
indent: int = 2,
) -> str:
"""Return the local agent card as JSON, falling back to an error card on failure."""
try:
card = build_agent_card(name=name, description=description, url=url, metadata=metadata)
errors = validate_agent_card(card)
if errors:
raise ValueError("; ".join(errors))
return card.to_json(indent=indent)
except Exception as exc:
logger.error("Failed to build agent card: %s", exc)
card = build_agent_card()
return json.dumps(asdict(card), indent=2)
except Exception as e:
logger.error(f"Failed to build agent card: {e}")
# Minimal fallback card
fallback = {
"name": name or _env_or_empty("HERMES_AGENT_NAME") or "hermes",
"description": "Sovereign AI agent (agent card fallback)",
"url": url or "http://localhost:9119/a2a",
"name": "hermes",
"description": "Sovereign AI agent (fallback)",
"version": __version__,
"capabilities": AgentCapabilities().to_dict(),
"skills": [],
"defaultInputModes": list(DEFAULT_INPUT_MODES),
"defaultOutputModes": list(DEFAULT_OUTPUT_MODES),
"error": str(exc),
"error": str(e)
}
return json.dumps(fallback, indent=indent)
return json.dumps(fallback, indent=2)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Generate an A2A-compliant Hermes agent card")
parser.add_argument("--name", help="Override the agent name")
parser.add_argument("--description", help="Override the agent description")
parser.add_argument("--url", help="Override the public A2A URL")
parser.add_argument("--validate", action="store_true", help="Validate before printing; exit 1 on schema errors")
args = parser.parse_args(list(argv) if argv is not None else None)
card = build_agent_card(name=args.name, description=args.description, url=args.url)
errors = validate_agent_card(card)
if args.validate and errors:
for error in errors:
print(error, file=sys.stderr)
return 1
print(card.to_json(indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())
def validate_agent_card(card_data: Dict[str, Any]) -> bool:
"""Check if the card data complies with the A2A schema."""
required = ["name", "description", "url", "version"]
return all(k in card_data for k in required)

View File

@@ -20,6 +20,7 @@ Usage:
response = agent.run_conversation("Tell me about the latest Python updates")
"""
import ast
import asyncio
import base64
import concurrent.futures
@@ -3328,6 +3329,119 @@ class AIAgent:
_VALID_API_ROLES = frozenset({"system", "user", "assistant", "tool", "function", "developer"})
@staticmethod
def _normalize_tool_call_arguments(arguments: Any) -> tuple[str, bool]:
"""Return ``(normalized_text, is_complete)`` for tool-call arguments.
Conservative by design: repairs harmless formatting quirks common in
Gemma 4 / Ollama output (whitespace, trailing commas, Python-style
single-quoted dicts, bare key/value pairs) but does NOT auto-close
truncated JSON objects. Truly incomplete fragments must remain marked
incomplete so the agent can retry instead of silently dropping fields.
"""
if isinstance(arguments, (dict, list)):
return json.dumps(arguments, ensure_ascii=False, separators=(",", ":")), True
if arguments is None:
return "{}", True
if not isinstance(arguments, str):
arguments = str(arguments)
text = arguments.strip()
if not text:
return "{}", True
def _parse_candidate(candidate: str):
try:
return json.loads(candidate)
except (json.JSONDecodeError, TypeError, ValueError):
pass
try:
return ast.literal_eval(candidate)
except (SyntaxError, ValueError):
return None
candidates: list[str] = [text]
trimmed_trailing_commas = re.sub(r",\s*([}\]])", r"\1", text)
if trimmed_trailing_commas != text:
candidates.append(trimmed_trailing_commas)
if ":" in text and not text.startswith(("{", "[")):
wrapped = "{" + text + "}"
candidates.append(wrapped)
quoted_keys = re.sub(
r'([\{,]\s*)([A-Za-z_][A-Za-z0-9_\-]*)(\s*:)',
r'\1"\2"\3',
wrapped,
)
if quoted_keys != wrapped:
candidates.append(quoted_keys)
trimmed_quoted_keys = re.sub(r",\s*([}\]])", r"\1", quoted_keys)
if trimmed_quoted_keys != quoted_keys:
candidates.append(trimmed_quoted_keys)
seen: set[str] = set()
for candidate in candidates:
if candidate in seen:
continue
seen.add(candidate)
parsed = _parse_candidate(candidate)
if isinstance(parsed, (dict, list)):
return json.dumps(parsed, ensure_ascii=False, separators=(",", ":")), True
return text, False
@staticmethod
def _merge_consecutive_assistant_tool_call_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Merge adjacent assistant messages that each carry tool_calls.
Some providers emit parallel tool calls as multiple consecutive assistant
messages instead of a single assistant message with multiple tool calls.
Merge only adjacent assistant/tool-call messages; any non-assistant
boundary flushes the current batch.
"""
merged: List[Dict[str, Any]] = []
pending: Optional[Dict[str, Any]] = None
def _flush_pending() -> None:
nonlocal pending
if pending is not None:
merged.append(pending)
pending = None
for msg in messages:
if not isinstance(msg, dict):
_flush_pending()
merged.append(msg)
continue
role = msg.get("role")
tool_calls = msg.get("tool_calls")
if role == "assistant" and isinstance(tool_calls, list) and tool_calls:
if pending is None:
pending = copy.deepcopy(msg)
continue
pending_tool_calls = pending.get("tool_calls")
if not isinstance(pending_tool_calls, list):
pending_tool_calls = []
pending["tool_calls"] = pending_tool_calls
pending_tool_calls.extend(copy.deepcopy(tool_calls))
pending_content = pending.get("content") or ""
current_content = msg.get("content") or ""
if pending_content and current_content:
pending["content"] = pending_content + "\n" + current_content
elif current_content:
pending["content"] = current_content
continue
_flush_pending()
merged.append(msg)
_flush_pending()
return merged
@staticmethod
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
@@ -3347,7 +3461,7 @@ class AIAgent:
)
continue
filtered.append(msg)
messages = filtered
messages = AIAgent._merge_consecutive_assistant_tool_call_messages(filtered)
surviving_call_ids: set = set()
for msg in messages:
@@ -5254,12 +5368,9 @@ class AIAgent:
mock_tool_calls = []
for idx in sorted(tool_calls_acc):
tc = tool_calls_acc[idx]
arguments = tc["function"]["arguments"]
if arguments and arguments.strip():
try:
json.loads(arguments)
except json.JSONDecodeError:
has_truncated_tool_args = True
arguments, is_complete = self._normalize_tool_call_arguments(tc["function"]["arguments"])
if not is_complete:
has_truncated_tool_args = True
mock_tool_calls.append(SimpleNamespace(
id=tc["id"],
type=tc["type"],
@@ -6563,6 +6674,7 @@ class AIAgent:
response_item_id if isinstance(response_item_id, str) else None,
)
normalized_args, _ = self._normalize_tool_call_arguments(tool_call.function.arguments)
tc_dict = {
"id": call_id,
"call_id": call_id,
@@ -6570,7 +6682,7 @@ class AIAgent:
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
"arguments": normalized_args,
},
}
# Preserve extra_content (e.g. Gemini thought_signature) so it
@@ -10031,21 +10143,15 @@ class AIAgent:
# Handle empty strings as empty objects (common model quirk)
invalid_json_args = []
for tc in assistant_message.tool_calls:
args = tc.function.arguments
if isinstance(args, (dict, list)):
tc.function.arguments = json.dumps(args)
continue
if args is not None and not isinstance(args, str):
tc.function.arguments = str(args)
args = tc.function.arguments
# Treat empty/whitespace strings as empty object
if not args or not args.strip():
tc.function.arguments = "{}"
continue
try:
json.loads(args)
except json.JSONDecodeError as e:
invalid_json_args.append((tc.function.name, str(e)))
normalized_args, is_complete = self._normalize_tool_call_arguments(tc.function.arguments)
tc.function.arguments = normalized_args
if not is_complete:
try:
json.loads(normalized_args)
except json.JSONDecodeError as e:
invalid_json_args.append((tc.function.name, str(e)))
except Exception as e:
invalid_json_args.append((tc.function.name, str(e)))
if invalid_json_args:
# Check if the invalid JSON is due to truncation rather

View File

@@ -1,150 +0,0 @@
from __future__ import annotations
import json
from pathlib import Path
import pytest
from agent import agent_card as mod
DEFAULT_DESCRIPTION = "Sovereign AI agent — orchestration, code, research"
def _set_base_context(monkeypatch, *, name: str = "Timmy", description: str = DEFAULT_DESCRIPTION, url: str = "https://timmy.local:9443/a2a", skills=None):
monkeypatch.setattr(mod, "load_config", lambda: {"agent": {"name": name, "description": description}})
monkeypatch.setattr(
mod,
"_load_skills",
lambda: list(
skills
if skills is not None
else [
mod.AgentSkill(
id="code",
name="Code Implementation",
description="Implement and patch code",
tags=["python", "gitea"],
)
]
),
)
monkeypatch.setenv("HERMES_A2A_PUBLIC_URL", url)
monkeypatch.delenv("HERMES_AGENT_NAME", raising=False)
monkeypatch.delenv("AGENT_NAME", raising=False)
monkeypatch.delenv("HERMES_AGENT_DESCRIPTION", raising=False)
monkeypatch.delenv("AGENT_DESCRIPTION", raising=False)
def test_build_agent_card_matches_issue_802_schema(monkeypatch):
_set_base_context(monkeypatch)
card = mod.build_agent_card()
payload = card.to_dict()
assert payload["name"] == "Timmy"
assert payload["description"] == DEFAULT_DESCRIPTION
assert payload["url"] == "https://timmy.local:9443/a2a"
assert payload["capabilities"] == {
"streaming": True,
"pushNotifications": False,
"stateTransitionHistory": True,
}
assert payload["defaultInputModes"] == ["text/plain", "application/json"]
assert payload["defaultOutputModes"] == ["text/plain", "application/json"]
assert payload["skills"][0]["tags"] == ["python", "gitea"]
assert mod.validate_agent_card(payload) == []
@pytest.mark.parametrize(
("name", "url"),
[
("Timmy", "https://timmy.local:9443/a2a"),
("Allegro", "https://allegro.local:9443/a2a"),
("Ezra", "https://ezra.local:9443/a2a"),
],
)
def test_build_agent_card_supports_fleet_members(monkeypatch, name, url):
_set_base_context(monkeypatch, name=name, url=url, skills=[])
payload = mod.build_agent_card().to_dict()
assert payload["name"] == name
assert payload["url"] == url
assert mod.validate_agent_card(payload) == []
def test_load_skills_collects_tags_and_category(monkeypatch, tmp_path):
skill_root = tmp_path / "skills"
skill_dir = skill_root / "code-implementation"
skill_dir.mkdir(parents=True)
(skill_dir / "SKILL.md").write_text(
"""---
name: Code Implementation
description: Implement and patch code
tags: [python, gitea]
category: discovery
---
# Code Implementation
""",
encoding="utf-8",
)
monkeypatch.setattr(mod, "get_all_skills_dirs", lambda: [skill_root])
monkeypatch.setattr(mod, "get_disabled_skill_names", lambda: set())
monkeypatch.setattr(mod, "skill_matches_platform", lambda _frontmatter: True)
skills = mod._load_skills()
assert len(skills) == 1
assert skills[0].id == "code-implementation"
assert skills[0].name == "Code Implementation"
assert skills[0].description == "Implement and patch code"
assert skills[0].tags == ["python", "gitea", "discovery"]
def test_validate_agent_card_reports_schema_errors():
errors = mod.validate_agent_card(
{
"name": "",
"description": "",
"url": "timmy.local",
"version": "",
"capabilities": {"streaming": True},
"skills": [{"id": "", "name": "", "tags": "python"}],
"defaultInputModes": ["text/plain"],
"defaultOutputModes": ["plain"],
"metadata": [],
}
)
assert any("name must be a non-empty string" in error for error in errors)
assert any("url must be an absolute http/https URL" in error for error in errors)
assert any("capabilities.pushNotifications" in error for error in errors)
assert any("skills[0] missing id" in error for error in errors)
assert any("skills[0].tags must be a list" in error for error in errors)
assert any("defaultInputModes must include application/json" in error for error in errors)
assert any("defaultOutputModes entries must be MIME types" in error for error in errors)
assert any("metadata must be an object" in error for error in errors)
def test_get_agent_card_json_emits_valid_json(monkeypatch):
_set_base_context(monkeypatch)
payload = json.loads(mod.get_agent_card_json())
assert payload["name"] == "Timmy"
assert mod.validate_agent_card(payload) == []
def test_main_validate_prints_card(monkeypatch, capsys):
_set_base_context(monkeypatch)
exit_code = mod.main(["--validate"])
captured = capsys.readouterr()
assert exit_code == 0
payload = json.loads(captured.out)
assert payload["url"] == "https://timmy.local:9443/a2a"
assert captured.err == ""

View File

@@ -1037,6 +1037,138 @@ class TestBuildAssistantMessage:
result = agent._build_assistant_message(msg, "tool_calls")
assert "extra_content" not in result["tool_calls"][0]
def test_tool_call_arguments_normalized_from_gemma4_whitespace(self, agent):
tc = _mock_tool_call(
name="read_file",
arguments=' \n {"path": "README.md"} \n ',
call_id="c4",
)
msg = _mock_assistant_msg(content="", tool_calls=[tc])
result = agent._build_assistant_message(msg, "tool_calls")
assert result["tool_calls"][0]["function"]["arguments"] == '{"path":"README.md"}'
def test_tool_call_arguments_normalized_from_single_quotes_and_trailing_comma(self, agent):
tc = _mock_tool_call(
name="read_file",
arguments="{'path': 'README.md',}",
call_id="c5",
)
msg = _mock_assistant_msg(content="", tool_calls=[tc])
result = agent._build_assistant_message(msg, "tool_calls")
assert result["tool_calls"][0]["function"]["arguments"] == '{"path":"README.md"}'
class TestNormalizeToolCallArguments:
@pytest.mark.parametrize(
("raw_args", "expected"),
[
('{"q":"test"}', '{"q":"test"}'),
(' \n {"q": "test"} \n ', '{"q":"test"}'),
('{"q": "test",}', '{"q":"test"}'),
("{'q': 'test'}", '{"q":"test"}'),
("{'path': 'README.md', 'mode': 'read'}", '{"path":"README.md","mode":"read"}'),
('"path": "README.md"', '{"path":"README.md"}'),
('path: "README.md"', '{"path":"README.md"}'),
('path: "README.md", mode: "read"', '{"path":"README.md","mode":"read"}'),
({"path": "README.md"}, '{"path":"README.md"}'),
(["README.md", "docs.md"], '["README.md","docs.md"]'),
('\t\n ', '{}'),
('{"nested": {"path": "README.md"}}', '{"nested":{"path":"README.md"}}'),
],
)
def test_complete_args_are_normalized(self, raw_args, expected):
normalized, is_complete = AIAgent._normalize_tool_call_arguments(raw_args)
assert is_complete is True
assert normalized == expected
@pytest.mark.parametrize(
"raw_args",
[
'{"path": "README.md"',
'{"a": 1, "b"',
'{"path": [1, 2}',
"{'path': 'README.md'",
'path: "README.md", mode:',
'{"command": "echo hello",',
],
)
def test_incomplete_args_are_not_marked_complete(self, raw_args):
normalized, is_complete = AIAgent._normalize_tool_call_arguments(raw_args)
assert is_complete is False
assert isinstance(normalized, str)
assert normalized == raw_args.strip()
class TestSanitizeApiMessages:
def test_merges_consecutive_assistant_tool_call_messages(self):
messages = [
{
"role": "assistant",
"content": "first",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}}],
},
{
"role": "assistant",
"content": "second",
"tool_calls": [{"id": "c2", "type": "function", "function": {"name": "search_files", "arguments": '{"pattern":"TODO"}'}}],
},
{"role": "tool", "tool_call_id": "c1", "content": "a.py"},
{"role": "tool", "tool_call_id": "c2", "content": "matches"},
]
sanitized = AIAgent._sanitize_api_messages(messages)
assert len(sanitized) == 3
assert sanitized[0]["role"] == "assistant"
assert [tc["id"] for tc in sanitized[0]["tool_calls"]] == ["c1", "c2"]
assert sanitized[0]["content"] == "first\nsecond"
def test_does_not_merge_assistant_tool_call_messages_across_non_assistant_boundary(self):
messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}}],
},
{"role": "tool", "tool_call_id": "c1", "content": "a.py"},
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "c2", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"b.py"}'}}],
},
{"role": "tool", "tool_call_id": "c2", "content": "b.py"},
]
sanitized = AIAgent._sanitize_api_messages(messages)
assistant_msgs = [m for m in sanitized if m.get("role") == "assistant"]
assert len(assistant_msgs) == 2
assert assistant_msgs[0]["tool_calls"][0]["id"] == "c1"
assert assistant_msgs[1]["tool_calls"][0]["id"] == "c2"
def test_merge_preserves_tool_call_order(self):
messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}}],
},
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "c2", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"b.py"}'}}],
},
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "c3", "type": "function", "function": {"name": "read_file", "arguments": '{"path":"c.py"}'}}],
},
]
sanitized = AIAgent._sanitize_api_messages(messages)
assert [tc["id"] for tc in sanitized[0]["tool_calls"]] == ["c1", "c2", "c3"]
class TestFormatToolsForSystemMessage:
def test_no_tools_returns_empty_array(self, agent):
@@ -3467,6 +3599,59 @@ class TestStreamingApiCall:
assert tc[0].function.arguments == '{"path":"x.txt","content":"hel'
assert resp.choices[0].finish_reason == "length"
@pytest.mark.parametrize(
("raw_arguments", "expected"),
[
(' \n {"path": "x.txt"} \n ', '{"path":"x.txt"}'),
("{'path': 'x.txt',}", '{"path":"x.txt"}'),
('path: "x.txt", mode: "read"', '{"path":"x.txt","mode":"read"}'),
],
)
def test_repairable_tool_call_args_do_not_upgrade_finish_reason_to_length(self, agent, raw_arguments, expected):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "read_file", raw_arguments)]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"messages": []})
tc = resp.choices[0].message.tool_calls
assert len(tc) == 1
assert tc[0].function.name == "read_file"
assert tc[0].function.arguments == expected
assert resp.choices[0].finish_reason == "tool_calls"
def test_streamed_tool_call_args_single_quotes_across_chunks_normalized(self, agent):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "read_file", "{'path':")]),
_make_chunk(tool_calls=[_make_tc_delta(0, None, None, " 'x.txt',}")]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"messages": []})
tc = resp.choices[0].message.tool_calls
assert len(tc) == 1
assert tc[0].function.arguments == '{"path":"x.txt"}'
assert resp.choices[0].finish_reason == "tool_calls"
def test_streamed_split_json_chunks_still_reassemble(self, agent):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "read_file", '{"path":')]),
_make_chunk(tool_calls=[_make_tc_delta(0, None, None, ' "x.txt"}')]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"messages": []})
tc = resp.choices[0].message.tool_calls
assert len(tc) == 1
assert tc[0].function.arguments == '{"path":"x.txt"}'
assert resp.choices[0].finish_reason == "tool_calls"
def test_ollama_reused_index_separate_tool_calls(self, agent):
"""Ollama sends every tool call at index 0 with different ids.