Adds WeCom as a gateway platform adapter using the AI Bot WebSocket gateway for real-time bidirectional communication. No public endpoint or new pip dependencies needed (uses existing aiohttp + httpx). Features: - WebSocket persistent connection with auto-reconnect (exponential backoff) - DM and group messaging with configurable access policies - Media upload/download with AES decryption for encrypted attachments - Markdown rendering, quote context preservation - Proactive + passive reply message modes - Chunked media upload pipeline (512KB chunks) Cherry-picked from PR #1898 by EvilRan with: - Moved to current main (PR was 300 commits behind) - Skipped base.py regressions (reply_to additions are good but belong in a separate PR since they affect all platforms) - Fixed test assertions to match current base class send() signature (reply_to=None kwarg now explicit) - All 16 integration points added surgically to current main - No new pip dependencies (aiohttp + httpx already installed) Fixes #1898 Co-authored-by: EvilRan <EvilRan@users.noreply.github.com>
1339 lines
52 KiB
Python
1339 lines
52 KiB
Python
"""
|
||
WeCom (Enterprise WeChat) platform adapter.
|
||
|
||
Uses the WeCom AI Bot WebSocket gateway for inbound and outbound messages.
|
||
The adapter focuses on the core gateway path:
|
||
|
||
- authenticate via ``aibot_subscribe``
|
||
- receive inbound ``aibot_msg_callback`` events
|
||
- send outbound markdown messages via ``aibot_send_msg``
|
||
- upload outbound media via ``aibot_upload_media_*`` and send native attachments
|
||
- best-effort download of inbound image/file attachments for agent context
|
||
|
||
Configuration in config.yaml:
|
||
platforms:
|
||
wecom:
|
||
enabled: true
|
||
extra:
|
||
bot_id: "your-bot-id" # or WECOM_BOT_ID env var
|
||
secret: "your-secret" # or WECOM_SECRET env var
|
||
websocket_url: "wss://openws.work.weixin.qq.com"
|
||
dm_policy: "open" # open | allowlist | disabled | pairing
|
||
allow_from: ["user_id_1"]
|
||
group_policy: "open" # open | allowlist | disabled
|
||
group_allow_from: ["group_id_1"]
|
||
groups:
|
||
group_id_1:
|
||
allow_from: ["user_id_1"]
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import mimetypes
|
||
import os
|
||
import re
|
||
import time
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
from urllib.parse import unquote, urlparse
|
||
|
||
try:
|
||
import aiohttp
|
||
AIOHTTP_AVAILABLE = True
|
||
except ImportError:
|
||
AIOHTTP_AVAILABLE = False
|
||
aiohttp = None # type: ignore[assignment]
|
||
|
||
try:
|
||
import httpx
|
||
HTTPX_AVAILABLE = True
|
||
except ImportError:
|
||
HTTPX_AVAILABLE = False
|
||
httpx = None # type: ignore[assignment]
|
||
|
||
from gateway.config import Platform, PlatformConfig
|
||
from gateway.platforms.base import (
|
||
BasePlatformAdapter,
|
||
MessageEvent,
|
||
MessageType,
|
||
SendResult,
|
||
cache_document_from_bytes,
|
||
cache_image_from_bytes,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_WS_URL = "wss://openws.work.weixin.qq.com"
|
||
|
||
APP_CMD_SUBSCRIBE = "aibot_subscribe"
|
||
APP_CMD_CALLBACK = "aibot_msg_callback"
|
||
APP_CMD_LEGACY_CALLBACK = "aibot_callback"
|
||
APP_CMD_EVENT_CALLBACK = "aibot_event_callback"
|
||
APP_CMD_SEND = "aibot_send_msg"
|
||
APP_CMD_RESPONSE = "aibot_respond_msg"
|
||
APP_CMD_PING = "ping"
|
||
APP_CMD_UPLOAD_MEDIA_INIT = "aibot_upload_media_init"
|
||
APP_CMD_UPLOAD_MEDIA_CHUNK = "aibot_upload_media_chunk"
|
||
APP_CMD_UPLOAD_MEDIA_FINISH = "aibot_upload_media_finish"
|
||
|
||
CALLBACK_COMMANDS = {APP_CMD_CALLBACK, APP_CMD_LEGACY_CALLBACK}
|
||
NON_RESPONSE_COMMANDS = CALLBACK_COMMANDS | {APP_CMD_EVENT_CALLBACK}
|
||
|
||
MAX_MESSAGE_LENGTH = 4000
|
||
CONNECT_TIMEOUT_SECONDS = 20.0
|
||
REQUEST_TIMEOUT_SECONDS = 15.0
|
||
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||
|
||
DEDUP_WINDOW_SECONDS = 300
|
||
DEDUP_MAX_SIZE = 1000
|
||
|
||
IMAGE_MAX_BYTES = 10 * 1024 * 1024
|
||
VIDEO_MAX_BYTES = 10 * 1024 * 1024
|
||
VOICE_MAX_BYTES = 2 * 1024 * 1024
|
||
FILE_MAX_BYTES = 20 * 1024 * 1024
|
||
ABSOLUTE_MAX_BYTES = FILE_MAX_BYTES
|
||
UPLOAD_CHUNK_SIZE = 512 * 1024
|
||
MAX_UPLOAD_CHUNKS = 100
|
||
VOICE_SUPPORTED_MIMES = {"audio/amr"}
|
||
|
||
|
||
def check_wecom_requirements() -> bool:
|
||
"""Check if WeCom runtime dependencies are available."""
|
||
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
||
|
||
|
||
def _coerce_list(value: Any) -> List[str]:
|
||
"""Coerce config values into a trimmed string list."""
|
||
if value is None:
|
||
return []
|
||
if isinstance(value, str):
|
||
return [item.strip() for item in value.split(",") if item.strip()]
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [str(item).strip() for item in value if str(item).strip()]
|
||
return [str(value).strip()] if str(value).strip() else []
|
||
|
||
|
||
def _normalize_entry(raw: str) -> str:
|
||
"""Normalize allowlist entries such as ``wecom:user:foo``."""
|
||
value = str(raw).strip()
|
||
value = re.sub(r"^wecom:", "", value, flags=re.IGNORECASE)
|
||
value = re.sub(r"^(user|group):", "", value, flags=re.IGNORECASE)
|
||
return value.strip()
|
||
|
||
|
||
def _entry_matches(entries: List[str], target: str) -> bool:
|
||
"""Case-insensitive allowlist match with ``*`` support."""
|
||
normalized_target = str(target).strip().lower()
|
||
for entry in entries:
|
||
normalized = _normalize_entry(entry).lower()
|
||
if normalized == "*" or normalized == normalized_target:
|
||
return True
|
||
return False
|
||
|
||
|
||
class WeComAdapter(BasePlatformAdapter):
|
||
"""WeCom AI Bot adapter backed by a persistent WebSocket connection."""
|
||
|
||
MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH
|
||
|
||
def __init__(self, config: PlatformConfig):
|
||
super().__init__(config, Platform.WECOM)
|
||
|
||
extra = config.extra or {}
|
||
self._bot_id = str(extra.get("bot_id") or os.getenv("WECOM_BOT_ID", "")).strip()
|
||
self._secret = str(extra.get("secret") or os.getenv("WECOM_SECRET", "")).strip()
|
||
self._ws_url = str(
|
||
extra.get("websocket_url")
|
||
or extra.get("websocketUrl")
|
||
or os.getenv("WECOM_WEBSOCKET_URL", DEFAULT_WS_URL)
|
||
).strip() or DEFAULT_WS_URL
|
||
|
||
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WECOM_DM_POLICY", "open")).strip().lower()
|
||
self._allow_from = _coerce_list(extra.get("allow_from") or extra.get("allowFrom"))
|
||
|
||
self._group_policy = str(extra.get("group_policy") or os.getenv("WECOM_GROUP_POLICY", "open")).strip().lower()
|
||
self._group_allow_from = _coerce_list(extra.get("group_allow_from") or extra.get("groupAllowFrom"))
|
||
self._groups = extra.get("groups") if isinstance(extra.get("groups"), dict) else {}
|
||
|
||
self._session: Optional["aiohttp.ClientSession"] = None
|
||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||
self._listen_task: Optional[asyncio.Task] = None
|
||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||
self._seen_messages: Dict[str, float] = {}
|
||
self._reply_req_ids: Dict[str, str] = {}
|
||
|
||
# ------------------------------------------------------------------
|
||
# Connection lifecycle
|
||
# ------------------------------------------------------------------
|
||
|
||
async def connect(self) -> bool:
|
||
"""Connect to the WeCom AI Bot gateway."""
|
||
if not AIOHTTP_AVAILABLE:
|
||
message = "WeCom startup failed: aiohttp not installed"
|
||
self._set_fatal_error("wecom_missing_dependency", message, retryable=True)
|
||
logger.warning("[%s] %s. Run: pip install aiohttp", self.name, message)
|
||
return False
|
||
if not HTTPX_AVAILABLE:
|
||
message = "WeCom startup failed: httpx not installed"
|
||
self._set_fatal_error("wecom_missing_dependency", message, retryable=True)
|
||
logger.warning("[%s] %s. Run: pip install httpx", self.name, message)
|
||
return False
|
||
if not self._bot_id or not self._secret:
|
||
message = "WeCom startup failed: WECOM_BOT_ID and WECOM_SECRET are required"
|
||
self._set_fatal_error("wecom_missing_credentials", message, retryable=True)
|
||
logger.warning("[%s] %s", self.name, message)
|
||
return False
|
||
|
||
try:
|
||
self._http_client = httpx.AsyncClient(timeout=30.0, follow_redirects=True)
|
||
await self._open_connection()
|
||
self._mark_connected()
|
||
self._listen_task = asyncio.create_task(self._listen_loop())
|
||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||
logger.info("[%s] Connected to %s", self.name, self._ws_url)
|
||
return True
|
||
except Exception as exc:
|
||
message = f"WeCom startup failed: {exc}"
|
||
self._set_fatal_error("wecom_connect_error", message, retryable=True)
|
||
logger.error("[%s] Failed to connect: %s", self.name, exc, exc_info=True)
|
||
await self._cleanup_ws()
|
||
if self._http_client:
|
||
await self._http_client.aclose()
|
||
self._http_client = None
|
||
return False
|
||
|
||
async def disconnect(self) -> None:
|
||
"""Disconnect from WeCom."""
|
||
self._running = False
|
||
self._mark_disconnected()
|
||
|
||
if self._listen_task:
|
||
self._listen_task.cancel()
|
||
try:
|
||
await self._listen_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._listen_task = None
|
||
|
||
if self._heartbeat_task:
|
||
self._heartbeat_task.cancel()
|
||
try:
|
||
await self._heartbeat_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._heartbeat_task = None
|
||
|
||
self._fail_pending_responses(RuntimeError("WeCom adapter disconnected"))
|
||
await self._cleanup_ws()
|
||
|
||
if self._http_client:
|
||
await self._http_client.aclose()
|
||
self._http_client = None
|
||
|
||
self._seen_messages.clear()
|
||
logger.info("[%s] Disconnected", self.name)
|
||
|
||
async def _cleanup_ws(self) -> None:
|
||
"""Close the live websocket/session, if any."""
|
||
if self._ws and not self._ws.closed:
|
||
await self._ws.close()
|
||
self._ws = None
|
||
|
||
if self._session and not self._session.closed:
|
||
await self._session.close()
|
||
self._session = None
|
||
|
||
async def _open_connection(self) -> None:
|
||
"""Open and authenticate a websocket connection."""
|
||
await self._cleanup_ws()
|
||
self._session = aiohttp.ClientSession()
|
||
self._ws = await self._session.ws_connect(
|
||
self._ws_url,
|
||
heartbeat=HEARTBEAT_INTERVAL_SECONDS * 2,
|
||
timeout=CONNECT_TIMEOUT_SECONDS,
|
||
)
|
||
|
||
req_id = self._new_req_id("subscribe")
|
||
await self._send_json(
|
||
{
|
||
"cmd": APP_CMD_SUBSCRIBE,
|
||
"headers": {"req_id": req_id},
|
||
"body": {"bot_id": self._bot_id, "secret": self._secret},
|
||
}
|
||
)
|
||
|
||
auth_payload = await self._wait_for_handshake(req_id)
|
||
errcode = auth_payload.get("errcode", 0)
|
||
if errcode not in (0, None):
|
||
errmsg = auth_payload.get("errmsg", "authentication failed")
|
||
raise RuntimeError(f"{errmsg} (errcode={errcode})")
|
||
|
||
async def _wait_for_handshake(self, req_id: str) -> Dict[str, Any]:
|
||
"""Wait for the subscribe acknowledgement."""
|
||
if not self._ws:
|
||
raise RuntimeError("WebSocket not initialized")
|
||
|
||
deadline = asyncio.get_running_loop().time() + CONNECT_TIMEOUT_SECONDS
|
||
while True:
|
||
remaining = deadline - asyncio.get_running_loop().time()
|
||
if remaining <= 0:
|
||
raise TimeoutError("Timed out waiting for WeCom subscribe acknowledgement")
|
||
|
||
msg = await asyncio.wait_for(self._ws.receive(), timeout=remaining)
|
||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
payload = self._parse_json(msg.data)
|
||
if not payload:
|
||
continue
|
||
if payload.get("cmd") == APP_CMD_PING:
|
||
continue
|
||
if self._payload_req_id(payload) == req_id:
|
||
return payload
|
||
logger.debug("[%s] Ignoring pre-auth payload: %s", self.name, payload.get("cmd"))
|
||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR):
|
||
raise RuntimeError("WeCom websocket closed during authentication")
|
||
|
||
async def _listen_loop(self) -> None:
|
||
"""Read websocket events forever, reconnecting on errors."""
|
||
backoff_idx = 0
|
||
while self._running:
|
||
try:
|
||
await self._read_events()
|
||
backoff_idx = 0
|
||
except asyncio.CancelledError:
|
||
return
|
||
except Exception as exc:
|
||
if not self._running:
|
||
return
|
||
logger.warning("[%s] WebSocket error: %s", self.name, exc)
|
||
self._fail_pending_responses(RuntimeError("WeCom connection interrupted"))
|
||
|
||
delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)]
|
||
backoff_idx += 1
|
||
await asyncio.sleep(delay)
|
||
|
||
try:
|
||
await self._open_connection()
|
||
backoff_idx = 0
|
||
logger.info("[%s] Reconnected", self.name)
|
||
except Exception as reconnect_exc:
|
||
logger.warning("[%s] Reconnect failed: %s", self.name, reconnect_exc)
|
||
|
||
async def _read_events(self) -> None:
|
||
"""Read websocket frames until the connection closes."""
|
||
if not self._ws:
|
||
raise RuntimeError("WebSocket not connected")
|
||
|
||
while self._running and self._ws and not self._ws.closed:
|
||
msg = await self._ws.receive()
|
||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
payload = self._parse_json(msg.data)
|
||
if payload:
|
||
await self._dispatch_payload(payload)
|
||
elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||
raise RuntimeError("WeCom websocket closed")
|
||
|
||
async def _heartbeat_loop(self) -> None:
|
||
"""Send lightweight application-level pings."""
|
||
try:
|
||
while self._running:
|
||
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
||
if not self._ws or self._ws.closed:
|
||
continue
|
||
try:
|
||
await self._send_json(
|
||
{
|
||
"cmd": APP_CMD_PING,
|
||
"headers": {"req_id": self._new_req_id("ping")},
|
||
"body": {},
|
||
}
|
||
)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Heartbeat send failed: %s", self.name, exc)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
async def _dispatch_payload(self, payload: Dict[str, Any]) -> None:
|
||
"""Route inbound websocket payloads."""
|
||
req_id = self._payload_req_id(payload)
|
||
cmd = str(payload.get("cmd") or "")
|
||
|
||
if req_id and req_id in self._pending_responses and cmd not in NON_RESPONSE_COMMANDS:
|
||
future = self._pending_responses.get(req_id)
|
||
if future and not future.done():
|
||
future.set_result(payload)
|
||
return
|
||
|
||
if cmd in CALLBACK_COMMANDS:
|
||
await self._on_message(payload)
|
||
return
|
||
if cmd in {APP_CMD_PING, APP_CMD_EVENT_CALLBACK}:
|
||
return
|
||
|
||
logger.debug("[%s] Ignoring websocket payload: %s", self.name, cmd or payload)
|
||
|
||
def _fail_pending_responses(self, exc: Exception) -> None:
|
||
"""Fail all outstanding request futures."""
|
||
for req_id, future in list(self._pending_responses.items()):
|
||
if not future.done():
|
||
future.set_exception(exc)
|
||
self._pending_responses.pop(req_id, None)
|
||
|
||
async def _send_json(self, payload: Dict[str, Any]) -> None:
|
||
"""Send a raw JSON frame over the active websocket."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
await self._ws.send_json(payload)
|
||
|
||
async def _send_request(self, cmd: str, body: Dict[str, Any], timeout: float = REQUEST_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||
"""Send a JSON request and await the correlated response."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
|
||
req_id = self._new_req_id(cmd)
|
||
future = asyncio.get_running_loop().create_future()
|
||
self._pending_responses[req_id] = future
|
||
try:
|
||
await self._send_json({"cmd": cmd, "headers": {"req_id": req_id}, "body": body})
|
||
response = await asyncio.wait_for(future, timeout=timeout)
|
||
return response
|
||
finally:
|
||
self._pending_responses.pop(req_id, None)
|
||
|
||
async def _send_reply_request(
|
||
self,
|
||
reply_req_id: str,
|
||
body: Dict[str, Any],
|
||
cmd: str = APP_CMD_RESPONSE,
|
||
timeout: float = REQUEST_TIMEOUT_SECONDS,
|
||
) -> Dict[str, Any]:
|
||
"""Send a reply frame correlated to an inbound callback req_id."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
|
||
normalized_req_id = str(reply_req_id or "").strip()
|
||
if not normalized_req_id:
|
||
raise ValueError("reply_req_id is required")
|
||
|
||
future = asyncio.get_running_loop().create_future()
|
||
self._pending_responses[normalized_req_id] = future
|
||
try:
|
||
await self._send_json(
|
||
{"cmd": cmd, "headers": {"req_id": normalized_req_id}, "body": body}
|
||
)
|
||
response = await asyncio.wait_for(future, timeout=timeout)
|
||
return response
|
||
finally:
|
||
self._pending_responses.pop(normalized_req_id, None)
|
||
|
||
@staticmethod
|
||
def _new_req_id(prefix: str) -> str:
|
||
return f"{prefix}-{uuid.uuid4().hex}"
|
||
|
||
@staticmethod
|
||
def _payload_req_id(payload: Dict[str, Any]) -> str:
|
||
headers = payload.get("headers")
|
||
if isinstance(headers, dict):
|
||
return str(headers.get("req_id") or "")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _parse_json(raw: Any) -> Optional[Dict[str, Any]]:
|
||
try:
|
||
payload = json.loads(raw)
|
||
except Exception:
|
||
logger.debug("Failed to parse WeCom payload: %r", raw)
|
||
return None
|
||
return payload if isinstance(payload, dict) else None
|
||
|
||
# ------------------------------------------------------------------
|
||
# Inbound message parsing
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _on_message(self, payload: Dict[str, Any]) -> None:
|
||
"""Process an inbound WeCom message callback event."""
|
||
body = payload.get("body")
|
||
if not isinstance(body, dict):
|
||
return
|
||
|
||
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
|
||
if self._is_duplicate(msg_id):
|
||
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
|
||
return
|
||
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
|
||
|
||
sender = body.get("from") if isinstance(body.get("from"), dict) else {}
|
||
sender_id = str(sender.get("userid") or "").strip()
|
||
chat_id = str(body.get("chatid") or sender_id).strip()
|
||
if not chat_id:
|
||
logger.debug("[%s] Missing chat id, skipping message", self.name)
|
||
return
|
||
|
||
is_group = str(body.get("chattype") or "").lower() == "group"
|
||
if is_group:
|
||
if not self._is_group_allowed(chat_id, sender_id):
|
||
logger.debug("[%s] Group %s / sender %s blocked by policy", self.name, chat_id, sender_id)
|
||
return
|
||
elif not self._is_dm_allowed(sender_id):
|
||
logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id)
|
||
return
|
||
|
||
text, reply_text = self._extract_text(body)
|
||
media_urls, media_types = await self._extract_media(body)
|
||
message_type = self._derive_message_type(body, text, media_types)
|
||
has_reply_context = bool(reply_text and (text or media_urls))
|
||
|
||
if not text and reply_text and not media_urls:
|
||
text = reply_text
|
||
|
||
if not text and not media_urls:
|
||
logger.debug("[%s] Empty WeCom message skipped", self.name)
|
||
return
|
||
|
||
source = self.build_source(
|
||
chat_id=chat_id,
|
||
chat_type="group" if is_group else "dm",
|
||
user_id=sender_id or None,
|
||
user_name=sender_id or None,
|
||
)
|
||
|
||
event = MessageEvent(
|
||
text=text,
|
||
message_type=message_type,
|
||
source=source,
|
||
raw_message=payload,
|
||
message_id=msg_id,
|
||
media_urls=media_urls,
|
||
media_types=media_types,
|
||
reply_to_message_id=f"quote:{msg_id}" if has_reply_context else None,
|
||
reply_to_text=reply_text if has_reply_context else None,
|
||
timestamp=datetime.now(tz=timezone.utc),
|
||
)
|
||
|
||
await self.handle_message(event)
|
||
|
||
@staticmethod
|
||
def _extract_text(body: Dict[str, Any]) -> Tuple[str, Optional[str]]:
|
||
"""Extract plain text and quoted text from a callback payload."""
|
||
text_parts: List[str] = []
|
||
reply_text: Optional[str] = None
|
||
msgtype = str(body.get("msgtype") or "").lower()
|
||
|
||
if msgtype == "mixed":
|
||
mixed = body.get("mixed") if isinstance(body.get("mixed"), dict) else {}
|
||
items = mixed.get("msg_item") if isinstance(mixed.get("msg_item"), list) else []
|
||
for item in items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if str(item.get("msgtype") or "").lower() == "text":
|
||
text_block = item.get("text") if isinstance(item.get("text"), dict) else {}
|
||
content = str(text_block.get("content") or "").strip()
|
||
if content:
|
||
text_parts.append(content)
|
||
else:
|
||
text_block = body.get("text") if isinstance(body.get("text"), dict) else {}
|
||
content = str(text_block.get("content") or "").strip()
|
||
if content:
|
||
text_parts.append(content)
|
||
|
||
if msgtype == "voice":
|
||
voice_block = body.get("voice") if isinstance(body.get("voice"), dict) else {}
|
||
voice_text = str(voice_block.get("content") or "").strip()
|
||
if voice_text:
|
||
text_parts.append(voice_text)
|
||
|
||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||
quote_type = str(quote.get("msgtype") or "").lower()
|
||
if quote_type == "text":
|
||
quote_text = quote.get("text") if isinstance(quote.get("text"), dict) else {}
|
||
reply_text = str(quote_text.get("content") or "").strip() or None
|
||
elif quote_type == "voice":
|
||
quote_voice = quote.get("voice") if isinstance(quote.get("voice"), dict) else {}
|
||
reply_text = str(quote_voice.get("content") or "").strip() or None
|
||
|
||
return "\n".join(part for part in text_parts if part).strip(), reply_text
|
||
|
||
async def _extract_media(self, body: Dict[str, Any]) -> Tuple[List[str], List[str]]:
|
||
"""Best-effort extraction of inbound media to local cache paths."""
|
||
media_paths: List[str] = []
|
||
media_types: List[str] = []
|
||
refs: List[Tuple[str, Dict[str, Any]]] = []
|
||
msgtype = str(body.get("msgtype") or "").lower()
|
||
|
||
if msgtype == "mixed":
|
||
mixed = body.get("mixed") if isinstance(body.get("mixed"), dict) else {}
|
||
items = mixed.get("msg_item") if isinstance(mixed.get("msg_item"), list) else []
|
||
for item in items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = str(item.get("msgtype") or "").lower()
|
||
if item_type == "image" and isinstance(item.get("image"), dict):
|
||
refs.append(("image", item["image"]))
|
||
else:
|
||
if isinstance(body.get("image"), dict):
|
||
refs.append(("image", body["image"]))
|
||
if msgtype == "file" and isinstance(body.get("file"), dict):
|
||
refs.append(("file", body["file"]))
|
||
|
||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||
quote_type = str(quote.get("msgtype") or "").lower()
|
||
if quote_type == "image" and isinstance(quote.get("image"), dict):
|
||
refs.append(("image", quote["image"]))
|
||
elif quote_type == "file" and isinstance(quote.get("file"), dict):
|
||
refs.append(("file", quote["file"]))
|
||
|
||
for kind, ref in refs:
|
||
cached = await self._cache_media(kind, ref)
|
||
if cached:
|
||
path, content_type = cached
|
||
media_paths.append(path)
|
||
media_types.append(content_type)
|
||
|
||
return media_paths, media_types
|
||
|
||
async def _cache_media(self, kind: str, media: Dict[str, Any]) -> Optional[Tuple[str, str]]:
|
||
"""Cache an inbound image/file/media reference to local storage."""
|
||
if "base64" in media and media.get("base64"):
|
||
try:
|
||
raw = self._decode_base64(media["base64"])
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to decode %s base64 media: %s", self.name, kind, exc)
|
||
return None
|
||
|
||
if kind == "image":
|
||
ext = self._detect_image_ext(raw)
|
||
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
|
||
|
||
filename = str(media.get("filename") or media.get("name") or "wecom_file")
|
||
return cache_document_from_bytes(raw, filename), mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||
|
||
url = str(media.get("url") or "").strip()
|
||
if not url:
|
||
return None
|
||
|
||
try:
|
||
raw, headers = await self._download_remote_bytes(url, max_bytes=ABSOLUTE_MAX_BYTES)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to download %s from %s: %s", self.name, kind, url, exc)
|
||
return None
|
||
|
||
aes_key = str(media.get("aeskey") or "").strip()
|
||
if aes_key:
|
||
try:
|
||
raw = self._decrypt_file_bytes(raw, aes_key)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to decrypt %s from %s: %s", self.name, kind, url, exc)
|
||
return None
|
||
|
||
content_type = str(headers.get("content-type") or "").split(";", 1)[0].strip() or "application/octet-stream"
|
||
if kind == "image":
|
||
ext = self._guess_extension(url, content_type, fallback=self._detect_image_ext(raw))
|
||
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
|
||
|
||
filename = self._guess_filename(url, headers.get("content-disposition"), content_type)
|
||
return cache_document_from_bytes(raw, filename), content_type
|
||
|
||
@staticmethod
|
||
def _decode_base64(data: str) -> bytes:
|
||
payload = data.split(",", 1)[-1].strip()
|
||
return base64.b64decode(payload)
|
||
|
||
@staticmethod
|
||
def _detect_image_ext(data: bytes) -> str:
|
||
if data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||
return ".png"
|
||
if data.startswith(b"\xff\xd8\xff"):
|
||
return ".jpg"
|
||
if data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
|
||
return ".gif"
|
||
if data.startswith(b"RIFF") and data[8:12] == b"WEBP":
|
||
return ".webp"
|
||
return ".jpg"
|
||
|
||
@staticmethod
|
||
def _mime_for_ext(ext: str, fallback: str = "application/octet-stream") -> str:
|
||
return mimetypes.types_map.get(ext.lower(), fallback)
|
||
|
||
@staticmethod
|
||
def _guess_extension(url: str, content_type: str, fallback: str) -> str:
|
||
ext = mimetypes.guess_extension(content_type) if content_type else None
|
||
if ext:
|
||
return ext
|
||
path_ext = Path(urlparse(url).path).suffix
|
||
if path_ext:
|
||
return path_ext
|
||
return fallback
|
||
|
||
@staticmethod
|
||
def _guess_filename(url: str, content_disposition: Optional[str], content_type: str) -> str:
|
||
if content_disposition:
|
||
match = re.search(r'filename="?([^";]+)"?', content_disposition)
|
||
if match:
|
||
return match.group(1)
|
||
|
||
name = Path(urlparse(url).path).name or "document"
|
||
if "." not in name:
|
||
ext = mimetypes.guess_extension(content_type) or ".bin"
|
||
name = f"{name}{ext}"
|
||
return name
|
||
|
||
@staticmethod
|
||
def _derive_message_type(body: Dict[str, Any], text: str, media_types: List[str]) -> MessageType:
|
||
"""Choose the normalized inbound message type."""
|
||
if any(mtype.startswith("application/") or mtype.startswith("text/") for mtype in media_types):
|
||
return MessageType.DOCUMENT
|
||
if any(mtype.startswith("image/") for mtype in media_types):
|
||
return MessageType.TEXT if text else MessageType.PHOTO
|
||
if str(body.get("msgtype") or "").lower() == "voice":
|
||
return MessageType.VOICE
|
||
return MessageType.TEXT
|
||
|
||
# ------------------------------------------------------------------
|
||
# Policy helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _is_dm_allowed(self, sender_id: str) -> bool:
|
||
if self._dm_policy == "disabled":
|
||
return False
|
||
if self._dm_policy == "allowlist":
|
||
return _entry_matches(self._allow_from, sender_id)
|
||
return True
|
||
|
||
def _is_group_allowed(self, chat_id: str, sender_id: str) -> bool:
|
||
if self._group_policy == "disabled":
|
||
return False
|
||
if self._group_policy == "allowlist" and not _entry_matches(self._group_allow_from, chat_id):
|
||
return False
|
||
|
||
group_cfg = self._resolve_group_cfg(chat_id)
|
||
sender_allow = _coerce_list(group_cfg.get("allow_from") or group_cfg.get("allowFrom"))
|
||
if sender_allow:
|
||
return _entry_matches(sender_allow, sender_id)
|
||
return True
|
||
|
||
def _resolve_group_cfg(self, chat_id: str) -> Dict[str, Any]:
|
||
if not isinstance(self._groups, dict):
|
||
return {}
|
||
if chat_id in self._groups and isinstance(self._groups[chat_id], dict):
|
||
return self._groups[chat_id]
|
||
lowered = chat_id.lower()
|
||
for key, value in self._groups.items():
|
||
if isinstance(key, str) and key.lower() == lowered and isinstance(value, dict):
|
||
return value
|
||
wildcard = self._groups.get("*")
|
||
return wildcard if isinstance(wildcard, dict) else {}
|
||
|
||
def _is_duplicate(self, msg_id: str) -> bool:
|
||
now = time.time()
|
||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||
self._seen_messages = {
|
||
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
|
||
}
|
||
if self._reply_req_ids:
|
||
self._reply_req_ids = {
|
||
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
|
||
}
|
||
|
||
if msg_id in self._seen_messages:
|
||
return True
|
||
|
||
self._seen_messages[msg_id] = now
|
||
return False
|
||
|
||
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
|
||
normalized_message_id = str(message_id or "").strip()
|
||
normalized_req_id = str(req_id or "").strip()
|
||
if not normalized_message_id or not normalized_req_id:
|
||
return
|
||
self._reply_req_ids[normalized_message_id] = normalized_req_id
|
||
while len(self._reply_req_ids) > DEDUP_MAX_SIZE:
|
||
self._reply_req_ids.pop(next(iter(self._reply_req_ids)))
|
||
|
||
def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]:
|
||
normalized = str(reply_to or "").strip()
|
||
if not normalized or normalized.startswith("quote:"):
|
||
return None
|
||
return self._reply_req_ids.get(normalized)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Outbound messaging
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _guess_mime_type(filename: str) -> str:
|
||
mime_type = mimetypes.guess_type(filename)[0]
|
||
if mime_type:
|
||
return mime_type
|
||
if Path(filename).suffix.lower() == ".amr":
|
||
return "audio/amr"
|
||
return "application/octet-stream"
|
||
|
||
@staticmethod
|
||
def _normalize_content_type(content_type: str, filename: str) -> str:
|
||
normalized = str(content_type or "").split(";", 1)[0].strip().lower()
|
||
guessed = WeComAdapter._guess_mime_type(filename)
|
||
if not normalized:
|
||
return guessed
|
||
if normalized in {"application/octet-stream", "text/plain"}:
|
||
return guessed
|
||
return normalized
|
||
|
||
@staticmethod
|
||
def _detect_wecom_media_type(content_type: str) -> str:
|
||
mime_type = str(content_type or "").strip().lower()
|
||
if mime_type.startswith("image/"):
|
||
return "image"
|
||
if mime_type.startswith("video/"):
|
||
return "video"
|
||
if mime_type.startswith("audio/") or mime_type == "application/ogg":
|
||
return "voice"
|
||
return "file"
|
||
|
||
@staticmethod
|
||
def _apply_file_size_limits(file_size: int, detected_type: str, content_type: Optional[str] = None) -> Dict[str, Any]:
|
||
file_size_mb = file_size / (1024 * 1024)
|
||
normalized_type = str(detected_type or "file").lower()
|
||
normalized_content_type = str(content_type or "").strip().lower()
|
||
|
||
if file_size > ABSOLUTE_MAX_BYTES:
|
||
return {
|
||
"final_type": normalized_type,
|
||
"rejected": True,
|
||
"reject_reason": (
|
||
f"文件大小 {file_size_mb:.2f}MB 超过了企业微信允许的最大限制 20MB,无法发送。"
|
||
"请尝试压缩文件或减小文件大小。"
|
||
),
|
||
"downgraded": False,
|
||
"downgrade_note": None,
|
||
}
|
||
|
||
if normalized_type == "image" and file_size > IMAGE_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"图片大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
if normalized_type == "video" and file_size > VIDEO_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"视频大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
if normalized_type == "voice":
|
||
if normalized_content_type and normalized_content_type not in VOICE_SUPPORTED_MIMES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": (
|
||
f"语音格式 {normalized_content_type} 不支持,企微仅支持 AMR 格式,已转为文件格式发送"
|
||
),
|
||
}
|
||
if file_size > VOICE_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"语音大小 {file_size_mb:.2f}MB 超过 2MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
return {
|
||
"final_type": normalized_type,
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": False,
|
||
"downgrade_note": None,
|
||
}
|
||
|
||
@staticmethod
|
||
def _response_error(response: Dict[str, Any]) -> Optional[str]:
|
||
errcode = response.get("errcode", 0)
|
||
if errcode in (0, None):
|
||
return None
|
||
errmsg = str(response.get("errmsg") or "unknown error")
|
||
return f"WeCom errcode {errcode}: {errmsg}"
|
||
|
||
@classmethod
|
||
def _raise_for_wecom_error(cls, response: Dict[str, Any], operation: str) -> None:
|
||
error = cls._response_error(response)
|
||
if error:
|
||
raise RuntimeError(f"{operation} failed: {error}")
|
||
|
||
@staticmethod
|
||
def _decrypt_file_bytes(encrypted_data: bytes, aes_key: str) -> bytes:
|
||
if not encrypted_data:
|
||
raise ValueError("encrypted_data is empty")
|
||
if not aes_key:
|
||
raise ValueError("aes_key is required")
|
||
|
||
key = base64.b64decode(aes_key)
|
||
if len(key) != 32:
|
||
raise ValueError(f"Invalid WeCom AES key length: expected 32 bytes, got {len(key)}")
|
||
|
||
try:
|
||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||
except ImportError as exc: # pragma: no cover - dependency is environment-specific
|
||
raise RuntimeError("cryptography is required for WeCom media decryption") from exc
|
||
|
||
cipher = Cipher(algorithms.AES(key), modes.CBC(key[:16]))
|
||
decryptor = cipher.decryptor()
|
||
decrypted = decryptor.update(encrypted_data) + decryptor.finalize()
|
||
|
||
pad_len = decrypted[-1]
|
||
if pad_len < 1 or pad_len > 32 or pad_len > len(decrypted):
|
||
raise ValueError(f"Invalid PKCS#7 padding value: {pad_len}")
|
||
if any(byte != pad_len for byte in decrypted[-pad_len:]):
|
||
raise ValueError("Invalid PKCS#7 padding: padding bytes mismatch")
|
||
|
||
return decrypted[:-pad_len]
|
||
|
||
async def _download_remote_bytes(
|
||
self,
|
||
url: str,
|
||
max_bytes: int,
|
||
) -> Tuple[bytes, Dict[str, str]]:
|
||
if not HTTPX_AVAILABLE:
|
||
raise RuntimeError("httpx is required for WeCom media download")
|
||
|
||
client = self._http_client or httpx.AsyncClient(timeout=30.0, follow_redirects=True)
|
||
created_client = client is not self._http_client
|
||
try:
|
||
async with client.stream(
|
||
"GET",
|
||
url,
|
||
headers={
|
||
"User-Agent": "HermesAgent/1.0",
|
||
"Accept": "*/*",
|
||
},
|
||
) as response:
|
||
response.raise_for_status()
|
||
headers = {key.lower(): value for key, value in response.headers.items()}
|
||
content_length = headers.get("content-length")
|
||
if content_length and content_length.isdigit() and int(content_length) > max_bytes:
|
||
raise ValueError(
|
||
f"Remote media exceeds WeCom limit: {int(content_length)} bytes > {max_bytes} bytes"
|
||
)
|
||
|
||
data = bytearray()
|
||
async for chunk in response.aiter_bytes():
|
||
data.extend(chunk)
|
||
if len(data) > max_bytes:
|
||
raise ValueError(
|
||
f"Remote media exceeds WeCom limit while downloading: {len(data)} bytes > {max_bytes} bytes"
|
||
)
|
||
|
||
return bytes(data), headers
|
||
finally:
|
||
if created_client:
|
||
await client.aclose()
|
||
|
||
@staticmethod
|
||
def _looks_like_url(media_source: str) -> bool:
|
||
parsed = urlparse(str(media_source or ""))
|
||
return parsed.scheme in {"http", "https"}
|
||
|
||
async def _load_outbound_media(
|
||
self,
|
||
media_source: str,
|
||
file_name: Optional[str] = None,
|
||
) -> Tuple[bytes, str, str]:
|
||
source = str(media_source or "").strip()
|
||
if not source:
|
||
raise ValueError("media source is required")
|
||
if re.fullmatch(r"<[^>\n]+>", source):
|
||
raise ValueError(f"Media placeholder was not replaced with a real file path: {source}")
|
||
|
||
parsed = urlparse(source)
|
||
if parsed.scheme in {"http", "https"}:
|
||
data, headers = await self._download_remote_bytes(source, max_bytes=ABSOLUTE_MAX_BYTES)
|
||
content_disposition = headers.get("content-disposition")
|
||
resolved_name = file_name or self._guess_filename(source, content_disposition, headers.get("content-type", ""))
|
||
content_type = self._normalize_content_type(headers.get("content-type", ""), resolved_name)
|
||
return data, content_type, resolved_name
|
||
|
||
if parsed.scheme == "file":
|
||
local_path = Path(unquote(parsed.path)).expanduser()
|
||
else:
|
||
local_path = Path(source).expanduser()
|
||
|
||
if not local_path.is_absolute():
|
||
local_path = (Path.cwd() / local_path).resolve()
|
||
|
||
if not local_path.exists() or not local_path.is_file():
|
||
raise FileNotFoundError(f"Media file not found: {local_path}")
|
||
|
||
data = local_path.read_bytes()
|
||
resolved_name = file_name or local_path.name
|
||
content_type = self._normalize_content_type("", resolved_name)
|
||
return data, content_type, resolved_name
|
||
|
||
async def _prepare_outbound_media(
|
||
self,
|
||
media_source: str,
|
||
file_name: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
data, content_type, resolved_name = await self._load_outbound_media(media_source, file_name=file_name)
|
||
detected_type = self._detect_wecom_media_type(content_type)
|
||
size_check = self._apply_file_size_limits(len(data), detected_type, content_type)
|
||
return {
|
||
"data": data,
|
||
"content_type": content_type,
|
||
"file_name": resolved_name,
|
||
"detected_type": detected_type,
|
||
**size_check,
|
||
}
|
||
|
||
async def _upload_media_bytes(self, data: bytes, media_type: str, filename: str) -> Dict[str, Any]:
|
||
if not data:
|
||
raise ValueError("Cannot upload empty media")
|
||
|
||
total_size = len(data)
|
||
total_chunks = (total_size + UPLOAD_CHUNK_SIZE - 1) // UPLOAD_CHUNK_SIZE
|
||
if total_chunks > MAX_UPLOAD_CHUNKS:
|
||
raise ValueError(
|
||
f"File too large: {total_chunks} chunks exceeds maximum of {MAX_UPLOAD_CHUNKS} chunks"
|
||
)
|
||
|
||
init_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_INIT,
|
||
{
|
||
"type": media_type,
|
||
"filename": filename,
|
||
"total_size": total_size,
|
||
"total_chunks": total_chunks,
|
||
"md5": hashlib.md5(data).hexdigest(),
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(init_response, "media upload init")
|
||
|
||
init_body = init_response.get("body") if isinstance(init_response.get("body"), dict) else {}
|
||
upload_id = str(init_body.get("upload_id") or "").strip()
|
||
if not upload_id:
|
||
raise RuntimeError(f"media upload init failed: missing upload_id in response {init_response}")
|
||
|
||
for chunk_index, start in enumerate(range(0, total_size, UPLOAD_CHUNK_SIZE)):
|
||
chunk = data[start : start + UPLOAD_CHUNK_SIZE]
|
||
chunk_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_CHUNK,
|
||
{
|
||
"upload_id": upload_id,
|
||
# Match the official SDK implementation, which currently uses 0-based chunk indexes.
|
||
"chunk_index": chunk_index,
|
||
"base64_data": base64.b64encode(chunk).decode("ascii"),
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(chunk_response, f"media upload chunk {chunk_index}")
|
||
|
||
finish_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_FINISH,
|
||
{"upload_id": upload_id},
|
||
)
|
||
self._raise_for_wecom_error(finish_response, "media upload finish")
|
||
|
||
finish_body = finish_response.get("body") if isinstance(finish_response.get("body"), dict) else {}
|
||
media_id = str(finish_body.get("media_id") or "").strip()
|
||
if not media_id:
|
||
raise RuntimeError(f"media upload finish failed: missing media_id in response {finish_response}")
|
||
|
||
return {
|
||
"type": str(finish_body.get("type") or media_type),
|
||
"media_id": media_id,
|
||
"created_at": finish_body.get("created_at"),
|
||
}
|
||
|
||
async def _send_media_message(self, chat_id: str, media_type: str, media_id: str) -> Dict[str, Any]:
|
||
response = await self._send_request(
|
||
APP_CMD_SEND,
|
||
{
|
||
"chatid": chat_id,
|
||
"msgtype": media_type,
|
||
media_type: {"media_id": media_id},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send media message")
|
||
return response
|
||
|
||
async def _send_reply_stream(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||
response = await self._send_reply_request(
|
||
reply_req_id,
|
||
{
|
||
"msgtype": "stream",
|
||
"stream": {
|
||
"id": self._new_req_id("stream"),
|
||
"finish": True,
|
||
"content": content[:self.MAX_MESSAGE_LENGTH],
|
||
},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send reply stream")
|
||
return response
|
||
|
||
async def _send_reply_media_message(
|
||
self,
|
||
reply_req_id: str,
|
||
media_type: str,
|
||
media_id: str,
|
||
) -> Dict[str, Any]:
|
||
response = await self._send_reply_request(
|
||
reply_req_id,
|
||
{
|
||
"msgtype": media_type,
|
||
media_type: {"media_id": media_id},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send reply media message")
|
||
return response
|
||
|
||
async def _send_followup_markdown(
|
||
self,
|
||
chat_id: str,
|
||
content: str,
|
||
reply_to: Optional[str] = None,
|
||
) -> Optional[SendResult]:
|
||
if not content:
|
||
return None
|
||
result = await self.send(chat_id=chat_id, content=content, reply_to=reply_to)
|
||
if not result.success:
|
||
logger.warning("[%s] Follow-up markdown send failed: %s", self.name, result.error)
|
||
return result
|
||
|
||
async def _send_media_source(
|
||
self,
|
||
chat_id: str,
|
||
media_source: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
) -> SendResult:
|
||
if not chat_id:
|
||
return SendResult(success=False, error="chat_id is required")
|
||
|
||
try:
|
||
prepared = await self._prepare_outbound_media(media_source, file_name=file_name)
|
||
except FileNotFoundError as exc:
|
||
return SendResult(success=False, error=str(exc))
|
||
except Exception as exc:
|
||
logger.error("[%s] Failed to prepare outbound media %s: %s", self.name, media_source, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
if prepared["rejected"]:
|
||
await self._send_followup_markdown(
|
||
chat_id,
|
||
f"⚠️ {prepared['reject_reason']}",
|
||
reply_to=reply_to,
|
||
)
|
||
return SendResult(success=False, error=prepared["reject_reason"])
|
||
|
||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||
try:
|
||
upload_result = await self._upload_media_bytes(
|
||
prepared["data"],
|
||
prepared["final_type"],
|
||
prepared["file_name"],
|
||
)
|
||
if reply_req_id:
|
||
media_response = await self._send_reply_media_message(
|
||
reply_req_id,
|
||
prepared["final_type"],
|
||
upload_result["media_id"],
|
||
)
|
||
else:
|
||
media_response = await self._send_media_message(
|
||
chat_id,
|
||
prepared["final_type"],
|
||
upload_result["media_id"],
|
||
)
|
||
except asyncio.TimeoutError:
|
||
return SendResult(success=False, error="Timeout sending media to WeCom")
|
||
except Exception as exc:
|
||
logger.error("[%s] Failed to send media %s: %s", self.name, media_source, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
caption_result = None
|
||
downgrade_result = None
|
||
if caption:
|
||
caption_result = await self._send_followup_markdown(
|
||
chat_id,
|
||
caption,
|
||
reply_to=reply_to,
|
||
)
|
||
if prepared["downgraded"] and prepared["downgrade_note"]:
|
||
downgrade_result = await self._send_followup_markdown(
|
||
chat_id,
|
||
f"ℹ️ {prepared['downgrade_note']}",
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=self._payload_req_id(media_response) or uuid.uuid4().hex[:12],
|
||
raw_response={
|
||
"upload": upload_result,
|
||
"media": media_response,
|
||
"caption": caption_result.raw_response if caption_result else None,
|
||
"caption_error": caption_result.error if caption_result and not caption_result.success else None,
|
||
"downgrade": downgrade_result.raw_response if downgrade_result else None,
|
||
"downgrade_error": downgrade_result.error if downgrade_result and not downgrade_result.success else None,
|
||
},
|
||
)
|
||
|
||
async def send(
|
||
self,
|
||
chat_id: str,
|
||
content: str,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send markdown to a WeCom chat via proactive ``aibot_send_msg``."""
|
||
del metadata
|
||
|
||
if not chat_id:
|
||
return SendResult(success=False, error="chat_id is required")
|
||
|
||
try:
|
||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||
if reply_req_id:
|
||
response = await self._send_reply_stream(reply_req_id, content)
|
||
else:
|
||
response = await self._send_request(
|
||
APP_CMD_SEND,
|
||
{
|
||
"chatid": chat_id,
|
||
"msgtype": "markdown",
|
||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||
},
|
||
)
|
||
except asyncio.TimeoutError:
|
||
return SendResult(success=False, error="Timeout sending message to WeCom")
|
||
except Exception as exc:
|
||
logger.error("[%s] Send failed: %s", self.name, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
error = self._response_error(response)
|
||
if error:
|
||
return SendResult(success=False, error=error)
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=self._payload_req_id(response) or uuid.uuid4().hex[:12],
|
||
raw_response=response,
|
||
)
|
||
|
||
async def send_image(
|
||
self,
|
||
chat_id: str,
|
||
image_url: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
del metadata
|
||
|
||
result = await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=image_url,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
if result.success or not self._looks_like_url(image_url):
|
||
return result
|
||
|
||
logger.warning("[%s] Falling back to text send for image URL %s: %s", self.name, image_url, result.error)
|
||
fallback_text = f"{caption}\n{image_url}" if caption else image_url
|
||
return await self.send(chat_id=chat_id, content=fallback_text, reply_to=reply_to)
|
||
|
||
async def send_image_file(
|
||
self,
|
||
chat_id: str,
|
||
image_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=image_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_document(
|
||
self,
|
||
chat_id: str,
|
||
file_path: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=file_path,
|
||
caption=caption,
|
||
file_name=file_name,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_voice(
|
||
self,
|
||
chat_id: str,
|
||
audio_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=audio_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_video(
|
||
self,
|
||
chat_id: str,
|
||
video_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=video_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||
"""WeCom does not expose typing indicators in this adapter."""
|
||
del chat_id, metadata
|
||
|
||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||
"""Return minimal chat info."""
|
||
return {
|
||
"name": chat_id,
|
||
"type": "group" if chat_id and chat_id.lower().startswith("group") else "dm",
|
||
}
|