Compare commits
6 Commits
kimi/issue
...
review-fix
| Author | SHA1 | Date | |
|---|---|---|---|
| d60eff31fe | |||
| d8d792a6e9 | |||
| c93ec2792d | |||
| ab4a185248 | |||
| 48103bb076 | |||
| 9f244ffc70 |
@@ -101,7 +101,7 @@ async def _process_chat(user_msg: str) -> dict | JSONResponse:
|
||||
try:
|
||||
response_text = await agent_chat(
|
||||
_build_context_prefix() + user_msg,
|
||||
session_id="mobile",
|
||||
session_id=body.get("session_id", "mobile"),
|
||||
)
|
||||
message_log.append(role="user", content=user_msg, timestamp=timestamp, source="api")
|
||||
message_log.append(role="agent", content=response_text, timestamp=timestamp, source="api")
|
||||
@@ -165,6 +165,11 @@ async def api_upload(file: UploadFile = File(...)):
|
||||
if not str(resolved).startswith(str(upload_root)):
|
||||
raise HTTPException(status_code=400, detail="Invalid file name")
|
||||
|
||||
# Validate MIME type
|
||||
allowed_types = ["image/png", "image/jpeg", "image/gif", "application/pdf", "text/plain"]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(status_code=400, detail=f"File type {file.content_type} not allowed")
|
||||
|
||||
contents = await file.read()
|
||||
if len(contents) > _MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=413, detail="File too large (max 50 MB)")
|
||||
|
||||
@@ -60,7 +60,12 @@ class MessageLog:
|
||||
self._conn: sqlite3.Connection | None = None
|
||||
|
||||
# Lazy connection — opened on first use, not at import time.
|
||||
def _ensure_conn(self) -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
path = self._db_path or DB_PATH
|
||||
with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
yield conn
|
||||
if self._conn is None:
|
||||
# Open a persistent connection for the class instance
|
||||
path = self._db_path or DB_PATH
|
||||
|
||||
@@ -79,7 +79,17 @@ class WebSocketManager:
|
||||
message = ws_event.to_json()
|
||||
disconnected = []
|
||||
|
||||
for ws in self._connections:
|
||||
import asyncio
|
||||
tasks = [ws.send_text(message) for ws in self._connections]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
disconnected = []
|
||||
for ws, result in zip(self._connections, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"WebSocket send error: {result}")
|
||||
disconnected.append(ws)
|
||||
|
||||
# Skip the old loop
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except ConnectionError:
|
||||
|
||||
117
src/integrations/chat_bridge/vendors/discord.py
vendored
117
src/integrations/chat_bridge/vendors/discord.py
vendored
@@ -515,25 +515,36 @@ class DiscordVendor(ChatPlatform):
|
||||
|
||||
async def _handle_message(self, message) -> None:
|
||||
"""Process an incoming message and respond via a thread."""
|
||||
# Strip the bot mention from the message content
|
||||
content = message.content
|
||||
if self._client.user:
|
||||
content = content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
|
||||
content = self._extract_content(message)
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Create or reuse a thread for this conversation
|
||||
thread = await self._get_or_create_thread(message)
|
||||
target = thread or message.channel
|
||||
session_id = f"discord_{thread.id}" if thread else f"discord_{message.channel.id}"
|
||||
|
||||
# Derive session_id for per-conversation history via Agno's SQLite
|
||||
if thread:
|
||||
session_id = f"discord_{thread.id}"
|
||||
else:
|
||||
session_id = f"discord_{message.channel.id}"
|
||||
run_output, response = await self._invoke_agent(content, session_id, target)
|
||||
|
||||
# Run Timmy agent with typing indicator and timeout
|
||||
if run_output is not None:
|
||||
await self._handle_paused_run(run_output, target, session_id)
|
||||
raw_content = run_output.content if hasattr(run_output, "content") else ""
|
||||
response = _clean_response(raw_content or "")
|
||||
|
||||
await self._send_response(response, target)
|
||||
|
||||
def _extract_content(self, message) -> str:
|
||||
"""Strip the bot mention and return clean message text."""
|
||||
content = message.content
|
||||
if self._client.user:
|
||||
content = content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
return content
|
||||
|
||||
async def _invoke_agent(self, content: str, session_id: str, target):
|
||||
"""Run chat_with_tools with a typing indicator and timeout.
|
||||
|
||||
Returns a (run_output, error_response) tuple. On success the
|
||||
error_response is ``None``; on failure run_output is ``None``.
|
||||
"""
|
||||
run_output = None
|
||||
response = None
|
||||
try:
|
||||
@@ -548,51 +559,57 @@ class DiscordVendor(ChatPlatform):
|
||||
except Exception as exc:
|
||||
logger.error("Discord: chat_with_tools() failed: %s", exc)
|
||||
response = "I'm having trouble reaching my inference backend right now. Please try again shortly."
|
||||
return run_output, response
|
||||
|
||||
# Check if Agno paused the run for tool confirmation
|
||||
if run_output is not None:
|
||||
status = getattr(run_output, "status", None)
|
||||
is_paused = status == "PAUSED" or str(status) == "RunStatus.paused"
|
||||
async def _handle_paused_run(self, run_output, target, session_id: str) -> None:
|
||||
"""If Agno paused the run for tool confirmation, enqueue approvals."""
|
||||
status = getattr(run_output, "status", None)
|
||||
is_paused = status == "PAUSED" or str(status) == "RunStatus.paused"
|
||||
|
||||
if is_paused and getattr(run_output, "active_requirements", None):
|
||||
from config import settings
|
||||
if not (is_paused and getattr(run_output, "active_requirements", None)):
|
||||
return
|
||||
|
||||
if settings.discord_confirm_actions:
|
||||
for req in run_output.active_requirements:
|
||||
if getattr(req, "needs_confirmation", False):
|
||||
te = req.tool_execution
|
||||
tool_name = getattr(te, "tool_name", "unknown")
|
||||
tool_args = getattr(te, "tool_args", {}) or {}
|
||||
from config import settings
|
||||
|
||||
from timmy.approvals import create_item
|
||||
if not settings.discord_confirm_actions:
|
||||
return
|
||||
|
||||
item = create_item(
|
||||
title=f"Discord: {tool_name}",
|
||||
description=_format_action_description(tool_name, tool_args),
|
||||
proposed_action=json.dumps({"tool": tool_name, "args": tool_args}),
|
||||
impact=_get_impact_level(tool_name),
|
||||
)
|
||||
self._pending_actions[item.id] = {
|
||||
"run_output": run_output,
|
||||
"requirement": req,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"target": target,
|
||||
"session_id": session_id,
|
||||
}
|
||||
await self._send_confirmation(target, tool_name, tool_args, item.id)
|
||||
for req in run_output.active_requirements:
|
||||
if not getattr(req, "needs_confirmation", False):
|
||||
continue
|
||||
te = req.tool_execution
|
||||
tool_name = getattr(te, "tool_name", "unknown")
|
||||
tool_args = getattr(te, "tool_args", {}) or {}
|
||||
|
||||
raw_content = run_output.content if hasattr(run_output, "content") else ""
|
||||
response = _clean_response(raw_content or "")
|
||||
from timmy.approvals import create_item
|
||||
|
||||
# Discord has a 2000 character limit — send with error handling
|
||||
if response and response.strip():
|
||||
for chunk in _chunk_message(response, 2000):
|
||||
try:
|
||||
await target.send(chunk)
|
||||
except Exception as exc:
|
||||
logger.error("Discord: failed to send message chunk: %s", exc)
|
||||
break
|
||||
item = create_item(
|
||||
title=f"Discord: {tool_name}",
|
||||
description=_format_action_description(tool_name, tool_args),
|
||||
proposed_action=json.dumps({"tool": tool_name, "args": tool_args}),
|
||||
impact=_get_impact_level(tool_name),
|
||||
)
|
||||
self._pending_actions[item.id] = {
|
||||
"run_output": run_output,
|
||||
"requirement": req,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"target": target,
|
||||
"session_id": session_id,
|
||||
}
|
||||
await self._send_confirmation(target, tool_name, tool_args, item.id)
|
||||
|
||||
@staticmethod
|
||||
async def _send_response(response: str | None, target) -> None:
|
||||
"""Send a response to Discord, chunked to the 2000-char limit."""
|
||||
if not response or not response.strip():
|
||||
return
|
||||
for chunk in _chunk_message(response, 2000):
|
||||
try:
|
||||
await target.send(chunk)
|
||||
except Exception as exc:
|
||||
logger.error("Discord: failed to send message chunk: %s", exc)
|
||||
break
|
||||
|
||||
async def _get_or_create_thread(self, message):
|
||||
"""Get the active thread for a channel, or create one.
|
||||
|
||||
@@ -78,6 +78,11 @@ DEFAULT_MAX_UTTERANCE = 30.0 # safety cap — don't record forever
|
||||
DEFAULT_SESSION_ID = "voice"
|
||||
|
||||
|
||||
def _rms(block: np.ndarray) -> float:
|
||||
"""Compute root-mean-square energy of an audio block."""
|
||||
return float(np.sqrt(np.mean(block.astype(np.float32) ** 2)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig:
|
||||
"""Configuration for the voice loop."""
|
||||
@@ -161,13 +166,6 @@ class VoiceLoop:
|
||||
min_blocks = int(self.config.min_utterance / 0.1)
|
||||
max_blocks = int(self.config.max_utterance / 0.1)
|
||||
|
||||
audio_chunks: list[np.ndarray] = []
|
||||
silent_count = 0
|
||||
recording = False
|
||||
|
||||
def _rms(block: np.ndarray) -> float:
|
||||
return float(np.sqrt(np.mean(block.astype(np.float32) ** 2)))
|
||||
|
||||
sys.stdout.write("\n 🎤 Listening... (speak now)\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -177,42 +175,69 @@ class VoiceLoop:
|
||||
dtype="float32",
|
||||
blocksize=block_size,
|
||||
) as stream:
|
||||
while self._running:
|
||||
block, overflowed = stream.read(block_size)
|
||||
if overflowed:
|
||||
logger.debug("Audio buffer overflowed")
|
||||
chunks = self._capture_audio_blocks(stream, block_size, silence_blocks, max_blocks)
|
||||
|
||||
rms = _rms(block)
|
||||
return self._finalize_utterance(chunks, min_blocks, sr)
|
||||
|
||||
if not recording:
|
||||
if rms > self.config.silence_threshold:
|
||||
recording = True
|
||||
silent_count = 0
|
||||
audio_chunks.append(block.copy())
|
||||
sys.stdout.write(" 📢 Recording...\r")
|
||||
sys.stdout.flush()
|
||||
def _capture_audio_blocks(
|
||||
self,
|
||||
stream,
|
||||
block_size: int,
|
||||
silence_blocks: int,
|
||||
max_blocks: int,
|
||||
) -> list[np.ndarray]:
|
||||
"""Read audio blocks from *stream* until silence or max length.
|
||||
|
||||
Returns the list of captured audio chunks (may be empty).
|
||||
"""
|
||||
chunks: list[np.ndarray] = []
|
||||
silent_count = 0
|
||||
recording = False
|
||||
|
||||
while self._running:
|
||||
block, overflowed = stream.read(block_size)
|
||||
if overflowed:
|
||||
logger.debug("Audio buffer overflowed")
|
||||
|
||||
rms = _rms(block)
|
||||
|
||||
if not recording:
|
||||
if rms > self.config.silence_threshold:
|
||||
recording = True
|
||||
silent_count = 0
|
||||
chunks.append(block.copy())
|
||||
sys.stdout.write(" 📢 Recording...\r")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
chunks.append(block.copy())
|
||||
|
||||
if rms < self.config.silence_threshold:
|
||||
silent_count += 1
|
||||
else:
|
||||
audio_chunks.append(block.copy())
|
||||
silent_count = 0
|
||||
|
||||
if rms < self.config.silence_threshold:
|
||||
silent_count += 1
|
||||
else:
|
||||
silent_count = 0
|
||||
if silent_count >= silence_blocks:
|
||||
break
|
||||
|
||||
# End of utterance
|
||||
if silent_count >= silence_blocks:
|
||||
break
|
||||
if len(chunks) >= max_blocks:
|
||||
logger.info("Max utterance length reached, stopping.")
|
||||
break
|
||||
|
||||
# Safety cap
|
||||
if len(audio_chunks) >= max_blocks:
|
||||
logger.info("Max utterance length reached, stopping.")
|
||||
break
|
||||
return chunks
|
||||
|
||||
if not audio_chunks or len(audio_chunks) < min_blocks:
|
||||
@staticmethod
|
||||
def _finalize_utterance(
|
||||
chunks: list[np.ndarray], min_blocks: int, sample_rate: int
|
||||
) -> np.ndarray | None:
|
||||
"""Concatenate recorded chunks and report duration.
|
||||
|
||||
Returns ``None`` if the utterance is too short to be meaningful.
|
||||
"""
|
||||
if not chunks or len(chunks) < min_blocks:
|
||||
return None
|
||||
|
||||
audio = np.concatenate(audio_chunks, axis=0).flatten()
|
||||
duration = len(audio) / sr
|
||||
audio = np.concatenate(chunks, axis=0).flatten()
|
||||
duration = len(audio) / sample_rate
|
||||
sys.stdout.write(f" ✂️ Captured {duration:.1f}s of audio\n")
|
||||
sys.stdout.flush()
|
||||
return audio
|
||||
|
||||
@@ -174,6 +174,103 @@ class TestDiscordVendor:
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestExtractContent:
|
||||
def test_strips_bot_mention(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
vendor._client = MagicMock()
|
||||
vendor._client.user.id = 12345
|
||||
msg = MagicMock()
|
||||
msg.content = "<@12345> hello there"
|
||||
assert vendor._extract_content(msg) == "hello there"
|
||||
|
||||
def test_no_client_user(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
vendor._client = MagicMock()
|
||||
vendor._client.user = None
|
||||
msg = MagicMock()
|
||||
msg.content = "hello"
|
||||
assert vendor._extract_content(msg) == "hello"
|
||||
|
||||
def test_empty_after_strip(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
vendor._client = MagicMock()
|
||||
vendor._client.user.id = 99
|
||||
msg = MagicMock()
|
||||
msg.content = "<@99>"
|
||||
assert vendor._extract_content(msg) == ""
|
||||
|
||||
|
||||
class TestInvokeAgent:
|
||||
@staticmethod
|
||||
def _make_typing_target():
|
||||
"""Build a mock target whose .typing() is an async context manager."""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
target = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _typing():
|
||||
yield
|
||||
|
||||
target.typing = _typing
|
||||
return target
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_returns_error(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
target = self._make_typing_target()
|
||||
|
||||
with patch(
|
||||
"integrations.chat_bridge.vendors.discord.chat_with_tools", side_effect=TimeoutError
|
||||
):
|
||||
run_output, response = await vendor._invoke_agent("hi", "sess", target)
|
||||
assert run_output is None
|
||||
assert "too long" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_error(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
target = self._make_typing_target()
|
||||
|
||||
with patch(
|
||||
"integrations.chat_bridge.vendors.discord.chat_with_tools",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
run_output, response = await vendor._invoke_agent("hi", "sess", target)
|
||||
assert run_output is None
|
||||
assert "trouble" in response
|
||||
|
||||
|
||||
class TestSendResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_empty(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
target = AsyncMock()
|
||||
await DiscordVendor._send_response(None, target)
|
||||
target.send.assert_not_called()
|
||||
await DiscordVendor._send_response("", target)
|
||||
target.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_short_message(self):
|
||||
from integrations.chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
target = AsyncMock()
|
||||
await DiscordVendor._send_response("hello", target)
|
||||
target.send.assert_called_once_with("hello")
|
||||
|
||||
|
||||
class TestChunkMessage:
|
||||
def test_short_message(self):
|
||||
from integrations.chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
@@ -15,7 +15,7 @@ except ImportError:
|
||||
np = None
|
||||
|
||||
try:
|
||||
from timmy.voice_loop import VoiceConfig, VoiceLoop, _strip_markdown
|
||||
from timmy.voice_loop import VoiceConfig, VoiceLoop, _rms, _strip_markdown
|
||||
except ImportError:
|
||||
pass # pytestmark will skip all tests anyway
|
||||
|
||||
@@ -147,6 +147,31 @@ class TestStripMarkdown:
|
||||
assert "*" not in result
|
||||
|
||||
|
||||
class TestRms:
|
||||
def test_silent_block(self):
|
||||
block = np.zeros(1600, dtype=np.float32)
|
||||
assert _rms(block) == pytest.approx(0.0, abs=1e-7)
|
||||
|
||||
def test_loud_block(self):
|
||||
block = np.ones(1600, dtype=np.float32)
|
||||
assert _rms(block) == pytest.approx(1.0, abs=1e-5)
|
||||
|
||||
|
||||
class TestFinalizeUtterance:
|
||||
def test_returns_none_for_empty(self):
|
||||
assert VoiceLoop._finalize_utterance([], min_blocks=5, sample_rate=16000) is None
|
||||
|
||||
def test_returns_none_for_too_short(self):
|
||||
chunks = [np.zeros(1600, dtype=np.float32) for _ in range(3)]
|
||||
assert VoiceLoop._finalize_utterance(chunks, min_blocks=5, sample_rate=16000) is None
|
||||
|
||||
def test_returns_audio_for_sufficient_chunks(self):
|
||||
chunks = [np.ones(1600, dtype=np.float32) for _ in range(6)]
|
||||
result = VoiceLoop._finalize_utterance(chunks, min_blocks=5, sample_rate=16000)
|
||||
assert result is not None
|
||||
assert len(result) == 6 * 1600
|
||||
|
||||
|
||||
class TestThink:
|
||||
def test_think_returns_response(self):
|
||||
loop = VoiceLoop()
|
||||
|
||||
Reference in New Issue
Block a user